Click here to Skip to main content
15,892,537 members
Articles / Web Development / ASP.NET

Signum Framework Principles

Rate me:
Please Sign up or sign in to vote.
4.74/5 (27 votes)
25 Jul 2011CPOL18 min read 99.4K   1.1K   86  
Explains the philosophy behind Signum Framework, an ORM with a full LINQ Provider that encourages an entities-first approach.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using Signum.Utilities;
using Signum.Utilities.DataStructures;
using Signum.Utilities.ExpressionTrees;
using System.Reflection;
using Signum.Utilities.Reflection;
using System.Collections;
using Signum.Engine.Properties;

namespace Signum.Engine.Linq
{
    /// <summary>
    /// Nominator is a class that walks an expression tree bottom up, determining the set of 
    /// candidate expressions that are possible columns of a select expression
    /// </summary>
    internal class Nominator : DbExpressionVisitor
    {
        public static ConditionsRewriter ConditionsRewriter = new ConditionsRewriter();

        string[] existingAliases;
        HashSet<Expression> candidates = new HashSet<Expression>();

        private Nominator() { }

        static internal HashSet<Expression> Nominate(Expression expression, string[] existingAliases, out Expression newExpression)
        {
            Nominator n = new Nominator { existingAliases = existingAliases };
            newExpression = n.Visit(expression);
            return n.candidates;
        }

        static internal Expression FullNominate(Expression expression, bool isCondition)
        {
            Nominator n = new Nominator { existingAliases = null };
            Expression result = n.Visit(expression);
            if (!n.candidates.Contains(result))
                throw new ApplicationException(Resources.TheExpressionCanTBeTranslatedToSQL + expression.ToString());

            if (isCondition)
                result = ConditionsRewriter.MakeSqlCondition(result);
            else
                result = ConditionsRewriter.MakeSqlValue(result); 

            return result;
        }

        protected override Expression VisitColumn(ColumnExpression column)
        {

            if (existingAliases == null || 
                // existingAliases is null when used in QueryBinder, not ColumnProjector
                // this allows to make function changes in where clausules but keeping the full expression (not compressing it in one column)
                existingAliases.Contains(column.Alias))
                candidates.Add(column);
            return column;
        }

        protected override NewExpression VisitNew(NewExpression nex)
        {
            if (existingAliases == null)
            {
                IEnumerable<Expression> args = this.VisitExpressionList(nex.Arguments);
                if (args != nex.Arguments)
                {
                    if (nex.Members != null)
                        // parece que para los tipos anonimos hace falt exact type matching
                        nex = Expression.New(nex.Constructor, args, nex.Members);
                    else
                        nex = Expression.New(nex.Constructor, args);
                }

                if (args.All(a => candidates.Contains(a)))
                    candidates.Add(nex);

                return nex;
            }
            else
                return base.VisitNew(nex);
        }

        protected override Expression VisitConstant(ConstantExpression c)
        {
            candidates.Add(c);
            return c;
        }

        protected override Expression VisitProjection(ProjectionExpression proj)
        {
            if (proj.IsOneCell)
            {
                if (proj.UniqueFunction == UniqueFunction.SingleIsZero)
                {
                    var newProj = new ProjectionExpression(typeof(int), proj.Source, proj.Projector, UniqueFunction.Single);
                    candidates.Add(newProj); 
                    Expression result = Expression.Equal(newProj, Expression.Constant(0));
                    candidates.Add(result);
                    return result;
                }
                else if (proj.UniqueFunction == UniqueFunction.SingleGreaterThanZero)
                {
                    var newProj = new ProjectionExpression(typeof(int), proj.Source, proj.Projector, UniqueFunction.Single);
                    candidates.Add(newProj); 
                    Expression result = Expression.GreaterThan(newProj, Expression.Constant(0));
                    candidates.Add(result);
                    return result;
                }
                else
                {
                    candidates.Add(proj);
                    return proj;
                }
            }
            else
                return base.VisitProjection(proj);
        }

        protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction)
        {
            candidates.Add(sqlFunction);
            return sqlFunction;
        }

        protected Expression TrySqlFunction(SqlFunction sqlFunction, Type type, params Expression[] expression)
        {
            return TrySqlFunction(sqlFunction.ToString(), type, expression); 
        }

