Click here to Skip to main content
11,496,146 members (798 online)
Click here to Skip to main content
Add your own
alternative version

Modifying LINQ Expressions with Rewrite Rules

, 18 Mar 2008 CPOL 42.6K 582 62
Rewriting query expressions is a simple and yet safe and powerful technique to modify queries dynamically at runtime.
The site is currently in read-only mode for maintenance. Posting of new items will be available again shortly.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Collections.ObjectModel;
using System.Text;
using System.Reflection;

namespace Rewrite {
    public class SimpleRewriter: ExpressionVisitor{
        private Rule _rule;
        private MatchExpressions _match;
        private bool _done;
        private Expression _expr;
        private Expression _was;
        private Expression _is;

        public SimpleRewriter( Expression sourceExpr ) {
            _expr = sourceExpr;
            _match = new MatchExpressions();
        }

        public Expression ApplyOnce( Rule rule){
            _rule = rule;
            _done = false;
            _expr = this.Visit( _expr );
            return _expr;
        }

        public bool Success {
            get { return _done; }
        }

        public Expression Expression {
            get { return _expr; }
        }

        public Expression Replaced {
            get { return _was; }
        }
        public Expression ReplacedFor {
            get { return _is; }
        }

        public static Expression ApplyOnce( Expression sourceExpr, Rule rule ) {
            SimpleRewriter srw = new SimpleRewriter( sourceExpr );
            return srw.ApplyOnce( rule );
        }
        public static Expression<TFunc> ApplyOnce<TFunc>( Expression<TFunc> sourceExpr, Rule rule ) {
            SimpleRewriter srw = new SimpleRewriter( sourceExpr );
            return (Expression<TFunc>)srw.ApplyOnce( rule );
        }

        protected override Expression Visit( Expression exp ) {
            if( exp == null || _done)
                return exp; // ------------>>>>>>>>>>>>>>>>>>>>>

            if( _match.Match( exp, _rule.Lhs ) ) {
                _was = exp;
                _done = true;
                _is = _match.Multiply( _rule.Rhs );
                return _is;
            }

            return base.Visit( exp );
        }
    }

    public class Rule {
        private LambdaExpression _lhs;
        private LambdaExpression _rhs;

        public Rule( LambdaExpression lhs, LambdaExpression rhs ) {
            if( lhs == null )
                throw new ArgumentNullException( "lhs" );
            if( rhs == null )
                throw new ArgumentNullException( "rhs" );
            if( lhs.Type != rhs.Type ) {
                throw new ArgumentException(
                    string.Format(
                    "Rhs type {0} is not equal to Lhs type {1}.",
                    rhs.Type.Name, lhs.Type.Name ) );
            }
            _lhs = lhs;
            _rhs = rhs;
        }

        #region a'la ctors
        public static Rule Create<TResult>( Expression<Func<TResult>> lhs, Expression<Func<TResult>> rhs ) {
            return new Rule( lhs, rhs );
        }
        public static Rule Create<TArg, TResult>( Expression<Func<TArg, TResult>> lhs, Expression<Func<TArg,TResult>> rhs ) {
            return new Rule( lhs, rhs );
        }
        #endregion

        public LambdaExpression Lhs {
            get { return _lhs; }
        }
        public LambdaExpression Rhs {
            get { return _rhs; }
        }
    }

    public static class EvaluateLiteral {
        public static LambdaExpression CreateRhs( Type type, object value ) {
            ConstantExpression cex = Expression.Constant( value );
            LambdaExpression rhs;
            if( type != cex.Type ) {
                rhs = Expression.Lambda( Expression.Convert( cex, type ) );
            } else
                rhs = Expression.Lambda( cex );
            return rhs;
        }

        public static Expression<Func<VType>> CreateRhs<VType>( VType value ) {
            ConstantExpression cex = Expression.Constant( value );
            return (Expression<Func<VType>>)Expression.Lambda( cex );
        }
    }

    public class Multiplier: ExpressionVisitor {
        Dictionary<ParameterExpression, Expression> _subst;

        private Multiplier( Dictionary<ParameterExpression, Expression> subst ) {
            _subst = subst;
        }

