Click here to Skip to main content
15,884,388 members
Articles / Desktop Programming / Win32

Cache IQueryable for Better LINQ-to-SQL Performance

Rate me:
Please Sign up or sign in to vote.
4.69/5 (10 votes)
21 May 2012CPOL10 min read 60K   448   33  
An approach to improve LINQ-to-SQL performance while preserving maintainability over DataReader.
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data;
using System.Data.Common;
using System.Data.Linq;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace Prototype.Database
{
    public static class Utility
    {
        #region Log
        public static bool Verbose { get; set; }

        private static System.Web.Script.Serialization.JavaScriptSerializer Serializer = new System.Web.Script.Serialization.JavaScriptSerializer();

        public static void LogObject(object obj)
        {
            if (!Verbose)
                return;

            Console.Write("*** ");
            Console.WriteLine(Serializer.Serialize(obj));
        }

        public static void Log(string format, params object[] args)
        {
            if (!Verbose)
                return;

            Console.Write("*** ");

            if ((args == null) || (args.Length == 0))
                Console.WriteLine(format);
            else
                Console.WriteLine(format, args);
        }

        public class VisitExpression
        {
            public static void Visit(Expression ex)
            {
                BaseVisit(ex, 0);
            }

            /// <summary>
            /// Print one general description about the expression
            /// 
            /// Call specific expression visitor
            /// </summary>
            /// <param name="ex"></param>
            /// <param name="indent"></param>
            private static void BaseVisit(Expression ex, int indent)
            {
                Console.WriteLine("{0}{1} {2}", new string(' ', indent * 2), ex.NodeType, ex);

                switch (Map.From(ex.NodeType))
                {
                    case ExpressionEnum.Binary:
                        Visit((BinaryExpression)ex, indent);
                        break;
                    case ExpressionEnum.MethodCall:
                        Visit((MethodCallExpression)ex, indent);
                        break;
                    case ExpressionEnum.Constant:
                        Visit((ConstantExpression)ex, indent);
                        break;
                    case ExpressionEnum.Unary:
                        Visit((UnaryExpression)ex, indent);
                        break;
                    case ExpressionEnum.Lambda:
                        Visit((LambdaExpression)ex, indent);
                        break;
                    case ExpressionEnum.Member:
                        Visit((MemberExpression)ex, indent);
                        break;
                    case ExpressionEnum.Parameter:
                        Visit((ParameterExpression)ex, indent);
                        break;
                    case ExpressionEnum.New:
                        Visit((NewExpression)ex, indent);
                        break;
                    case ExpressionEnum.MemberInit:
                        Visit((MemberInitExpression)ex, indent);
                        break;
                    case ExpressionEnum.Conditional:
                        Visit((ConditionalExpression)ex, indent);
                        break;
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            private static void Visit(MethodCallExpression ex, int indent)
            {
                foreach (var arg in ex.Arguments)
                    BaseVisit(arg, indent + 1);
            }

            private static void Visit(ConstantExpression ex, int indent) { }

            private static void Visit(UnaryExpression ex, int indent)
            {
                BaseVisit(ex.Operand, indent + 1);
            }

            private static void Visit(LambdaExpression ex, int indent)
            {
                BaseVisit(ex.Body, indent + 1);
            }

            private static void Visit(BinaryExpression ex, int indent)
            {
                BaseVisit(ex.Left, indent + 1);
                BaseVisit(ex.Right, indent + 1);

                if (ex.Conversion != null)
                    BaseVisit(ex.Conversion, indent + 1);
            }

            private static void Visit(MemberExpression ex, int indent)
            {
                BaseVisit(ex.Expression, indent + 1);
            }

            private static void Visit(ParameterExpression ex, int indent) { }

            private static void Visit(NewExpression ex, int indent)
            {
                foreach (var sub in ex.Arguments)
                    BaseVisit(sub, indent + 1);
            }

            private static void Visit(MemberInitExpression ex, int indent)
            {
                BaseVisit(ex.NewExpression, indent + 1);

                foreach (var mb in ex.Bindings)
                {
                    switch (mb.BindingType)
                    {
                        case MemberBindingType.Assignment:
                            BaseVisit(((MemberAssignment)mb).Expression, indent + 1);
                            break;

                        default:
                            throw new NotImplementedException(mb.BindingType.ToString());
                    }
                }
            }

            private static void Visit(ConditionalExpression ex, int indent)
            {
                BaseVisit(ex.Test, indent + 1);
                BaseVisit(ex.IfTrue, indent + 1);
                BaseVisit(ex.IfFalse, indent + 1);
            }
        }
        #endregion

        #region Supporting class
        /// <summary>
        /// Enumeration of expression type in Linq (class definition)
        /// </summary>
        internal enum ExpressionEnum
        {
            Binary,
            Unary,
            MethodCall,
            Conditional,
            Constant,
            Invocation,
            Lambda,
            ListInit,
            Member,
            MemberInit,
            New,
            NewArray,
            Parameter,
            TypeBinary,
        }

        /// <summary>
        /// To convert Expression Node Type to matching Expression class
        /// </summary>
        internal static class Map
        {
            public static ExpressionEnum From(ExpressionType type)
            {
                switch (type)
                {
                    case ExpressionType.Add:
                    case ExpressionType.AddChecked:
                    case ExpressionType.And:
                    case ExpressionType.AndAlso:
                    case ExpressionType.ArrayIndex:
                    case ExpressionType.Coalesce:
                    case ExpressionType.Divide:
                    case ExpressionType.Equal:
                    case ExpressionType.ExclusiveOr:
                    case ExpressionType.GreaterThan:
                    case ExpressionType.GreaterThanOrEqual:
                    case ExpressionType.LeftShift:
                    case ExpressionType.LessThan:
                    case ExpressionType.LessThanOrEqual:
                    case ExpressionType.Modulo:
                    case ExpressionType.Multiply:
                    case ExpressionType.MultiplyChecked:
                    case ExpressionType.NotEqual:
                    case ExpressionType.Or:
                    case ExpressionType.OrElse:
                    case ExpressionType.Power:
                    case ExpressionType.RightShift:
                    case ExpressionType.Subtract:
                    case ExpressionType.SubtractChecked:
                        return ExpressionEnum.Binary;

                    case ExpressionType.ArrayLength:
                    case ExpressionType.Convert:
                    case ExpressionType.ConvertChecked:
                    case ExpressionType.Negate:
                    case ExpressionType.UnaryPlus:
                    case ExpressionType.NegateChecked:
                    case ExpressionType.Not:
                    case ExpressionType.Quote:
                    case ExpressionType.TypeAs:
                        return ExpressionEnum.Unary;

                    case ExpressionType.Call:
                        return ExpressionEnum.MethodCall;

                    case ExpressionType.Conditional:
                        return ExpressionEnum.Conditional;

                    case ExpressionType.Constant:
                        return ExpressionEnum.Constant;

                    case ExpressionType.Invoke:
                        return ExpressionEnum.Invocation;

                    case ExpressionType.Lambda:
                        return ExpressionEnum.Lambda;

                    case ExpressionType.ListInit:
                        return ExpressionEnum.ListInit;

                    case ExpressionType.MemberAccess:
                        return ExpressionEnum.Member;

                    case ExpressionType.MemberInit:
                        return ExpressionEnum.MemberInit;

                    case ExpressionType.New:
                        return ExpressionEnum.New;

                    case ExpressionType.NewArrayInit:
                    case ExpressionType.NewArrayBounds:
                        return ExpressionEnum.NewArray;

                    case ExpressionType.Parameter:
                        return ExpressionEnum.Parameter;

                    case ExpressionType.TypeIs:
                        return ExpressionEnum.TypeBinary;

                    default:
                        throw new NotImplementedException(type.ToString());
                }
            }
        }

        /// <summary>
        /// Class to compare one expression to others
        /// </summary>
        internal class CompareExpression
        {
            private Expression _Origin;

            public CompareExpression(Expression origin)
            {
                _Origin = origin;
            }

            /// <summary>
            /// Compare the expression with the initial expression
            /// </summary>
            /// <param name="compare"></param>
            /// <returns></returns>
            public bool Equals(Expression compare)
            {
                return BaseCompare(_Origin, compare);
            }

            /// <summary>
            /// Compare two expression in general sense
            /// 
            /// Call specific typed comparison after that
            /// </summary>
            /// <param name="ex1"></param>
            /// <param name="ex2"></param>
            /// <returns></returns>
            private static bool BaseCompare(Expression ex1, Expression ex2)
            {
                if (ex1 == null)
                    return ex2 == null;
                if (ex2 == null)
                    return false;

                if (ex1.NodeType != ex2.NodeType)
                    return false;

                var type = Map.From(ex1.NodeType);

                switch (type)
                {
                    case ExpressionEnum.Binary:
                        return Compare((BinaryExpression)ex1, (BinaryExpression)ex2);
                    case ExpressionEnum.Unary:
                        return Compare((UnaryExpression)ex1, (UnaryExpression)ex2);
                    case ExpressionEnum.MethodCall:
                        return Compare((MethodCallExpression)ex1, (MethodCallExpression)ex2);
                    case ExpressionEnum.Constant:
                        return Compare((ConstantExpression)ex1, (ConstantExpression)ex2);
                    case ExpressionEnum.Lambda:
                        return Compare((LambdaExpression)ex1, (LambdaExpression)ex2);
                    case ExpressionEnum.Member:
                        return Compare((MemberExpression)ex1, (MemberExpression)ex2);
                    case ExpressionEnum.Parameter:
                        return Compare((ParameterExpression)ex1, (ParameterExpression)ex2);
                    case ExpressionEnum.TypeBinary:
                        return Compare((TypeBinaryExpression)ex1, (TypeBinaryExpression)ex2);
                    default:
                        throw new NotImplementedException(type.ToString());
                }
            }

            private static bool Compare(ParameterExpression ex1, ParameterExpression ex2)
            {
                if (!string.Equals(ex1.Name, ex2.Name))
                    return false;

                return true;
            }

            private static bool Compare(BinaryExpression ex1, BinaryExpression ex2)
            {
                if (ex1.IsLifted != ex2.IsLifted)
                    return false;

                if (ex1.IsLiftedToNull != ex2.IsLiftedToNull)
                    return false;

                if (ex1.Method != ex2.Method)
                    return false;

                if (!BaseCompare(ex1.Conversion, ex2.Conversion))
                    return false;

                if (!BaseCompare(ex1.Left, ex2.Left))
                    return false;

                if (!BaseCompare(ex1.Right, ex2.Right))
                    return false;

                return true;
            }

            private static bool Compare(TypeBinaryExpression ex1, TypeBinaryExpression ex2)
            {
                if (ex1.TypeOperand != ex2.TypeOperand)
                    return false;

                if (!BaseCompare(ex1.Expression, ex2.Expression))
                    return false;

                return true;
            }

            private static bool Compare(UnaryExpression ex1, UnaryExpression ex2)
            {
                if (ex1.IsLifted != ex2.IsLifted)
                    return false;

                if (ex1.IsLiftedToNull != ex2.IsLiftedToNull)
                    return false;

                if (ex1.Method != ex2.Method)
                    return false;

                if (!BaseCompare(ex1.Operand, ex2.Operand))
                    return false;

                return true;
            }

            private static bool Compare(MethodCallExpression ex1, MethodCallExpression ex2)
            {
                if (ex1.Method != ex2.Method)
                    return false;

                if (ex1.Arguments.Count != ex2.Arguments.Count)
                    return false;

                for (var i = 0; i < ex1.Arguments.Count; i++)
                    if (!BaseCompare(ex1.Arguments[i], ex2.Arguments[i]))
                        return false;

                return true;
            }

            private static bool Compare(ConstantExpression ex1, ConstantExpression ex2)
            {
                if (!ex1.Value.Equals(ex2.Value))
                    return false;

                return true;
            }

            private static bool Compare(LambdaExpression ex1, LambdaExpression ex2)
            {
                if (!BaseCompare(ex1.Body, ex2.Body))
                    return false;

                if (ex1.Parameters.Count != ex2.Parameters.Count)
                    return false;

                for (var i = 0; i < ex1.Parameters.Count; i++)
                    if (!BaseCompare(ex1.Parameters[i], ex2.Parameters[i]))
                        return false;

                return true;
            }

            private static bool Compare(MemberExpression ex1, MemberExpression ex2)
            {
                if (ex1.Member != ex2.Member)
                    return false;

                if (ex1.Member.MemberType == MemberTypes.Field)
                {
                    var field = (FieldInfo)ex1.Member;

                    if ((ex1.Expression is ConstantExpression) &&
                        field.FieldType.GetInterfaces().Contains(typeof(System.Collections.IList)))
                    {
                        // For collections, which may turn into separate parameters, compare the number of values
                        var list1 = (System.Collections.IList)field.GetValue(((ConstantExpression)ex1.Expression).Value);
                        var list2 = (System.Collections.IList)field.GetValue(((ConstantExpression)ex2.Expression).Value);

                        if (list1.Count > list2.Count)
                        {
                            // In order to save cached queries, 
                            return false;
                        }
                    }
                }

                return true;
            }
        }

        /// <summary>
        /// Build the expression to retrieve unique display class instance in a expression
        /// </summary>
        internal class GetDisplayClass
        {
            /// <summary>
            /// 
            /// </summary>
            /// <param name="expression"></param>
            /// <returns>A dictionary of </returns>
            public static Dictionary<object, Expression<Func<Expression, object>>> Build(Expression expression)
            {
                var getters = new Dictionary<object, Expression<Func<Expression, object>>>();
                var found = new List<object>();

                while (true)
                {
                    var accessors = new Stack<Access>();

                    var display = BaseFind(expression, accessors, found);

                    if (display == null)
                        break;

                    getters[display] = CompileStack(accessors);

                    found.Add(display);
                }

                return getters;
            }

            /// <summary>
            /// Type of method that need to be used
            /// </summary>
            private enum MethodType
            {
                Convert, // Convert an object
                PropertyOrField, // Retrieve a property or field
                CollectionIndex, // Retrieve one Expression from list
                MemberCollectionIndex, // Retrieve one MemberBinding from list
            }

            /// <summary>
            /// Individual step for getting display class
            /// </summary>
            private class Access
            {
                public MethodType Method { get; set; }
                public object Parameter { get; set; }
            }

            /// <summary>
            /// Turn the stack of Access into Lambda expression
            /// </summary>
            /// <param name="accessors"></param>
            /// <returns></returns>
            private static Expression<Func<Expression, object>> CompileStack(Stack<Access> accessors)
            {
                var input = Expression.Parameter(typeof(Expression), "ex");
                Expression current = input;

                while (accessors.Count > 0)
                {
                    var b = accessors.Pop();

                    switch (b.Method)
                    {
                        case MethodType.PropertyOrField:
                            current = Expression.PropertyOrField(current, (string)b.Parameter);
                            break;

                        case MethodType.Convert:
                            current = Expression.Convert(current, (Type)b.Parameter);
                            break;

                        case MethodType.CollectionIndex:
                            current = Expression.Call(current,
                                                      typeof(ReadOnlyCollection<Expression>).GetMethod("get_Item"),
                                                      Expression.Constant((int)b.Parameter));
                            break;

                        case MethodType.MemberCollectionIndex:
                            current = Expression.Call(current,
                                typeof(ReadOnlyCollection<MemberBinding>).GetMethod("get_Item"),
                                Expression.Constant((int)b.Parameter));
                            break;
                    }
                }

                // Compose the final Lambda expression
                return Expression.Lambda<Func<Expression, object>>(current, input);
            }

            private static object BaseFind(Expression ex, Stack<Access> accessors, List<object> found)
            {
                if (ex == null)
                    return null;

                var type = Map.From(ex.NodeType);

                switch (type)
                {
                    case ExpressionEnum.MethodCall:
                        return Find((MethodCallExpression)ex, accessors, found);
                    case ExpressionEnum.Constant:
                        return Find((ConstantExpression)ex, accessors, found);
                    case ExpressionEnum.Unary:
                        return Find((UnaryExpression)ex, accessors, found);
                    case ExpressionEnum.Lambda:
                        return Find((LambdaExpression)ex, accessors, found);
                    case ExpressionEnum.Binary:
                        return Find((BinaryExpression)ex, accessors, found);
                    case ExpressionEnum.Member:
                        return Find((MemberExpression)ex, accessors, found);
                    case ExpressionEnum.Parameter:
                        return Find((ParameterExpression)ex, accessors, found);
                    case ExpressionEnum.New:
                        return Find((NewExpression)ex, accessors, found);
                    case ExpressionEnum.MemberInit:
                        return Find((MemberInitExpression)ex, accessors, found);
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            private static object Find(MethodCallExpression ex, Stack<Access> accessors, List<object> found)
            {
                var display = BaseFind(ex.Object, accessors, found);
                if (display != null)
                {
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Object" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });

                    return display;
                }

                for (var i = 0; i < ex.Arguments.Count; i++)
                {
                    display = BaseFind(ex.Arguments[i], accessors, found);

                    if (display == null)
                        continue;

                    accessors.Push(new Access { Method = MethodType.CollectionIndex, Parameter = i });
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Arguments" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });

                    break;
                }

                return display;
            }

            private static object Find(NewExpression ex, Stack<Access> accessors, List<object> found)
            {
                for (var i = 0; i < ex.Arguments.Count; i++)
                {
                    var display = BaseFind(ex.Arguments[i], accessors, found);

                    if (display == null)
                        continue;

                    accessors.Push(new Access { Method = MethodType.CollectionIndex, Parameter = i });
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Arguments" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });

                    return display;
                }

                return null;
            }

            private static object Find(ConstantExpression ex, Stack<Access> accessors, List<object> found)
            {
                if (!ex.Type.Name.Contains("DisplayClass") || found.Contains(ex.Value))
                    return null;

                accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Value" });
                accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });

                return ex.Value;
            }

            private static object Find(UnaryExpression ex, Stack<Access> accessors, List<object> found)
            {
                var display = BaseFind(ex.Operand, accessors, found);

                if (display != null)
                {
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Operand" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });
                }

                return display;
            }

            private static object Find(LambdaExpression ex, Stack<Access> accessors, List<object> found)
            {
                var display = BaseFind(ex.Body, accessors, found);

                if (display != null)
                {
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Body" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });
                }

                return display;
            }

            private static object Find(BinaryExpression ex, Stack<Access> accessors, List<object> found)
            {
                var display = BaseFind(ex.Left, accessors, found);

                if (display != null)
                {
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Left" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });
                }
                else
                {
                    display = BaseFind(ex.Right, accessors, found);

                    if (display != null)
                    {
                        accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Right" });
                        accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });
                    }
                }

                return display;
            }

            private static object Find(MemberExpression ex, Stack<Access> accessors, List<object> found)
            {
                var display = BaseFind(ex.Expression, accessors, found);

                if (display != null)
                {
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Expression" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });
                }

                return display;
            }

            private static object Find(ParameterExpression ex, Stack<Access> accessors, List<object> found)
            {
                return null;
            }

            private static object Find(MemberInitExpression ex, Stack<Access> accessors, List<object> found)
            {
                var display = BaseFind(ex.NewExpression, accessors, found);

                if (display != null)
                {
                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "NewExpression" });
                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });

                    return display;
                }

                if (ex.Bindings.Count > 0)
                {
                    for (var i = 0; i < ex.Bindings.Count; i++)
                    {
                        switch (ex.Bindings[i].BindingType)
                        {
                            case MemberBindingType.Assignment:
                                display = BaseFind(((MemberAssignment)ex.Bindings[i]).Expression, accessors, found);

                                if (display != null)
                                {
                                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Expression" });
                                    accessors.Push(new Access
                                                       {
                                                           Method = MethodType.Convert,
                                                           Parameter = typeof(MemberAssignment)
                                                       });
                                    accessors.Push(new Access { Method = MethodType.MemberCollectionIndex, Parameter = i });
                                    accessors.Push(new Access { Method = MethodType.PropertyOrField, Parameter = "Bindings" });
                                    accessors.Push(new Access { Method = MethodType.Convert, Parameter = ex.GetType() });

                                    return display;
                                }
                                break;

                            default:
                                throw new NotImplementedException(ex.Bindings[i].BindingType.ToString());
                        }
                    }
                }

                return null;
            }
        }

        /// <summary>
        /// 
        /// </summary>
        internal class ExractParameter
        {
            public static List<Expression> Do(Expression ex)
            {
                var paramexs = new List<Expression>();

                BaseFind(ex, paramexs);

                return paramexs;
            }

            #region Find methods
            private static bool BaseFind(Expression ex, List<Expression> paramexs)
            {
                switch (Map.From(ex.NodeType))
                {
                    case ExpressionEnum.MethodCall:
                        return Find((MethodCallExpression)ex, paramexs);
                    case ExpressionEnum.Constant:
                        return Find((ConstantExpression)ex, paramexs);
                    case ExpressionEnum.Unary:
                        return Find((UnaryExpression)ex, paramexs);
                    case ExpressionEnum.Lambda:
                        return Find((LambdaExpression)ex, paramexs);
                    case ExpressionEnum.Binary:
                        return Find((BinaryExpression)ex, paramexs);
                    case ExpressionEnum.Member:
                        return Find((MemberExpression)ex, paramexs);
                    case ExpressionEnum.Parameter:
                        return Find((ParameterExpression)ex, paramexs);
                    case ExpressionEnum.New:
                        return Find((NewExpression)ex, paramexs);
                    case ExpressionEnum.MemberInit:
                        return Find((MemberInitExpression)ex, paramexs);
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            private static bool Find(MethodCallExpression ex, List<Expression> paramexs)
            {
                if ((ex.Method.Name == "Contains") &&
                    (ex.Method.DeclaringType == typeof(Enumerable)))
                    return FindCallEnumerableContains(ex, paramexs);
                if ((ex.Method.Name == "Contains") &&
                    (ex.Method.DeclaringType.GetInterfaces().Contains(typeof(System.Collections.IList))))
                    return FindCallListContains(ex, paramexs);

                // Not many calls are used in LINQ
                // ex.Object is not consider in these method
                var obj = true;
                if (ex.Object != null)
                    obj = BaseFind(ex.Object, paramexs);

                var results = new bool[ex.Arguments.Count];
                var valid = 0;

                for (var i = ex.Arguments.Count - 1; i >= 0; i--)
                {
                    results[i] = BaseFind(ex.Arguments[i], paramexs);
                    valid += (results[i]) ? 1 : 0;
                }

                if (!obj || (valid != ex.Arguments.Count))
                {
                    for (var i = 0; i < ex.Arguments.Count; i++)
                    {
                        if (results[i])
                            paramexs.Add(ex.Arguments[i]);
                    }
                }

                return obj && (valid == ex.Arguments.Count);
            }

            /// <summary>
            /// Find within a extension method on List
            /// </summary>
            /// <param name="ex"></param>
            /// <param name="paramexs"></param>
            /// <returns></returns>
            private static bool FindCallEnumerableContains(MethodCallExpression ex, List<Expression> paramexs)
            {
                var results = new bool[ex.Arguments.Count];
                var valid = 0;

                for (var i = ex.Arguments.Count - 1; i >= 1; i--) // Skip the first parameter, which is container itself
                {
                    results[i] = BaseFind(ex.Arguments[i], paramexs);
                    valid += (results[i]) ? 1 : 0;
                }

                if (valid != ex.Arguments.Count)
                {
                    for (var i = 1; i < ex.Arguments.Count; i++)
                    {
                        if (results[i])
                            paramexs.Add(ex.Arguments[i]);
                    }
                }

                AddListValues(paramexs, ex.Arguments[0]);

                return valid == ex.Arguments.Count;
            }

            /// <summary>
            /// Find within an instance method on IList
            /// </summary>
            /// <param name="ex"></param>
            /// <param name="paramexs"></param>
            /// <returns></returns>
            private static bool FindCallListContains(MethodCallExpression ex, List<Expression> paramexs)
            {
                var results = new bool[ex.Arguments.Count];
                var valid = 0;

                for (var i = ex.Arguments.Count - 1; i >= 0; i--)
                {
                    results[i] = BaseFind(ex.Arguments[i], paramexs);
                    valid += (results[i]) ? 1 : 0;
                }

                if (valid != ex.Arguments.Count)
                {
                    for (var i = 0; i < ex.Arguments.Count; i++)
                    {
                        if (results[i])
                            paramexs.Add(ex.Arguments[i]);
                    }
                }

                AddListValues(paramexs, ex.Object);

                return valid == ex.Arguments.Count;
            }

            /// <summary>
            /// 
            /// </summary>
            /// <param name="exps"></param>
            /// <param name="objex"></param>
            private static void AddListValues(List<Expression> exps, Expression objex)
            {
                var obj = Expression.Lambda(objex).Compile().DynamicInvoke();
                var list = obj as System.Collections.IList;

                for (var i = 0; i < list.Count; i++)
                {
                    var ex =
                        Expression.Call(
                            objex,
                            typeof(System.Collections.IList).GetMethod("get_Item"),
                                Expression.Condition(
                                    Expression.GreaterThan(
                                    Expression.PropertyOrField(objex, (list.GetType().IsArray) ? "Length" : "Count"),
                                        Expression.Constant(i)),
                                    Expression.Constant(i),
                                    Expression.Constant(0)));

                    exps.Add(ex);
                }
            }

            private static bool Find(ConstantExpression ex, List<Expression> paramexs)
            {
                if (ex.Type.Name.Equals("Table`1"))
                    return false;

                return true;
            }

            private static bool Find(UnaryExpression ex, List<Expression> paramexs)
            {
                return BaseFind(ex.Operand, paramexs);
            }

            private static bool Find(LambdaExpression ex, List<Expression> paramexs)
            {
                return BaseFind(ex.Body, paramexs);
            }

            private static bool Find(BinaryExpression ex, List<Expression> paramexs)
            {
                var l = BaseFind(ex.Left, paramexs);
                var r = BaseFind(ex.Right, paramexs);

                if (l && r)
                    return true;

                if (l)
                    paramexs.Add(ex.Left);
                if (r)
                    paramexs.Add(ex.Right);

                return false;
            }

            private static bool Find(MemberExpression ex, List<Expression> paramexs)
            {
                return BaseFind(ex.Expression, paramexs);
            }

            private static bool Find(ParameterExpression ex, List<Expression> paramexs)
            {
                return false;
            }

            private static bool Find(NewExpression ex, List<Expression> paramexs)
            {
                var results = new bool[ex.Arguments.Count];
                var valid = 0;

                for (var i = ex.Arguments.Count - 1; i >= 0; i--)
                {
                    results[i] = BaseFind(ex.Arguments[i], paramexs);
                    valid += (results[i]) ? 1 : 0;
                }

                if (valid != ex.Arguments.Count)
                {
                    for (var i = 0; i < ex.Arguments.Count; i++)
                    {
                        if (results[i])
                            paramexs.Add(ex.Arguments[i]);
                    }

                    return false;
                }

                return true;
            }

            private static bool Find(MemberInitExpression ex, List<Expression> paramexs)
            {
                var valid = 0;
                var results = new bool[ex.Bindings.Count];

                for (var i = 0; i < results.Length; i++)
                {
                    switch (ex.Bindings[i].BindingType)
                    {
                        case MemberBindingType.Assignment:
                            results[i] = BaseFind(((MemberAssignment)ex.Bindings[i]).Expression, paramexs);
                            valid += (results[i] ? 1 : 0);
                            break;

                        default:
                            throw new NotImplementedException(ex.Bindings[i].BindingType.ToString());
                    }
                }

                if (valid != ex.Bindings.Count)
                {
                    for (var i = 0; i < results.Length; i++)
                        if (results[i])
                        {
                            switch (ex.Bindings[i].BindingType)
                            {
                                case MemberBindingType.Assignment:
                                    paramexs.Add(((MemberAssignment)ex.Bindings[i]).Expression);
                                    break;

                                default:
                                    throw new NotImplementedException(ex.Bindings[i].BindingType.ToString());
                            }
                        }

                    return false;
                }

                return true;
            }
            #endregion
        }

        /// <summary>
        /// 
        /// </summary>
        internal class InjectDisplayGetters
        {
            public static LambdaExpression[] Build(List<Expression> exs, Dictionary<object, Expression<Func<Expression, object>>> getters)
            {
                var parameters = new Dictionary<object, ParameterExpression>();

                foreach (var key in getters.Keys)
                    parameters[key] = Expression.Parameter(key.GetType());

                var deles = new LambdaExpression[exs.Count];

                for (var i = 0; i < exs.Count; i++)
                    deles[i] = Expression.Lambda(BaseProcess(exs[i], parameters), parameters.Values.ToArray());

                return deles;
            }

            #region Individual process
            private static Expression BaseProcess(Expression ex, Dictionary<object, ParameterExpression> parameters)
            {
                switch (Map.From(ex.NodeType))
                {
                    case ExpressionEnum.Constant:
                        return Process((ConstantExpression)ex, parameters);
                    case ExpressionEnum.Member:
                        return Process((MemberExpression)ex, parameters);
                    case ExpressionEnum.MethodCall:
                        return Process((MethodCallExpression)ex, parameters);
                    case ExpressionEnum.Conditional:
                        return Process((ConditionalExpression)ex, parameters);
                    case ExpressionEnum.Binary:
                        return Process((BinaryExpression)ex, parameters);
                    case ExpressionEnum.Unary:
                        return Process((UnaryExpression)ex, parameters);
                    case ExpressionEnum.New:
                        return Process((NewExpression)ex, parameters);
                    case ExpressionEnum.Lambda:
                        return Process((LambdaExpression)ex, parameters);
                    case ExpressionEnum.MemberInit:
                        return Process((MemberInitExpression)ex, parameters);
                    case ExpressionEnum.Parameter:
                        return Process((ParameterExpression)ex, parameters);
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            private static Expression Process(ConstantExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                if (parameters.ContainsKey(ex.Value))
                    return parameters[ex.Value];

                return Expression.Constant(ex.Value);
            }

            private static Expression Process(MemberExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return
                    Expression.MakeMemberAccess(
                        BaseProcess(ex.Expression, parameters),
                        ex.Member);
            }

            private static Expression Process(MethodCallExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return
                    Expression.Call(
                        BaseProcess(ex.Object, parameters),
                        ex.Method,
                        (from arg in ex.Arguments select BaseProcess(arg, parameters)).ToArray());
            }

            private static Expression Process(ConditionalExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return
                    Expression.Condition(
                        BaseProcess(ex.Test, parameters),
                        BaseProcess(ex.IfTrue, parameters),
                        BaseProcess(ex.IfFalse, parameters), ex.Type);
            }

            private static Expression Process(BinaryExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return
                    Expression.MakeBinary(
                        ex.NodeType,
                        BaseProcess(ex.Left, parameters),
                        BaseProcess(ex.Right, parameters),
                        ex.IsLiftedToNull, ex.Method, ex.Conversion);
            }

            private static Expression Process(UnaryExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return Expression.MakeUnary(ex.NodeType, BaseProcess(ex.Operand, parameters), ex.Type);
            }

            private static Expression Process(NewExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return Expression.New(ex.Constructor, ex.Arguments, ex.Members);
            }

            private static Expression Process(LambdaExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return
                    Expression.Lambda(
                        ex.Type,
                        BaseProcess(ex.Body, parameters),
                        ex.Parameters);
            }

            private static Expression Process(MemberInitExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                var binds = new List<MemberBinding>();

                foreach (var b in ex.Bindings)
                {
                    switch (b.BindingType)
                    {
                        case MemberBindingType.Assignment:
                            var ma = (MemberAssignment)b;

                            binds.Add(ma.Update(BaseProcess(ma.Expression, parameters)));
                            break;

                        default:
                            throw new NotImplementedException(b.BindingType.ToString());
                    }
                }

                return
                    Expression.MemberInit(
                        (NewExpression)BaseProcess(ex.NewExpression, parameters),
                        binds);
            }

            private static Expression Process(ParameterExpression ex, Dictionary<object, ParameterExpression> parameters)
            {
                return ex;
            }
            #endregion
        }

        /// <summary>
        /// 
        /// </summary>
        internal class GetConversion
        {
            public static List<LambdaExpression> Find(Expression ex)
            {
                var found = new List<LambdaExpression>();

                BaseFind(ex, found);

                return found;
            }

            private static void BaseFind(Expression ex, List<LambdaExpression> found)
            {
                switch (Map.From(ex.NodeType))
                {
                    case ExpressionEnum.MethodCall:
                        Find((MethodCallExpression)ex, found);
                        break;
                    case ExpressionEnum.Constant:
                        Find((ConstantExpression)ex, found);
                        break;
                    case ExpressionEnum.Unary:
                        Find((UnaryExpression)ex, found);
                        break;
                    case ExpressionEnum.Lambda:
                        Find((LambdaExpression)ex, found);
                        break;
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            private static void Find(MethodCallExpression ex, List<LambdaExpression> found)
            {
                if (ex.Method.DeclaringType == typeof(Queryable))
                {
                    switch (ex.Method.Name)
                    {
                        case "Select":
                            BaseFind(ex.Arguments[0], found);
                            BaseFind(ex.Arguments[1], found);
                            break;
                        case "Where":
                            BaseFind(ex.Arguments[0], found);
                            break;
                    }
                }
            }

            private static void Find(ConstantExpression ex, List<LambdaExpression> found) { }

            private static void Find(UnaryExpression ex, List<LambdaExpression> found)
            {
                BaseFind(ex.Operand, found);
            }

            private static void Find(LambdaExpression ex, List<LambdaExpression> found)
            {
                switch (Map.From(ex.Body.NodeType))
                {
                    case ExpressionEnum.MemberInit: return;
                    case ExpressionEnum.Parameter: return;
                    case ExpressionEnum.Member: return;
                    case ExpressionEnum.MethodCall: break;
                    default:
                        throw new NotImplementedException(ex.Body.NodeType.ToString());
                }

                found.Add(ex);
            }
        }

        /// <summary>
        /// 
        /// </summary>
        internal class InjectConversion
        {
            public static Expression Build(Expression body, ParameterExpression param, Expression value)
            {
                return BaseProcess(body, param, value);
            }

            private static Expression BaseProcess(Expression ex, ParameterExpression param, Expression value)
            {
                switch (Map.From(ex.NodeType))
                {
                    case ExpressionEnum.MethodCall:
                        return Process((MethodCallExpression)ex, param, value);
                    case ExpressionEnum.Parameter:
                        return Process((ParameterExpression)ex, param, value);
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            private static Expression Process(MethodCallExpression ex, ParameterExpression param, Expression value)
            {
                var obj = ex.Object;
                if (obj != null)
                    obj = BaseProcess(ex, param, value);

                var args = new Expression[ex.Arguments.Count];

                if (ex.Arguments.Count > 0)
                {
                    for (var i = 0; i < ex.Arguments.Count; i++)
                        args[i] = BaseProcess(ex.Arguments[i], param, value);
                }

                return Expression.Call(obj, ex.Method, args);
            }

            private static Expression Process(ParameterExpression ex, ParameterExpression param, Expression value)
            {
                return ex == param ? value : ex;
            }
        }

        // For MS Linq
        internal class ExtractContext
        {
            public static DataContext Do(Expression ex)
            {
                var table = BaseFind(ex);

                var property = table.GetType().GetProperty("Context");

                return (DataContext)property.GetValue(table, null);
            }

            private static object BaseFind(Expression ex)
            {
                switch (Map.From(ex.NodeType))
                {
                    case ExpressionEnum.MethodCall:
                        return Find((MethodCallExpression)ex);
                    case ExpressionEnum.Constant:
                        return Find((ConstantExpression)ex);
                    default:
                        throw new NotImplementedException(ex.NodeType.ToString());
                }
            }

            protected static object Find(MethodCallExpression ex)
            {
                object obj = null;

                for (var i = 0; i < ex.Arguments.Count; i++)
                {
                    obj = BaseFind(ex.Arguments[i]);
                    if (obj != null)
                        break;
                }

                return obj;
            }

            protected static object Find(ConstantExpression ex)
            {
                return ex.Type.Name.Equals("Table`1") ? ex.Value : null;
            }
        }
        #endregion

        /// <summary>
        /// 
        /// </summary>
        private class CachedQuery
        {
            /// <summary>
            /// Expression of the query
            /// </summary>
            public Expression Expression { get; set; }

            /// <summary>
            /// Command built from the query
            /// </summary>
            public DbCommand Command { get; set; }

            /// <summary>
            /// 
            /// </summary>
            public Func<Expression, object>[] DisplayClassGetter { get; set; }

            public Delegate[] ParameterSetters { get; set; }

            public Func<IDataReader, object> ObjectReader { get; set; }
        }

        public static List<T> ExecuteList<T>(this IQueryable<T> query) where T : class, new()
        {
            // Retrieve/Create the cached query object
            var stored = GetCachedQuery(query);

            // Extract variable values from the query
            Log("Actual display class");
            var values = new object[stored.DisplayClassGetter.Length];
            for (var i = 0; i < values.Length; i++)
            {
                values[i] = stored.DisplayClassGetter[i](query.Expression);
                LogObject(values[i]);
            }

            // Set Command parameter values
            Log("Actual parameter:");
            for (var i = 0; i < stored.Command.Parameters.Count; i++)
            {
                stored.Command.Parameters[i].Value = stored.ParameterSetters[i].DynamicInvoke(values);
                Log("{0}: {1}", stored.Command.Parameters[i].ParameterName, stored.Command.Parameters[i].Value);
            }

            // Get the connection and open
            if (stored.Command.Connection.State != ConnectionState.Open)
                stored.Command.Connection.Open();

            // Retrieve the data
            var queried = new List<T>();

            using (var reader = stored.Command.ExecuteReader())
            {
                if (stored.ObjectReader == null)
                {
                    Log("Building reader");

                    // Build the object reader upon first retrieval
                    BuildObjectReader(typeof(T), stored, reader);
                }

                while (reader.Read())
                    queried.Add((T)stored.ObjectReader(reader));
            }

            // Done
            return queried;
        }

        /// <summary>
        /// Cache for stored procedures
        /// </summary>
        private static List<CachedQuery> Procedures = new List<CachedQuery>();

        /// <summary>
        /// Either retrieve a stored procedure or create a new one
        /// </summary>
        /// <param name="query"></param>
        /// <returns></returns>
        private static CachedQuery GetCachedQuery(IQueryable query)
        {
            var comparer = new CompareExpression(query.Expression);

            // Find the cached procedure that matches the query expression
            var found = Procedures.Find(s => comparer.Equals(s.Expression));

            if (found == null)
            {
                // Not found, then create one
                Log("Create new procedure");
                found = CreateProcedure(query);
                Procedures.Add(found);
            }

            return found;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="query"></param>
        /// <returns></returns>
        private static CachedQuery CreateProcedure(IQueryable query)
        {
            var context = ExtractContext.Do(query.Expression);

            var procedure = new CachedQuery
                                {
                                    Expression = query.Expression,
                                    // Let the context create the command
                                    Command = context.GetCommand(query),
                                };

            // Log the command
            Log("Command:");
            Log(procedure.Command.CommandText);
            Log("Parameters:");
            for (var i = 0; i < procedure.Command.Parameters.Count; i++)
                Log("{0}={1}", procedure.Command.Parameters[i].ParameterName, procedure.Command.Parameters[i].Value);

            // Compile the display object getter
            var getters = GetDisplayClass.Build(query.Expression);
            Log("Getters:");
            foreach (var g in getters)
                Log(g.Value.ToString());

            procedure.DisplayClassGetter = (from g in getters.Values select g.Compile()).ToArray();

            // Compile the parameter setters
            var setters = InjectDisplayGetters.Build(ExractParameter.Do(query.Expression), getters);
            Log("Setters:");
            foreach (var s in setters)
                Log(s.ToString());

            // Ensure the number of parameters is the same as number of parameter setters
            if (setters.Length != procedure.Command.Parameters.Count)
                throw new Exception("Parameter mismatch");

            procedure.ParameterSetters = (from s in setters select s.Compile()).ToArray();

            return procedure;
        }

        /// <summary>
        /// Build the ObjectReader
        /// 
        /// In the form:
        /// new Object { Property = Reader.Get(index), ... }
        /// 
        /// </summary>
        /// <param name="target"></param>
        /// <param name="procedure"></param>
        /// <param name="reader"></param>
        private static void BuildObjectReader(Type target, CachedQuery procedure, IDataReader reader)
        {
            // Reader parameter used in the reader expression
            var readerRef = Expression.Parameter(typeof(IDataReader), "reader");
            var readingType = target;

            // If there are further conversions, use the initial type
            var conversions = GetConversion.Find(procedure.Expression);
            if (conversions.Count > 0)
                readingType = conversions[0].Parameters[0].Type;

            // Bound to readerType
            var raw = new List<MemberAssignment>();

            foreach (var m in readingType.GetMembers())
            {
                if (m.MemberType != MemberTypes.Property)
                    continue;

                var idx = -1;

                try
                {
                    idx = reader.GetOrdinal(m.Name);
                }
                catch (Exception)
                {
                    // Ignore all exception when GetOrdinal
                }

                if (idx == -1)
                    continue;

                raw.Add(Expression.Bind(m, BuildReadRow(m, reader, idx, readerRef)));
            }

            var bindings = raw.Cast<MemberBinding>().ToList();

            Expression step = Expression.MemberInit(
                    Expression.New(readingType.GetConstructor(Type.EmptyTypes)),
                        bindings);

            // Make further conversions if any
            foreach (var lamba in conversions)
                step = InjectConversion.Build(lamba.Body, lamba.Parameters[0], step);

            Log(step.ToString());

            procedure.ObjectReader = Expression.Lambda<Func<IDataReader, object>>(step, readerRef).Compile();
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="member"></param>
        /// <param name="reader"></param>
        /// <param name="idx"></param>
        /// <param name="reference"></param>
        /// <returns></returns>
        private static Expression BuildReadRow(MemberInfo member, IDataReader reader, int idx, ParameterExpression reference)
        {
            MethodInfo method = null;

            var readerType = typeof(IDataRecord);

            var sourceType = reader.GetFieldType(idx);
            var destType = ((PropertyInfo)member).PropertyType;

            if (sourceType == typeof(int))
                method = readerType.GetMethod("GetInt32");
            else if (sourceType == typeof(double))
                method = readerType.GetMethod("GetDouble");
            else if (sourceType == typeof(DateTime))
                method = readerType.GetMethod("GetDateTime");
            else if (sourceType == typeof(string))
                method = readerType.GetMethod("GetString");
            else
                throw new NotImplementedException(sourceType.ToString());

            if (sourceType != destType)
                return Expression.Convert(
                        Expression.Call(reference, method, Expression.Constant(idx)),
                        member.DeclaringType);
            else
                return Expression.Call(reference, method, Expression.Constant(idx));
        }
    }
}

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Software Developer (Senior) 3PLearning
Australia Australia
Lead Developer, MMO Game Company
Testor, Microsoft

Comments and Discussions