        protected Expression TrySqlFunction(string sqlFunction, Type type, params Expression[] expression)
        {
            expression = expression.NotNull().ToArray();
            Expression[] newExpressions = new Expression[expression.Length];

            for (int i = 0; i < expression.Length; i++)
            {
                newExpressions[i] = Visit(expression[i]);
                if (!candidates.Contains(newExpressions[i]))
                    return null;
            }

            var result = new SqlFunctionExpression(type, sqlFunction.ToString(), newExpressions);
            candidates.Add(result);
            return result;
        }

        private SqlFunctionExpression TrySqlDifference(SqlEnums sqlEnums, Type type, Expression expression)
        {
            BinaryExpression be = expression as BinaryExpression;

            if (be == null || be.NodeType != ExpressionType.Subtract)
                return null;

            Expression left = Visit(be.Left);
            if (!candidates.Contains(left))
                return null;


            Expression right = Visit(be.Right);
            if (!candidates.Contains(right))
                return null;


            SqlFunctionExpression result = new SqlFunctionExpression(type, SqlFunction.DATEDIFF.ToString(), new Expression[]{
                new SqlEnumExpression(sqlEnums), right, left});

            candidates.Add(result);

            return result;
        }

        private Expression TrySqlDate(Expression expression)
        {
            Expression expr = Visit(expression);
            if (!candidates.Contains(expr))
                return null;

            Expression result = DateAdd(SqlEnums.hour, MinusDatePart(SqlEnums.hour, expr),
                                    DateAdd(SqlEnums.minute, MinusDatePart(SqlEnums.minute, expr),
                                        DateAdd(SqlEnums.second, MinusDatePart(SqlEnums.second, expr),
                                            DateAdd(SqlEnums.millisecond, MinusDatePart(SqlEnums.millisecond, expr), expr))));

            candidates.Add(result);
            return result; 
        }

        private Expression DateAdd(SqlEnums part, Expression dateExpression, Expression intExpression)
        {
            return new SqlFunctionExpression(typeof(DateTime), SqlFunction.DATEADD.ToString(), new Expression[] { new SqlEnumExpression(part), dateExpression, intExpression });
        }

        private Expression MinusDatePart(SqlEnums part, Expression dateExpression)
        {
            return Expression.Negate(new SqlFunctionExpression(typeof(int), SqlFunction.DATEPART.ToString(), new Expression[] { new SqlEnumExpression(part), dateExpression }));
        }


        protected override Expression VisitBinary(BinaryExpression b)
        {
            if (existingAliases == null)
            {
                var nuevo = Transform(b);
                if (nuevo != null)
                    return Visit(nuevo);
            }

            Expression left = this.Visit(b.Left);
            Expression right = this.Visit(b.Right);
            Expression conversion = this.Visit(b.Conversion);
            if (left != b.Left || right != b.Right || conversion != b.Conversion)
            {
                if (b.NodeType == ExpressionType.Coalesce && b.Conversion != null)
                    b = Expression.Coalesce(left, right, conversion as LambdaExpression);
                else
                    b = Expression.MakeBinary(b.NodeType, left, right, b.IsLiftedToNull, b.Method);
            }

            if (candidates.Contains(left) && candidates.Contains(right))
                candidates.Add(b);

            return b;
        }

        private static Expression Transform(BinaryExpression b)
        {
            if (b.NodeType == ExpressionType.Equal || b.NodeType == ExpressionType.NotEqual)
            {
                var newb = SmartEqualizer.PolymorphicEqual(b.Left, b.Right);

                if (newb.NodeType == b.NodeType && ((BinaryExpression)newb).Map(nb => nb.Left == b.Left && nb.Right == b.Right))
                {
                    return null; 
                }
                else if (b.NodeType == ExpressionType.NotEqual)
                    return Expression.Not(newb);
                else
                    return newb;
            }
            return null;
        }


        protected override Expression VisitConditional(ConditionalExpression c)
        {
            Expression result = c;
            Expression test = this.Visit(c.Test);
            Expression ifTrue = this.Visit(c.IfTrue);
            Expression ifFalse = this.Visit(c.IfFalse);
            if (test != c.Test || ifTrue != c.IfTrue || ifFalse != c.IfFalse)
            {
                result = Expression.Condition(test, ifTrue, ifFalse);
            }

            if (candidates.Contains(test) && candidates.Contains(ifTrue) && candidates.Contains(ifFalse))
            {
                Expression newTest = ConditionsRewriter.MakeSqlCondition(test); 

                if (ifFalse.NodeType == (ExpressionType)DbExpressionType.Case)
                {
                    var oldC  = (CaseExpression)ifFalse;
                    candidates.Remove(ifFalse); // just to save some memory
                    result = new CaseExpression(oldC.Whens.PreAnd(new When(newTest, ifTrue)), oldC.DefaultValue);
                }
                else
                    result = new CaseExpression(new[] { new When(newTest, ifTrue) }, ifFalse);

                candidates.Add(result);
            }
            return result;
        }