        public static Expression Multiply( Expression expr, Dictionary<ParameterExpression, Expression> subst ) {
            Multiplier m = new Multiplier( subst );
            return m.Visit( expr );
        }

        protected override Expression VisitParameter( ParameterExpression p ) {
            Expression substFor;
            if( _subst.TryGetValue( p, out substFor ) )
                return substFor;
            return base.VisitParameter( p );
        }
    }
    
    // TODO: Compare value types, convert them. 
    public class MatchExpressions: ExpressionsCompareVisitor {
        private Dictionary<ParameterExpression, Expression> _subst;
        private Type _patternType;
        private ReadOnlyCollection<ParameterExpression> _patternParams;

        public MatchExpressions(  ) {
        }

        public bool Match( Expression expr1, LambdaExpression pattern ) {
            _patternType = pattern.Type;
            _patternParams = pattern.Parameters;
            _subst = pattern.Parameters.ToDictionary( p => p, p => (Expression)null);
            
            bool success = Visit( expr1, pattern.Body );
            if( success ) {
                if( _subst.Count( x => x.Value == null ) > 0 )
                    throw new ArgumentException(
                        string.Format(
                            "Parameters '{0}' don't occur in the body of the pattersn.",
                            string.Join( ", ", _subst.Where( x => x.Value == null ).Select( x => x.Key.Name ).ToArray() ) ) );
            }

            return success;
        }

        public Expression Multiply( LambdaExpression rhs ) {
            if( rhs == null)
                throw new ArgumentNullException( "rhs");
            //if( _patternType != rhs.Type ) { -- ?????
                if( _patternType == null )
                    throw new InvalidOperationException( "Matching must be done before multiplication." );
                //throw new ArgumentException( string.Format(
                //      "Can't multiply expression of a type different from patternType. "
                //    + "Pattern Type: {0}, expression type: {1}.",
                //    _patternType.Name,
                //    rhs.Type.Name ) );
            //}

            Dictionary<ParameterExpression, Expression> renamedSubst =
                rhs.Parameters
                .Select( ( pe, i ) => new { P = pe, E = _subst[ _patternParams[ i ] ] } )
                .ToDictionary( x => x.P, x => x.E );

            return Multiplier.Multiply( rhs.Body, renamedSubst );
        }

        protected override bool Visit( Expression exp, Expression exp2 ) {
            if( exp != null && exp2 != null && exp2.NodeType == ExpressionType.Parameter) {
                ParameterExpression lhsVar = (ParameterExpression)exp2;
                if( !base.IsInternalParameter( lhsVar ) ) {
                    Expression varSubst = _subst[ lhsVar ];
                    // Is there already a substitution?
                    if( varSubst == null ) {
                        if( lhsVar.Type.IsAssignableFrom( exp.Type)) {
                            _subst[ lhsVar ] = exp;
                            return true; // ------------->>>>>>>>>>>>>>>>>>>>>>
                        } else
                            return false; // ------------->>>>>>>>>>>>>>>>>>>>>>
                    } else {
                        // varSubst != null
                        ExpressionComparer cmp = new ExpressionComparer();
                        return cmp.AreEqual( exp, varSubst ); // ---------->>>>>>>>>>>>>>>>>>>>>>>>>>
                    }
                }
            }
            
            return base.Visit( exp, exp2 );
        }

        public string DbgSubst() {
            StringBuilder sb = new StringBuilder();
            foreach( var su in _subst )
                sb.AppendFormat( "\n{0} {1}   <-  {2}",
                    su.Key.Type.Name,
                    su.Key.Name,
                    su.Value );
            return sb.ToString();
        }
    }

    /// <summary>
    /// Makes <see cref="ExpressionsComparingVisitor"/> not abstract. 
    /// </summary>
    public class ExpressionComparer: ExpressionsCompareVisitor {
        public ExpressionComparer() { }

        public bool AreEqual( Expression expr1, Expression expr2 ) {
            return base.Visit( expr1, expr2 );
        }
    }

    /// <summary>
    /// Compares two expressions up to renaming of lambda parameters. 
    /// Types are not compared (only in lambda param definitions). 
    /// </summary>
    public abstract class ExpressionsCompareVisitor {
        // Maps Expr1.Param -> Expr2.Param
        private Dictionary<ParameterExpression, ParameterExpression> _paramDict;

