Click here to Skip to main content
Click here to Skip to main content
Add your own
alternative version

Modifying LINQ Expressions with Rewrite Rules

, 18 Mar 2008 CPOL
Rewriting query expressions is a simple and yet safe and powerful technique to modify queries dynamically at runtime.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
using System.Reflection;
using System.Data.Linq.SqlClient;

using Rewrite;

namespace RewriteSql {
    public static class GenericBinFilter<TEntity> {
        private static Dictionary<string, IPropertyGenericBinFilter<TEntity>> _cache 
            = new Dictionary<string,IPropertyGenericBinFilter<TEntity>>();

        public static Expression<Func<TEntity, bool>> Create( string propertyName, string compareOp, object value ) {
            if( string.IsNullOrEmpty( propertyName))
                throw new ArgumentException( "Null or empty value", "propertyName");

            IPropertyGenericBinFilter<TEntity> propFilter;
            if( !_cache.TryGetValue( propertyName, out propFilter ) ) {
                PropertyInfo pi = typeof( TEntity ).GetProperty( propertyName );
                if( pi == null )
                    throw new ArgumentException( string.Format(
                        "Type '{0}' doesn't have property '{1}'.",
                        typeof( TEntity ).Name, propertyName ) );

                Type[] genTypes = { typeof(TEntity), pi.PropertyType};
                Type filterType = typeof( PropertyGenericBinFilter<,> ).MakeGenericType( genTypes );
                propFilter = (IPropertyGenericBinFilter<TEntity>)Activator.CreateInstance( filterType );

                ParameterExpression p = Expression.Parameter( typeof( TEntity ), "e" );
                LambdaExpression propRhs = Expression.Lambda(
                        Expression.MakeMemberAccess( p, pi ),
                        p );
                propFilter.RewriteGetter( propRhs );

                _cache.Add( propertyName, propFilter );
            }

            SimpleRewriter rwr = new SimpleRewriter( propFilter.Expression );

            // replace compare op and value
            LambdaExpression opRhs = CompareOpDecoder.Decode( propFilter.PropertyType, compareOp );
            rwr.ApplyOnce( new Rule( propFilter.CompareOpLhs, opRhs ) );
            rwr.ApplyOnce( new Rule( propFilter.ValueLhs, EvaluateLiteral.CreateRhs( propFilter.PropertyType, value ) ) );

            return (Expression<Func<TEntity, bool>>)rwr.Expression;
        }
    }

    public static class CompareOpDecoder {
        public static Dictionary<TypeCode, Dictionary<string, LambdaExpression>> CompareOpDictionary;

        /// <summary>
        /// Initializes compare operation dictionaries for two supported types of values: string and int. 
        /// </summary>
        static CompareOpDecoder() {
            CompareOpDictionary = new Dictionary<TypeCode, Dictionary<string, LambdaExpression>>();

            #region string compare
            Dictionary<string, LambdaExpression> stringOps = new Dictionary<string, LambdaExpression>( StringComparer.InvariantCultureIgnoreCase );
            stringOps.AddStringPredicate( "==", ( x, y ) => x == y );
            stringOps.AddStringPredicate( "!=", ( x, y ) => x != y );
            stringOps.AddStringPredicate( "StartsWith", ( x, y ) => x.StartsWith( y ) );
            stringOps.AddStringPredicate( "EndsWith", ( x, y ) => x.EndsWith( y ) );
            stringOps.AddStringPredicate( "Contains", ( x, y ) => x.Contains( y ) );
            stringOps.AddStringPredicate( "IStartsWith", ( x, y ) => x.ToUpper().StartsWith( y.ToUpper() ) );
            stringOps.AddStringPredicate( "IEndsWith", ( x, y ) => x.ToUpper().EndsWith( y.ToUpper() ) );
            stringOps.AddStringPredicate( "IContains", ( x, y ) => x.ToUpper().Contains( y.ToUpper() ) );
            stringOps.AddStringPredicate( "Like", ( x, y ) => SqlMethods.Like( x, y ) );
            stringOps.AddStringPredicate( "ILike", ( x, y ) => SqlMethods.Like( x.ToUpper(), y.ToUpper() ) );

            CompareOpDictionary.Add( TypeCode.String, stringOps );
            #endregion

            #region int compare
            Dictionary<string, LambdaExpression> intOps = new Dictionary<string, LambdaExpression>( StringComparer.InvariantCultureIgnoreCase );
            intOps.AddIntPredicate( "==", ( x, y ) => x == y );
            intOps.AddIntPredicate( "!=", ( x, y ) => x != y );
            intOps.AddIntPredicate( ">", ( x, y ) => x > y );
            intOps.AddIntPredicate( ">=", ( x, y ) => x >= y );
            intOps.AddIntPredicate( "<", ( x, y ) => x < y );
            intOps.AddIntPredicate( "<=", ( x, y ) => x <= y );

            CompareOpDictionary.Add( TypeCode.Int32, intOps );
            #endregion
        }