        protected override Expression VisitUnary(UnaryExpression u)
        {
            Expression operand = this.Visit(u.Operand);
            if (operand != u.Operand)
                u = Expression.MakeUnary(u.NodeType, operand, u.Type, u.Method);

            if (candidates.Contains(operand))
                candidates.Add(u);

            return u;
        }

        protected override Expression VisitEnumExpression(EnumExpression enumExp)
        {
            var id = (ColumnExpression)Visit(enumExp.ID);
            if (id != enumExp.ID)
                enumExp =  new EnumExpression(enumExp.Type, id);

            if (candidates.Contains(id))
                candidates.Add(enumExp); 

            return enumExp;
        }

        protected override Expression VisitSqlEnum(SqlEnumExpression sqlEnum)
        {
            candidates.Add(sqlEnum);
            return sqlEnum;
        }

         private LikeExpression TryLike(Expression expression, Expression pattern)
        {
            pattern = Evaluator.PartialEval(pattern);
             Expression newPattern = Visit(pattern); 
             Expression newExpression = Visit(expression);

             LikeExpression result = new LikeExpression(newExpression,newPattern); 

             if(candidates.Contains(newPattern) && candidates.Contains(newExpression))
             {
                 candidates.Add(result); 
                 return result;
             }
             return null;
        }

        protected override Expression VisitMemberAccess(MemberExpression m)
        {
            Expression newM = new Switch<string, Expression>(m.Member.DeclaringType.TypeName() + "." + m.Member.Name)
                            .Case("string.Length", a => TrySqlFunction(SqlFunction.LEN, m.Type, m.Expression))
                            .Case("Math.PI", a => TrySqlFunction(SqlFunction.PI, m.Type))
                            .Case("DateTime.Now", a => TrySqlFunction(SqlFunction.GETDATE, m.Type))
                            .Case("DateTime.Year", a => TrySqlFunction(SqlFunction.YEAR, m.Type, m.Expression))
                            .Case("DateTime.Month", a => TrySqlFunction(SqlFunction.MONTH, m.Type, m.Expression))
                            .Case("DateTime.Day", a => TrySqlFunction(SqlFunction.DAY, m.Type, m.Expression))
                            .Case("DateTime.DayOfYear", a => TrySqlFunction(SqlFunction.DATEPART, m.Type, new SqlEnumExpression(SqlEnums.dayofyear), m.Expression))
                            .Case("DateTime.Hour", a => TrySqlFunction(SqlFunction.DATEPART, m.Type, new SqlEnumExpression(SqlEnums.hour), m.Expression))
                            .Case("DateTime.Minute", a => TrySqlFunction(SqlFunction.DATEPART, m.Type, new SqlEnumExpression(SqlEnums.minute), m.Expression))
                            .Case("DateTime.Second", a => TrySqlFunction(SqlFunction.DATEPART, m.Type, new SqlEnumExpression(SqlEnums.second), m.Expression))
                            .Case("DateTime.Millisecond", a => TrySqlFunction(SqlFunction.DATEPART, m.Type, new SqlEnumExpression(SqlEnums.millisecond), m.Expression))
                            .Case("DateTime.Date", a=>TrySqlDate(m.Expression))
                            .Case("TimeSpan.TotalDays", a=> TrySqlDifference(SqlEnums.day, m.Type, m.Expression))
                            .Case("TimeSpan.TotalHours", a=> TrySqlDifference(SqlEnums.hour, m.Type, m.Expression))
                            .Case("TimeSpan.TotalMilliseconds", a=> TrySqlDifference(SqlEnums.millisecond, m.Type, m.Expression))
                            .Case("TimeSpan.TotalSeconds", a=> TrySqlDifference(SqlEnums.second, m.Type, m.Expression))
                            .Case("TimeSpan.TotalMinutes", a=> TrySqlDifference(SqlEnums.minute, m.Type, m.Expression))
                            .Default((Expression)null);

             if (newM != null)
                return newM;
     
            if (m.Expression.Type.IsNullable() && (m.Member.Name == "Value" || m.Member.Name == "HasValue"))
            {
                Expression expression = this.Visit(m.Expression);

                if (m.Member.Name == "Value")
                    newM = Expression.Convert(expression, m.Expression.Type.UnNullify());
                else
                    newM = Expression.NotEqual(expression, Expression.Constant(null));

                if (candidates.Contains(expression))
                    candidates.Add(newM);

                return newM;
            }

            return base.VisitMemberAccess(m);
        }

       