        protected ExpressionsCompareVisitor() {
            _paramDict = new Dictionary<ParameterExpression, ParameterExpression>();
        }

        protected bool IsInternalParameter( ParameterExpression p2 ) {
            return _paramDict.Values.Contains( p2 );
        }

        //protected IDictionary<ParameterExpression, ParameterExpression> VarMapping {
        //    get { return _paramDict; }
        //}

        //protected virtual bool ExpressionsEqual( Expression exp, Expression exp2 ) {
        //    _paramDict = new Dictionary<ParameterExpression, ParameterExpression>();
        //    return this.Visit( exp, exp2 );
        //}

        //protected virtual bool ExpressionsEqual( Expression exp, Expression exp2,
        //    IEnumerable<ParameterExpression> boundVars ) {
        //    _paramDict = boundVars.ToDictionary( x => x, x => x );
        //    return this.Visit( exp, exp2 );
        //}

        protected virtual bool Visit( Expression exp, Expression exp2 ) {
            if( exp2 == null )
                return exp == null;
            if( exp == null )
                return false;

            if( exp.NodeType != exp2.NodeType ) {
                return false;
            }

            switch( exp.NodeType ) {
                case ExpressionType.Negate:
                case ExpressionType.NegateChecked:
                case ExpressionType.Not:
                case ExpressionType.Convert:
                case ExpressionType.ConvertChecked:
                case ExpressionType.ArrayLength:
                case ExpressionType.Quote:
                case ExpressionType.TypeAs:
                    return this.VisitUnary( (UnaryExpression)exp, (UnaryExpression)exp2 );
                case ExpressionType.Add:
                case ExpressionType.AddChecked:
                case ExpressionType.Subtract:
                case ExpressionType.SubtractChecked:
                case ExpressionType.Multiply:
                case ExpressionType.MultiplyChecked:
                case ExpressionType.Divide:
                case ExpressionType.Modulo:
                case ExpressionType.And:
                case ExpressionType.AndAlso:
                case ExpressionType.Or:
                case ExpressionType.OrElse:
                case ExpressionType.LessThan:
                case ExpressionType.LessThanOrEqual:
                case ExpressionType.GreaterThan:
                case ExpressionType.GreaterThanOrEqual:
                case ExpressionType.Equal:
                case ExpressionType.NotEqual:
                case ExpressionType.Coalesce:
                case ExpressionType.ArrayIndex:
                case ExpressionType.RightShift:
                case ExpressionType.LeftShift:
                case ExpressionType.ExclusiveOr:
                    return this.VisitBinary( (BinaryExpression)exp, (BinaryExpression)exp2 );
                case ExpressionType.TypeIs:
                    return this.VisitTypeIs( (TypeBinaryExpression)exp, (TypeBinaryExpression)exp2 );
                case ExpressionType.Conditional:
                    return this.VisitConditional( (ConditionalExpression)exp, (ConditionalExpression)exp2 );
                case ExpressionType.Constant:
                    return this.VisitConstant( (ConstantExpression)exp, (ConstantExpression)exp2 );
                case ExpressionType.Parameter:
                    return this.VisitParameter( (ParameterExpression)exp, (ParameterExpression)exp2 );
                case ExpressionType.MemberAccess:
                    return this.VisitMemberAccess( (MemberExpression)exp, (MemberExpression)exp2 );
                case ExpressionType.Call:
                    return this.VisitMethodCall( (MethodCallExpression)exp, (MethodCallExpression)exp2 );
                case ExpressionType.Lambda:
                    return this.VisitLambda( (LambdaExpression)exp, (LambdaExpression)exp2 );
                case ExpressionType.New:
                    return this.VisitNew( (NewExpression)exp, (NewExpression)exp2 );
                case ExpressionType.NewArrayInit:
                case ExpressionType.NewArrayBounds:
                    return this.VisitNewArray( (NewArrayExpression)exp, (NewArrayExpression)exp2 );
                case ExpressionType.Invoke:
                    return this.VisitInvocation( (InvocationExpression)exp, (InvocationExpression)exp2 );
                case ExpressionType.MemberInit:
                    return this.VisitMemberInit( (MemberInitExpression)exp, (MemberInitExpression)exp2 );
                case ExpressionType.ListInit:
                    return this.VisitListInit( (ListInitExpression)exp, (ListInitExpression)exp2 );
                default:
                    throw new Exception( string.Format( "Unhandled expression type: '{0}'", exp.NodeType ) );
            }
        }