        private static void AddIntPredicate( this Dictionary<string, LambdaExpression> dict,
            string opName,
            Expression<Func<int, int, bool>> expr ) {
            dict.Add( opName, expr );
        }

        private static void AddStringPredicate( this Dictionary<string, LambdaExpression> dict, 
            string opName, 
            Expression<Func<string, string, bool>> expr ) {
            dict.Add( opName, expr );
        }

        public static IEnumerable<string> GetCompareOpNames( Type type ) {
            Dictionary<string, LambdaExpression> dict;
            if( !CompareOpDictionary.TryGetValue( Type.GetTypeCode( type ), out dict ) )
                return null; // --------- type not there ----------->>>>>>>>>>>>>>>>>>
            return dict.Keys;
        }

        public static LambdaExpression Decode( Type valueType, string compareOp ) {
            if( valueType == null )
                throw new ArgumentNullException( "valueType");
            if( compareOp == null)
                throw new ArgumentNullException( "compareOp");

            Dictionary<string, LambdaExpression> dict;
            if( !CompareOpDictionary.TryGetValue( Type.GetTypeCode( valueType ), out dict ) )
                throw new ArgumentOutOfRangeException( string.Format( 
                    "Decoder dictionary doesn't contain compare operations over type '{0}'.", 
                    valueType));

            LambdaExpression ret;
            if( !dict.TryGetValue( compareOp, out ret ) )
                throw new ArgumentOutOfRangeException( string.Format(
                    "Decoder dictionarydoesn't contain compare operation '{0}' over type '{1}'.",
                    compareOp, valueType ) );

            return ret;
        }
    }

    public interface IPropertyGenericBinFilter<TEntity> {
        Expression<Func<TEntity, bool>> Expression { get; }
        void RewriteGetter( LambdaExpression rhs );
        LambdaExpression CompareOpLhs { get; }
        LambdaExpression ValueLhs { get; }
        Type PropertyType{ get; }
    }

    public class PropertyGenericBinFilter<TEntity,TValue>: IPropertyGenericBinFilter<TEntity> {
        private Expression<Func<TEntity, bool>> _expr;
        private Expression<Func<TEntity, TValue>> _propertyLhs;
        private Expression<Func<TValue, TValue, bool>> _compareOpLhs;
        private Expression<Func<TValue>> _valueLhs;
        
        public PropertyGenericBinFilter() {
            Func<TValue, TValue, bool> op = null;
            Func<TEntity, TValue> prop = null;
            TValue value = default( TValue );

            _compareOpLhs = ( x, y ) => op( x, y );
            _propertyLhs = e => prop( e );
            _valueLhs = () => value;
            _expr = e => op( prop( e ), value );
        }

        #region IPropertyGenericBinFilter<TEntity> Members
        public Expression<Func<TEntity, bool>> Expression {
            get { return _expr; }
        }

        public void RewriteGetter( LambdaExpression rhs ) {
            _expr = SimpleRewriter.ApplyOnce( _expr,
                Rule.Create( _propertyLhs, (Expression<Func<TEntity, TValue>>)rhs ) );
        }

        public LambdaExpression CompareOpLhs {
            get { return _compareOpLhs; }
        }

        public LambdaExpression ValueLhs {
            get { return _valueLhs; }
        }
        public Type PropertyType {
            get { return typeof( TValue ); } 
        }

        #endregion
    }

    public static class WhereRewriter {
        public static Expression Rewrite<TEntity>(
            Expression expr,
            Expression<Func<TEntity, bool>> dummyPred,
            Expression<Func<TEntity, bool>> filter ) 
        {
            if( expr == null )
                throw new ArgumentNullException( "expr", "pression is null" );
            if( dummyPred == null )
                throw new ArgumentNullException( "srcLhs", "Lhs to find dummy predicate is null" );

            SimpleRewriter rwr = new SimpleRewriter( expr );

            if( filter == null){
                Func<TEntity, bool> p = null;
                // Replace the dummy predicate in the query for locally defined predicate "p"
                rwr.ApplyOnce( FilterBuilder.FilterRule( dummyPred, z => p( z ) ) );

                // remove .Where( y => p(y))  fromthe query
                rwr.ApplyOnce( Rule.Create<IQueryable<TEntity>,IQueryable<TEntity>>(
                    x => x.Where( y => p( y ) ), 
                    x => x ) );
            } else {
                rwr.ApplyOnce( FilterBuilder.FilterRule( dummyPred, filter ) );
            }

            return rwr.Expression;
        }

        public static IQueryable<TQuery> RecreateQuery<TQuery>(
            IQueryable<TQuery> q,
            Expression expr ) 
        {
            return q.Provider.CreateQuery<TQuery>( expr );
        }

    
    }
}

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Share

About the Author

Dmitri Raiko

Germany Germany
No Biography provided

| Advertise | Privacy | Terms of Use | Mobile
Web01 | 2.8.141220.1 | Last Updated 18 Mar 2008
Article Copyright 2008 by Dmitri Raiko
Everything else Copyright © CodeProject, 1999-2014
Layout: fixed | fluid