        static MethodInfo c = typeof(string).GetMethod("Concat", new[] { typeof(string), typeof(string) });

        protected override Expression VisitMethodCall(MethodCallExpression m)
        {
            SqlMethodAttribute sma = m.Method.SingleAttribute<SqlMethodAttribute>();
            if (sma != null)
                return TrySqlFunction(sma.Name?? m.Method.Name, m.Type, m.Arguments.ToArray()); 

            Expression newM = new Switch<string, Expression>(m.Method.DeclaringType.TypeName() + "." + m.Method.MethodName())
                            .Case("string.IndexOf", a => TrySqlFunction(SqlFunction.CHARINDEX, m.Type, m.GetArgument("value"), m.Object, m.TryGetArgument("startIndex").TryCC(e => Expression.Add(e, Expression.Constant(1)))).TryCC(e => Expression.Subtract(e, Expression.Constant(1))))
                            .Case("string.ToLower", a => TrySqlFunction(SqlFunction.LOWER, m.Type, m.Object))
                            .Case("string.ToUpper", a => TrySqlFunction(SqlFunction.UPPER, m.Type, m.Object))
                            .Case("string.TrimStart", a => TrySqlFunction(SqlFunction.LTRIM, m.Type, m.Object))
                            .Case("string.TrimEnd", a => TrySqlFunction(SqlFunction.RTRIM, m.Type, m.Object))
                            .Case("string.Replace", a => TrySqlFunction(SqlFunction.REPLACE, m.Type, m.Object, m.GetArgument("oldValue"), m.GetArgument("newValue")))
                            .Case("string.Substring", a => TrySqlFunction(SqlFunction.SUBSTRING, m.Type, m.Object, Expression.Add(m.GetArgument("startIndex"), Expression.Constant(1)), m.TryGetArgument("length") ?? Expression.Constant(int.MaxValue)))
                            // escapar los patrones es muy complicado en expresiones generales (hacerlo en SQL)
                            .Case("string.Contains", a=>TryLike(m.Object, Expression.Add(Expression.Add( Expression.Constant("%"), m.GetArgument("value"), c), Expression.Constant("%"), c))) 
                            .Case("string.StartsWith", a => TryLike(m.Object, Expression.Add(m.GetArgument("value"), Expression.Constant("%"), c)))
                            .Case("string.EndsWith", a => TryLike(m.Object, Expression.Add(Expression.Constant("%"), m.GetArgument("value"), c)))
                          
                            .Case("StringExtensions.Left", a => TrySqlFunction(SqlFunction.LEFT, m.Type, m.GetArgument("s"), m.GetArgument("numChars")))
                            .Case("StringExtensions.Right", a => TrySqlFunction(SqlFunction.RIGHT, m.Type, m.GetArgument("s"), m.GetArgument("numChars")))
                            .Case("StringExtensions.Replicate", a => TrySqlFunction(SqlFunction.REPLICATE, m.Type, m.GetArgument("s"), m.GetArgument("times")))
                            .Case("StringExtensions.Reverse", a => TrySqlFunction(SqlFunction.REVERSE, m.Type, m.GetArgument("s")))
                            .Case("StringExtensions.Like", a=>TryLike(m.GetArgument("s"), m.GetArgument("pattern"))) 

                            .Case("DateTime.AddDays", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.day), m.GetArgument("value"), m.Object))
                            .Case("DateTime.AddHours", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.hour), m.GetArgument("value"), m.Object))
                            .Case("DateTime.AddMilliseconds", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.millisecond), m.GetArgument("value"), m.Object))
                            .Case("DateTime.AddMinutes", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.minute), m.GetArgument("value"), m.Object))
                            .Case("DateTime.AddMonths", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.month), m.GetArgument("value"), m.Object))
                            .Case("DateTime.AddSeconds", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.second), m.GetArgument("value"), m.Object))
                            .Case("DateTime.AddYears", a=>TrySqlFunction(SqlFunction.DATEADD, m.Type, new SqlEnumExpression(SqlEnums.year), m.GetArgument("value"), m.Object))

                            .Case("Math.Sign", a => TrySqlFunction(SqlFunction.SIGN, m.Type, m.GetArgument("value")))
                            .Case("Math.Abs", a => TrySqlFunction(SqlFunction.ABS, m.Type, m.GetArgument("value")))
                            .Case("Math.Sin", a => TrySqlFunction(SqlFunction.SIN, m.Type, m.GetArgument("a")))
                            .Case("Math.Asin", a => TrySqlFunction(SqlFunction.ASIN, m.Type, m.GetArgument("d")))
                            .Case("Math.Cos", a => TrySqlFunction(SqlFunction.COS, m.Type, m.GetArgument("d")))
                            .Case("Math.Acos", a => TrySqlFunction(SqlFunction.ACOS, m.Type, m.GetArgument("d")))
                            .Case("Math.Tan", a => TrySqlFunction(SqlFunction.TAN, m.Type, m.GetArgument("a")))
                            .Case("Math.Pow", a => TrySqlFunction(SqlFunction.POWER, m.Type, m.GetArgument("x"), m.GetArgument("y")))
                            .Case("Math.Sqrt", a => TrySqlFunction(SqlFunction.SQRT, m.Type, m.GetArgument("d")))
                            .Case("Math.Exp", a => TrySqlFunction(SqlFunction.EXP, m.Type, m.GetArgument("d")))
                            .Case("Math.Floor", a => TrySqlFunction(SqlFunction.FLOOR, m.Type, m.GetArgument("d")))
                            .Case("Math.Log10", a => TrySqlFunction(SqlFunction.Log10, m.Type, m.GetArgument("d")))
                            .Case("Math.Ceiling", a => TrySqlFunction(SqlFunction.CEILING, m.Type, m.TryGetArgument("d") ?? m.GetArgument("a")))
                            .Case("Math.Round", a => TrySqlFunction(SqlFunction.ROUND, m.Type,
                                m.TryGetArgument("a") ?? m.TryGetArgument("d") ?? m.GetArgument("value"),
                                m.TryGetArgument("decimals") ?? m.TryGetArgument("digits") ?? Expression.Constant(0)))
                            .Default((Expression)null);

            return newM ?? base.VisitMethodCall(m);
        }

        protected override Expression VisitImplementedBy(ImplementedByExpression reference)
        {
            if (existingAliases != null)
                return base.VisitImplementedBy(reference); 

            var newImple = reference.Implementations
              .NewIfChange(ri => Visit(ri.Field).Map(r => r == ri.Field ? ri : new ImplementationColumnExpression(ri.Type, (FieldInitExpression)r)));

            if (newImple != reference.Implementations)
                reference = new ImplementedByExpression(reference.Type, newImple);

            if (newImple.All(i => candidates.Contains(i.Field)))
                candidates.Add(reference);

            return reference;
        }

        protected override Expression VisitImplementedByAll(ImplementedByAllExpression reference)
        {
            if (existingAliases != null || reference.Implementations != null && reference.Implementations.Count > 0)
                return base.VisitImplementedByAll(reference); 

            var id = (ColumnExpression)Visit(reference.ID);
            var typeId = (ColumnExpression)Visit(reference.TypeID);

            if (id != reference.ID || typeId != reference.TypeID)
                reference = new ImplementedByAllExpression(reference.Type, id, typeId);

            if (candidates.Contains(id) && candidates.Contains(typeId))
                candidates.Add(reference);

            return reference;
        }

        protected override Expression VisitFieldInit(FieldInitExpression fieldInit)
        {
            if (existingAliases != null || fieldInit.Bindings != null && fieldInit.Bindings.Count > 0)
                return base.VisitFieldInit(fieldInit);

            var id = Visit(fieldInit.ID);
            var alias = VisitFieldInitAlias(fieldInit.Alias);
            if (fieldInit.ID != id)
                fieldInit = new FieldInitExpression(fieldInit.Type, alias, id);

            if (candidates.Contains(id))
                candidates.Add(fieldInit);
            
            return fieldInit;
        }
    }
}

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Software Developer (Senior) Signum Software
Spain Spain
I'm Computer Scientist, one of the founders of Signum Software, and the lead developer behind Signum Framework.

www.signumframework.com

I love programming in C#, Linq, Compilers, Algorithms, Functional Programming, Computer Graphics, Maths...

Comments and Discussions