        protected virtual bool VisitBinding( MemberBinding binding, MemberBinding binding2 ) {
            if( binding.Member != binding2.Member )
                return false;

            switch( binding.BindingType ) {
                case MemberBindingType.Assignment:
                    return this.VisitMemberAssignment( (MemberAssignment)binding, (MemberAssignment)binding2 );
                case MemberBindingType.MemberBinding:
                    return this.VisitMemberMemberBinding( (MemberMemberBinding)binding, (MemberMemberBinding)binding2 );
                case MemberBindingType.ListBinding:
                    return this.VisitMemberListBinding( (MemberListBinding)binding, (MemberListBinding)binding2 );
                default:
                    throw new Exception( string.Format( "Unhandled binding type '{0}'", binding.BindingType ) );
            }
        }

        protected virtual bool VisitElementInitializer( ElementInit initializer, ElementInit initializer2 ) {
            return initializer.AddMethod == initializer2.AddMethod
                && this.VisitExpressionList( initializer.Arguments, initializer2.Arguments );
        }

        protected virtual bool VisitUnary( UnaryExpression u, UnaryExpression u2 ) {
            return u.Method == u2.Method
                && u.IsLifted == u2.IsLifted
                && u.IsLiftedToNull == u2.IsLiftedToNull
                && this.Visit( u.Operand, u2.Operand );
        }

        protected virtual bool VisitBinary( BinaryExpression b, BinaryExpression b2 ) {
            return b.Method == b2.Method
                && b.IsLifted == b2.IsLifted
                && b.IsLiftedToNull == b2.IsLiftedToNull
                && this.Visit( b.Left, b2.Left )
                && this.Visit( b.Right, b2.Right )
                && this.Visit( b.Conversion, b2.Conversion );
        }

        protected virtual bool VisitTypeIs( TypeBinaryExpression b, TypeBinaryExpression b2 ) {
            return b.TypeOperand == b2.TypeOperand
                && this.Visit( b.Expression, b2.Expression );
        }

        protected virtual bool VisitConstant( ConstantExpression c, ConstantExpression c2 ) {
            return /* c.Type == c2.Type && */ object.Equals( c.Value, c2.Value );
        }

        protected virtual bool VisitConditional( ConditionalExpression c, ConditionalExpression c2 ) {
            return this.Visit( c.Test, c2.Test )
                && this.Visit( c.IfTrue, c2.IfTrue )
                && this.Visit( c.IfFalse, c2.IfFalse );
        }

        protected virtual bool VisitParameter( ParameterExpression p, ParameterExpression p2 ) {
            // if two sub-trees of the same expressions are compared. they may have equal 
            //  parameters (first disjunct). 
            return p == p2 || p2 == _paramDict[ p ];
        }

        protected virtual bool VisitMemberAccess( MemberExpression m, MemberExpression m2 ) {
            return m.Member == m2.Member
                && this.Visit( m.Expression, m2.Expression );
        }

        protected virtual bool VisitMethodCall( MethodCallExpression m, MethodCallExpression m2 ) {
            return m.Method == m2.Method
                && this.Visit( m.Object, m2.Object )
                && this.VisitExpressionList( m.Arguments, m2.Arguments );
        }

        protected virtual bool VisitExpressionList( ReadOnlyCollection<Expression> original, ReadOnlyCollection<Expression> original2 ) {
            if( original.Count != original2.Count ) {
                return false;
            }

            for( int i = 0, n = original.Count; i < n; i++ ) {
                if( !this.Visit( original[ i ], original2[ i ] ) )
                    return false;
            }
            return true;
        }

        protected virtual bool VisitMemberAssignment( MemberAssignment assignment, MemberAssignment assignment2 ) {
            return assignment.Member == assignment2.Member
             && this.Visit( assignment.Expression, assignment2.Expression );
        }

        protected virtual bool VisitMemberMemberBinding( MemberMemberBinding binding, MemberMemberBinding binding2 ) {
            return binding.Member == binding2.Member
                && this.VisitBindingList( binding.Bindings, binding2.Bindings );
        }

        protected virtual bool VisitMemberListBinding( MemberListBinding binding, MemberListBinding binding2 ) {
            return binding.Member == binding2.Member
                && this.VisitElementInitializerList( binding.Initializers, binding2.Initializers );
        }

        protected virtual bool VisitBindingList( ReadOnlyCollection<MemberBinding> original, ReadOnlyCollection<MemberBinding> original2 ) {
            if( original.Count != original2.Count )
                return false;

            for( int i = 0, n = original.Count; i < n; i++ ) {
                if( !this.VisitBinding( original[ i ], original2[ i ] ) )
                    return false; // ------->>>>>>>>>>>>>>
            }
            return true;
        }

        protected virtual bool VisitElementInitializerList( ReadOnlyCollection<ElementInit> original, ReadOnlyCollection<ElementInit> original2 ) {
            if( original.Count != original2.Count )
                return false;
            for( int i = 0, n = original.Count; i < n; i++ ) {
                if( !this.VisitElementInitializer( original[ i ], original2[ i ] ) )
                    return false;
            }
            return true;
        }

        protected virtual bool VisitLambda( LambdaExpression lambda, LambdaExpression lambda2 ) {
            bool ret = this.VisitParameterList( lambda.Parameters, lambda2.Parameters );
            if( ret )
                ret = this.Visit( lambda.Body, lambda2.Body );
            // if VisitParameterList() failed, not all params are in the dict
            foreach( var p in lambda.Parameters ) {
                if( _paramDict.ContainsKey( p ) )
                    _paramDict.Remove( p );
            }
            return ret;
        }

        protected virtual bool VisitParameterList( ReadOnlyCollection<ParameterExpression> original, ReadOnlyCollection<ParameterExpression> original2 ) {
            if( original.Count != original2.Count )
                return false; // ------------>>>>>>>>>>>>>
            for( int i = 0, n = original.Count; i < n; i++ ) {
                if( original[ i ].Type != original2[ i ].Type )
                    return false; // ------------>>>>>>>>>>>>>
                _paramDict.Add( original[ i ], original2[ i ] );
            }
            return true;
        }

        protected virtual bool VisitNew( NewExpression nex, NewExpression nex2 ) {
            return nex.Constructor == nex2.Constructor
                && this.VisitMembersOfNew( nex.Members, nex2.Members )
                && this.VisitExpressionList( nex.Arguments, nex2.Arguments );
        }

        protected virtual bool VisitMembersOfNew( ReadOnlyCollection<MemberInfo> original, ReadOnlyCollection<MemberInfo> original2 ) {
            if( original.Count != original2.Count )
                return false;
            for( int i = 0, n = original.Count; i < n; i++ ) {
                if( original[ i ] != original2[ i ] )
                    return false;
            }
            return true;
        }

        protected virtual bool VisitMemberInit( MemberInitExpression init, MemberInitExpression init2 ) {
            return this.VisitNew( init.NewExpression, init2.NewExpression )
                && this.VisitBindingList( init.Bindings, init2.Bindings );
        }

        protected virtual bool VisitListInit( ListInitExpression init, ListInitExpression init2 ) {
            return this.VisitNew( init.NewExpression, init2.NewExpression )
                && this.VisitElementInitializerList( init.Initializers, init2.Initializers );
        }

        protected virtual bool VisitNewArray( NewArrayExpression na, NewArrayExpression na2 ) {
            return na.Type == na2.Type
                && this.VisitExpressionList( na.Expressions, na2.Expressions );
        }

        protected virtual bool VisitInvocation( InvocationExpression iv, InvocationExpression iv2 ) {
            return this.VisitExpressionList( iv.Arguments, iv2.Arguments )
                && this.Visit( iv.Expression, iv2.Expression );
        }
    }
}

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Share

About the Author

Dmitri Raiko

Germany Germany
No Biography provided

| Advertise | Privacy | Terms of Use | Mobile
Web02 | 2.8.150520.1 | Last Updated 18 Mar 2008
Article Copyright 2008 by Dmitri Raiko
Everything else Copyright © CodeProject, 1999-2015
Layout: fixed | fluid