#region Copyright (c) 2008 by Jahmani Muigai Mwaura and Community
/*--------------------------------------------------------------------------------------------------
* LinqToSql, a Linq to Sql parser for the .NET Platform
* by Jahmani Mwaura and community
* ------------------------------------------------------------------------------------------------
* Version: LGPL 2.1
*
* Software distributed under the License is distributed on an "AS IS" basis,
* WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
* for the specific language governing rights and limitations under the
* License.
*
* The Original Code is any part of this file that is not marked as a contribution.
*
* The Initial Developer of the Original Code is Jahmani Muigai Mwaura.
* Portions created by the Initial Developer are Copyright (C) 2008
* the Initial Developer. All Rights Reserved.
*
* Contributor(s): None.
*--------------------------------------------------------------------------------------------------
*/
#endregion
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Data.Common;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
namespace LinqToSql {
public class SqlExpressionParser : ExpressionVisitor {
private readonly SqlExpressionParser outerStatement = null;
private readonly AggregateType aggregateType = AggregateType.None;
private readonly StringBuilder sb = new StringBuilder();
private readonly int indentLevel = -1;
private readonly Queue<MethodCallExpression> queryableMethods =
new Queue<MethodCallExpression>();
private readonly Stack<MethodCallExpression> optimizableQueryableMethods =
new Stack<MethodCallExpression>();
private readonly Dictionary<Guid, Type> typesFromGroupBy = new Dictionary<Guid, Type>();
private SelectHandler selectHandler = null;
private WhereHandler whereHandler = null;
private JoinHandler joinHandler = null;
private CrossJoinHandler crossJoinHandler = null;
private Stack<OrderByHandler> orderByHandlers = new Stack<OrderByHandler>();
private SetOperationHandler setOperationHandler = null;
private readonly SqlSyntaxHelper syntaxHelper;
private static readonly ThreadSafeCache<string, SqlExpressionParser> parserCache =
new ThreadSafeCache<string, SqlExpressionParser>();
private static readonly ThreadSafeCache<string, Delegate> queryableExecutorCache =
new ThreadSafeCache<string, Delegate>();
private SqlExpressionParser(SqlSyntaxHelper syntaxHelper)
: this(-1, null, syntaxHelper) {
}
private SqlExpressionParser(int indentLevel, SqlExpressionParser outerStatement,
SqlSyntaxHelper syntaxHelper)
: this(indentLevel, outerStatement, AggregateType.None, syntaxHelper) {
}
private SqlExpressionParser(int indentLevel, SqlExpressionParser outerStatement,
AggregateType aggregateType,
SqlSyntaxHelper syntaxHelper) {
this.indentLevel = indentLevel;
this.outerStatement = outerStatement;
this.aggregateType = aggregateType;
this.syntaxHelper = syntaxHelper;
}
public static string TranslateExpression(Expression expression, SqlSyntaxHelper syntaxHelper) {
var parser = GetSqlExpressionParser(expression, syntaxHelper);
return parser.GetSQLStatement();
}
public static object ExecuteExpression(DbConnection connection, Expression expression,
SqlSyntaxHelper syntaxHelper) {
return GetSqlExpressionParser(expression, syntaxHelper).Execute(connection, expression);
}
private static SqlExpressionParser GetSqlExpressionParser(Expression expression,
SqlSyntaxHelper syntaxHelper) {
var key = syntaxHelper.GetType().GUID + expression.ToString();
if (parserCache.ContainsKey(key)) {
return parserCache[key];
}
var sqlExpressionParser = new SqlExpressionParser(syntaxHelper);
sqlExpressionParser.Translate(expression);
parserCache.TryAdd(key, sqlExpressionParser);
//Debug.Print(sqlExpressionParser.GetSQLStatement());
return sqlExpressionParser;
}
private object Execute(DbConnection connection, Expression expression) {
if (expression.NodeType == ExpressionType.Constant) {
return Activator.CreateInstance(
typeof(ConstantEnumerable<>).MakeGenericType(expression.Type));
}
Debug.Assert(sb.Length != 0);
Debug.Assert(selectHandler != null);
var selector = selectHandler.Selector;
var executor = Activator.CreateInstance(
typeof(Executor<>).MakeGenericType(selector.Body.Type),
connection,
this,
Evaluator.PartialEval(expression),
Binder.GetBinder(selector));
if (queryableMethods.Count == 0) {
return executor;
}
var result = ExecuteQueryableMethod(executor,
new Stack<MethodCallExpression>(queryableMethods),
selector.Body.Type);
return result;
}
private object ExecuteQueryableMethod(object executor,
Stack<MethodCallExpression> queryableMethods,
Type executorSourceType) {
var queryableMethod = queryableMethods.Pop();
var queryableExecutor = GetQueryableExecutor(executorSourceType,
queryableMethod);
var source = Queryable.AsQueryable((System.Collections.IEnumerable)executor);
if (queryableMethods.Count == 0) {
return queryableExecutor.DynamicInvoke(source);
}
return ExecuteQueryableMethod(queryableExecutor.DynamicInvoke(source),
queryableMethods,
queryableMethod.Type.GetGenericArguments()[0]);
}
private static Delegate GetQueryableExecutor(Type executorSourceType, MethodCallExpression queryableMethod) {
var args = queryableMethod.Arguments.Where((arg, index) => index != 0).ToArray();
var key = queryableMethod.Type.GUID +
queryableMethod.Arguments[0].Type.GUID.ToString() +
queryableMethod.ToString();
if (queryableExecutorCache.ContainsKey(key)) {
return queryableExecutorCache[key];
}
Type sourceType = QueryableMethodsProvider.GetQueryableType(executorSourceType);
var queryableArgs = new Expression[args.Length + 1];
var source = Expression.Parameter(sourceType, "source");
queryableArgs[0] = source;
for (int i = 0; i < args.Length; i++) {
queryableArgs[i + 1] = args[i];
}
var queryableExecutor = Expression.Lambda(Expression.Call(queryableMethod.Method,
queryableArgs),
source);
var result = queryableExecutor.Compile();
queryableExecutorCache.TryAdd(key, result);
return result;
}
private string Translate(Expression expression) {
if (expression.NodeType == ExpressionType.Constant &&
(expression as ConstantExpression).Type != typeof(object)) {
return string.Empty;
}
if (sb.Length != 0) {
// expression has already been evaluated and cached
return sb.ToString();
}
this.Visit(Evaluator.PartialEval(expression));
EmitSelectStatement();
//Debug.Print(sb.ToString());
return sb.ToString();
}
protected override Expression VisitMethodCall(MethodCallExpression m) {
if (m.Method.DeclaringType == typeof(Queryable) ||
m.Method.DeclaringType == typeof(Enumerable)) {
switch (m.Method.Name) {
case "Select":
if (!GetSelectHandler(m)) {
queryableMethods.Enqueue(m);
}
this.Visit(m.Arguments[0]);
break;
case "Join":
Debug.Assert(joinHandler == null);
Debug.Assert(crossJoinHandler == null);
joinHandler = JoinHandler.GetJoinHandler(this, indentLevel + 1, m);
this.Visit(m.Arguments[0]);
break;
case "SelectMany":
Debug.Assert(crossJoinHandler == null);
Debug.Assert(joinHandler == null);
crossJoinHandler = CrossJoinHandler.GetCrossJoinHandler(this, indentLevel + 1, m);
this.Visit(m.Arguments[0]);
break;
case "Where":
if (!GetWhereHandler(m)) {
queryableMethods.Enqueue(m);
}
this.Visit(m.Arguments[0]);
break;
case "OrderBy":
case "OrderByDescending":
case "ThenBy":
case "ThenByDescending":
if (!GetOrderByHandler(m)) {
queryableMethods.Enqueue(m);
}
this.Visit(m.Arguments[0]);
break;
case "Distinct":
case "First":
case "FirstOrDefault":
TryOptimize(m);
this.Visit(m.Arguments[0]);
break;
case "Concat":
case "Except":
case "Intersect":
case "Union":
if (m.Arguments.Count == 2) {
setOperationHandler =
SetOperationHandler.GetSetOperationHandler(this, indentLevel + 1, m);
break;
}
else {
queryableMethods.Enqueue(m);
}
break;
default:
queryableMethods.Enqueue(m);
this.Visit(m.Arguments[0]);
break;
}
}
else {
throw new NotSupportedException(string.Format("The method '{0}' is not supported", m.Method.Name));
}
return m;
}
private bool GetSelectHandler(MethodCallExpression m) {
if (selectHandler == null) {
selectHandler = SelectHandler.GetSelectHandler(this, indentLevel + 1, m, aggregateType);
return selectHandler != null;
}
selectHandler.AddNestedSelect(m);
return true;
}
private bool GetWhereHandler(MethodCallExpression m) {
if (whereHandler == null) {
int parameterBaseIndex = outerStatement == null ? 0 : outerStatement.ParameterCount;
whereHandler = WhereHandler.GetWhereHandler(this, indentLevel + 1, m, parameterBaseIndex);
return whereHandler != null;
}
whereHandler.AddCriteria(m);
return true;
}
private bool GetOrderByHandler(MethodCallExpression m) {
var orderByHandler = OrderByHandler.GetOrderByHandler(this, indentLevel + 1, m);
if (orderByHandler != null) {
orderByHandlers.Push(orderByHandler);
return true;
}
else {
queryableMethods.Enqueue(m);
return false;
}
}
private void TryOptimize(MethodCallExpression m) {
if (m.Arguments.Count == 1) {
optimizableQueryableMethods.Push(m);
}
queryableMethods.Enqueue(m);
}
private void EmitSelectStatement() {
if (setOperationHandler != null) {
InitSelectHandler();
sb.Append(setOperationHandler.GetSetOperationClause());
return;
}
GetSelectClause();
GetJoinClause();
GetCrossJoinClause();
GetWhereClause();
GetOrderByClause();
var aliasedSQL = sb.ToString();
sb.Length = 0;
sb.Append(ReplaceAliases(aliasedSQL));
}
private void InitSelectHandler() {
if (selectHandler != null) {
return;
}
if (joinHandler != null) {
selectHandler = SelectHandler.GetSelectHandler(this, indentLevel + 1,
QueryableMethodsProvider
.GetSelectCall(joinHandler.Selector.Parameters[0].Type,
joinHandler.Selector),
aggregateType);
return;
}
if (crossJoinHandler != null) {
selectHandler = SelectHandler.GetSelectHandler(this, indentLevel + 1,
QueryableMethodsProvider.GetSelectCall(crossJoinHandler.Selector),
aggregateType);
return;
}
Type returnType = GetReturnType();
if (returnType == null) {
throw new InvalidOperationException("Cannot translate statement");
}
selectHandler = SelectHandler.GetSelectHandler(this, indentLevel + 1, returnType, aggregateType);
}
private void GetSelectClause() {
InitSelectHandler();
sb.Append(selectHandler.GetSelectClause(crossJoinHandler == null &&
joinHandler == null,
optimizableQueryableMethods
)
);
}
private void GetJoinClause() {
if (joinHandler != null) {
sb.Append(joinHandler.GetJoinClause());
}
}
private void GetCrossJoinClause() {
if (crossJoinHandler != null) {
sb.Append(crossJoinHandler.GetCrossJoinClause());
}
}
private void GetWhereClause() {
if (whereHandler != null) {
sb.Append(whereHandler.GetWhereClause());
}
}
private void GetOrderByClause() {
if (IsTopLevelOrderBy()) {
EmitOrderBy();
sb.Append(Environment.NewLine);
return;
}
/*else {
LiftOrderByClause();
}*/
}
private bool IsTopLevelOrderBy() {
return orderByHandlers.Count > 0 && outerStatement == null;
}
private void EmitOrderBy() {
var orderByClauses = from handler in orderByHandlers
select handler.GetOrderByClause();
// the distinct thing is a bit of a hack
var orderByClause = string.Join(", ", orderByClauses.Distinct().ToArray());
sb.Append("ORDER BY " + orderByClause);
}
private string ReplaceAliases(string aliasedSQL) {
string result = aliasedSQL;
if (selectHandler != null) {
result = selectHandler.ReplaceAliases(result);
}
if (joinHandler != null) {
result = joinHandler.ReplaceAliases(result);
}
if (crossJoinHandler != null) {
result = crossJoinHandler.ReplaceAliases(result);
}
return result;
}
private Type GetReturnType() {
if (setOperationHandler != null) {
return setOperationHandler.ReturnType;
}
if (whereHandler != null) {
return whereHandler.ReturnType;
}
if (orderByHandlers.Count != 0) {
return orderByHandlers.Peek().ReturnType;
}
if (queryableMethods.Count != 0) {
return queryableMethods.Peek().Method.GetGenericArguments()[0];
}
return null;
}
private string GetTableName() {
Debug.Assert(selectHandler != null);
return GetTableName(selectHandler.TableType);
}
private string GetSQLStatement() {
Debug.Assert(sb.Length != 0);
return sb.ToString();
}
private int ParameterCount {
get {
if (whereHandler != null) {
return whereHandler.ParameterCount;
}
if (outerStatement != null) {
return outerStatement.ParameterCount;
}
return 0;
}
}
private bool FromIsOrContainsGrouping(Type type) {
var genericParameters = type.GetGenericArguments();
bool result = type.Name == "IGrouping`2" ||
genericParameters.Any(t => t.Name == "IGrouping`2") ||
typesFromGroupBy.ContainsKey(type.GUID)
;
var tableAttributeType = typeof(System.Data.Linq.Mapping.TableAttribute);
var fromGrouping = genericParameters
.Where(t => t.Name == "IGrouping`2")
.Select(u => u.GetGenericArguments()[1])
.Where(v => v.GetCustomAttributes(tableAttributeType, false).Length == 0);
foreach (var t in fromGrouping) {
typesFromGroupBy[t.GUID] = t;
}
return result;
}
private static bool IsAggregateMethod(MethodCallExpression m) {
if (m.Method.DeclaringType != typeof(Queryable) &&
m.Method.DeclaringType != typeof(Enumerable)) {
return false;
}
switch (m.Method.Name) {
case "Count":
case "Average":
case "Max":
case "Min":
case "Sum":
return true;
default:
return false;
}
}
private static string GetTableAlias(int indentLevel) {
return "t" + indentLevel.ToString();
}
private static string GetIndentation(int indentLevel) {
StringBuilder sb = new StringBuilder(indentLevel);
for (int i = 0; i < indentLevel; i++) {
sb.Append("\t");
}
return sb.ToString();
}
private static Expression StripQuotes(Expression e) {
while (e.NodeType == ExpressionType.Quote) {
e = ((UnaryExpression)e).Operand;
}
return e;
}
private static LambdaExpression GetLambdaExpression(Expression expression) {
var selectorLambda = StripQuotes(expression) as LambdaExpression;
if (selectorLambda == null) {
Debug.Assert(expression as ConstantExpression != null);
Debug.Assert((expression as ConstantExpression).Value as LambdaExpression != null);
var constantValue = (ConstantExpression)expression;
selectorLambda = (LambdaExpression)constantValue.Value;
Debug.Assert(selectorLambda != null);
}
return selectorLambda;
}
private static string GetTableName(Type tableType) {
return ((System.Data.Linq.Mapping.TableAttribute)
tableType.GetCustomAttributes(typeof(System.Data.Linq.Mapping.TableAttribute),
false)[0]).Name;
}
private enum AggregateType {
None,
Count,
Sum,
Min,
Max,
Average
}
private class SelectHandler {
private readonly SqlExpressionParser outerStatement;
private readonly int indentLevel;
private readonly AggregateType aggregateType = AggregateType.None;
private readonly Type returnType = null;
private readonly Type tableType = null;
private readonly LambdaExpressionHandler lambdaHandler = null;
private readonly LambdaExpression selector = null;
private readonly string selectorExpression = null;
private readonly Stack<SelectHandler> nestedSelectHandlers =
new Stack<SelectHandler>();
public Type ReturnType {
get {
return returnType;
}
}
public Type TableType {
get {
return tableType;
}
}
private SelectHandler(SqlExpressionParser outerStatement, int indentLevel,
MethodCallExpression expression, AggregateType aggregateType) {
this.outerStatement = outerStatement;
this.indentLevel = indentLevel;
this.aggregateType = aggregateType;
selector = GetLambdaExpression(expression.Arguments[1]);
returnType = selector.Type.GetGenericArguments()[1];
tableType = selector.Parameters[0].Type;
lambdaHandler = new LambdaExpressionHandler(indentLevel, selector, outerStatement);
selectorExpression = lambdaHandler.GetExpressionAsString(true).ToString();
}
private SelectHandler(SqlExpressionParser outerStatement, int indentLevel, Type returnType,
AggregateType aggregateType) :
this(outerStatement, indentLevel,
QueryableMethodsProvider.GetSelectCall(returnType), aggregateType) {
}
public static SelectHandler GetSelectHandler(SqlExpressionParser outerStatement, int indentLevel,
MethodCallExpression expression,
AggregateType aggregateType) {
Debug.Assert(expression.Method.Name == "Select");
Debug.Assert(expression.Arguments.Count == 2);
Debug.Assert(expression.Arguments[0].Type.GetGenericArguments().Length == 1);
var selector = GetLambdaExpression(expression.Arguments[1]).Parameters[0];
if (outerStatement.FromIsOrContainsGrouping(selector.Type)) {
return null;
}
SelectHandler selectHandler = new SelectHandler(outerStatement, indentLevel,
expression, aggregateType);
return selectHandler;
}
public static SelectHandler GetSelectHandler(SqlExpressionParser outerStatement,
int indentLevel, Type returnType,
AggregateType aggregateType) {
return new SelectHandler(outerStatement, indentLevel, returnType, aggregateType);
}
public string GetSelectClause(bool emitTableAlias,
Stack<MethodCallExpression> optimizableQueryableMethods) {
var optimizedCalls = GetOptimizedCalls(optimizableQueryableMethods);
StringBuilder sb = new StringBuilder();
sb.Append(GetIndentation(indentLevel));
sb.Append("SELECT " + optimizedCalls + " ");
sb.Append(GetFields(GetTableAlias(indentLevel)));
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append("FROM ");
EmitAlias(emitTableAlias, sb);
sb.Append(Environment.NewLine);
return sb.ToString();
}
private string GetOptimizedCalls(Stack<MethodCallExpression> optimizableQueryableMethods) {
var sb = new StringBuilder();
foreach (var optimizableQueryableMethod in optimizableQueryableMethods) {
// not quite right as each modifier should appear only once
// in a valid SQL statement also the order in which
// the modifiers appear is significant
switch (optimizableQueryableMethod.Method.Name) {
case "Distinct":
sb.Append(" " + outerStatement.syntaxHelper.GetDistinctKeyword() + " ");
break;
case "First":
case "FirstOrDefault":
sb.Append(" " + outerStatement.syntaxHelper.GetFirstKeyword() + " ");
break;
}
}
return sb.ToString();
}
private void EmitAlias(bool emitTableAlias, StringBuilder sb) {
if (emitTableAlias) {
sb.Append(GetTableName(tableType));
sb.Append(" AS " + GetTableAlias(indentLevel));
}
}
private string GetFields(string tableAlias) {
var accessedFields = lambdaHandler.GetAccessedFields();
string fieldList = null;
if (accessedFields.Length != 0) {
fieldList = GetFieldsFromSelector(accessedFields);
}
else {
fieldList = GetFieldsFromReturnType(tableAlias);
}
var aggregateExpression = ReplaceAliases(selectorExpression);
switch (aggregateType) {
case AggregateType.None:
return fieldList;
case AggregateType.Average:
return outerStatement.syntaxHelper.GetAverageKeyword() + "(" + aggregateExpression + ")";
case AggregateType.Count:
return outerStatement.syntaxHelper.GetCountKeyword() + " ";
case AggregateType.Max:
return outerStatement.syntaxHelper.GetMaxKeyword() + "(" + aggregateExpression + ")";
case AggregateType.Min:
return outerStatement.syntaxHelper.GetMinKeyword() + "(" + aggregateExpression + ")";
case AggregateType.Sum:
return outerStatement.syntaxHelper.GetSumKeyword() + "(" + aggregateExpression + ")";
default:
throw new InvalidOperationException();
}
}
private string GetFieldsFromReturnType(string tableAlias) {
var separator = string.Empty;
if (!string.IsNullOrEmpty(tableAlias)) {
separator = ".";
}
// Hack. Property may not correspond to a column in a table
return string.Join(", ", (from property in returnType.GetProperties()
where property.PropertyType.IsValueType ||
property.PropertyType == typeof(string)
orderby property.Name
select tableAlias + separator + property.Name)
.ToArray());
}
private string GetFieldsFromSelector(string[] fields) {
return ReplaceAliases(string.Join(", ", fields));
}
public string ReplaceAliases(string expression) {
var expressionString = expression;
foreach (var handler in nestedSelectHandlers) {
expressionString = handler.ReplaceAliases(expressionString);
}
StringBuilder sb = new StringBuilder(lambdaHandler.ReplaceAliases(expressionString));
sb.Replace(tableType.GUID.ToString(),
GetTableAlias(indentLevel).ToString());
sb.Replace(returnType.GUID.ToString(),
GetTableAlias(indentLevel).ToString());
return sb.ToString();
}
public void AddNestedSelect(MethodCallExpression expression) {
//Debug.Assert(!outerStatement.FromIsOrContainsGrouping(lambda.Parameters[0].Type));
nestedSelectHandlers.Push(new SelectHandler(outerStatement,
indentLevel,
expression,
AggregateType.None));
}
public LambdaExpression Selector {
get {
Debug.Assert(selector != null);
return selector;
}
}
}
private class WhereHandler {
private readonly SqlExpressionParser outerStatement;
private readonly Type returnType = null;
private readonly int indentLevel;
private readonly Stack<LambdaExpressionHandler> lambdaHandlers =
new Stack<LambdaExpressionHandler>();
public Type ReturnType {
get {
return returnType;
}
}
public int ParameterCount {
get {
return lambdaHandlers.Peek().ParameterCount;
}
}
private WhereHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression,
int parameterBaseIndex) {
this.outerStatement = outerStatement;
this.indentLevel = indentLevel;
returnType = expression.Arguments[0].Type.GetGenericArguments()[0];
LambdaExpression lambda = GetLambdaExpression(expression.Arguments[1]);
lambdaHandlers.Push(new LambdaExpressionHandler(indentLevel,
lambda,
parameterBaseIndex,
outerStatement)
);
}
public static WhereHandler GetWhereHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression,
int parameterBaseIndex) {
Debug.Assert(expression.Method.Name == "Where");
Debug.Assert(expression.Arguments.Count == 2);
Debug.Assert(expression.Arguments[0].Type.GetGenericArguments().Length == 1);
var selector = GetLambdaExpression(expression.Arguments[1]).Parameters[0];
if (outerStatement.FromIsOrContainsGrouping(selector.Type)) {
return null;
}
return new WhereHandler(outerStatement, indentLevel, expression, parameterBaseIndex);
}
public void AddCriteria(MethodCallExpression expression) {
Debug.Assert(lambdaHandlers.Count > 0);
LambdaExpression lambda = GetLambdaExpression(expression.Arguments[1]);
Debug.Assert(!outerStatement.FromIsOrContainsGrouping(lambda.Parameters[0].Type));
lambdaHandlers.Push(new LambdaExpressionHandler(indentLevel,
lambda,
lambdaHandlers.Peek().ParameterCount,
outerStatement)
);
}
public string GetWhereClause() {
var criteriaString = string.Join(" AND ",
(from handler in lambdaHandlers
select handler.GetExpressionAsString(false)
.ToString()
).ToArray());
return GetIndentation(indentLevel) + "WHERE " +
criteriaString +
Environment.NewLine;
}
}
private class JoinHandler {
private readonly SqlExpressionParser outerStatement;
private readonly SqlExpressionParser leftStatement;
private readonly SqlExpressionParser rightStatement;
private readonly LambdaExpression selector = null;
private readonly LambdaExpressionHandler leftKeySelector = null;
private readonly LambdaExpressionHandler rightKeySelector = null;
private readonly Type leftReturnType = null;
private readonly Type rightReturnType = null;
private readonly int indentLevel;
public LambdaExpression Selector {
get {
return selector;
}
}
private JoinHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
this.outerStatement = outerStatement;
this.indentLevel = indentLevel;
selector = new MakeSelector(GetLambdaExpression(expression.Arguments[4])).Selector;
leftStatement = new SqlExpressionParser(indentLevel + 1, outerStatement,
outerStatement.syntaxHelper);
leftStatement.Translate(GetLeftSourceExpression(expression));
rightStatement = new SqlExpressionParser(indentLevel + 1, outerStatement,
outerStatement.syntaxHelper);
rightStatement.Translate(GetRightSourceExpression(expression));
leftKeySelector = new LambdaExpressionHandler(indentLevel,
GetLambdaExpression(expression.Arguments[2]),
outerStatement);
rightKeySelector = new LambdaExpressionHandler(indentLevel,
GetLambdaExpression(expression.Arguments[3]),
outerStatement);
leftReturnType = leftStatement.selectHandler.ReturnType;
rightReturnType = rightStatement.selectHandler.ReturnType;
outerStatement.Visit(expression.Arguments[1]);
}
private static Expression GetLeftSourceExpression(MethodCallExpression expression) {
switch (expression.Arguments[0].NodeType) {
case ExpressionType.Call:
return expression.Arguments[0];
case ExpressionType.Constant:
return GetSourceExpression(expression.Arguments[0]);
default:
throw new ArgumentException("Node type not supported " + expression.Arguments[0].NodeType);
}
}
private static MethodCallExpression GetRightSourceExpression(MethodCallExpression expression) {
return GetSourceExpression(expression.Arguments[1]);
}
private static MethodCallExpression GetSourceExpression(Expression source) {
Debug.Assert(source.Type.GetGenericArguments().Length == 1);
return QueryableMethodsProvider.GetSelectCall(source.Type.GetGenericArguments()[0]);
}
public static JoinHandler GetJoinHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
Debug.Assert(expression.Method.Name == "Join");
Debug.Assert(expression.Arguments.Count == 5);
return new JoinHandler(outerStatement, indentLevel, expression);
}
public string GetJoinClause() {
StringBuilder sb = new StringBuilder();
sb.Append(GetIndentation(indentLevel));
sb.Append(leftStatement.GetTableName() + " AS " + GetTableAlias(indentLevel + 1));
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(" " + outerStatement.syntaxHelper.GetInnerJoinKeyword() + " ");
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(rightStatement.GetTableName() + " AS " + GetTableAlias(indentLevel + 2));
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(" ON ");
sb.Append(GetJoinExpression(leftKeySelector, leftReturnType,
GetTableAlias(indentLevel + 1)) + " = " +
GetJoinExpression(rightKeySelector, rightReturnType,
GetTableAlias(indentLevel + 2)));
sb.Append(Environment.NewLine);
return sb.ToString();
}
private static StringBuilder GetJoinExpression(LambdaExpressionHandler handler,
Type type,
string tableAlias) {
return handler.GetExpressionAsString(false)
.Replace(type.GUID.ToString(), tableAlias);
}
public string ReplaceAliases(string expression) {
StringBuilder sb = new StringBuilder(expression);
sb.Replace(leftReturnType.GUID.ToString(), GetTableAlias(indentLevel + 1));
sb.Replace(rightReturnType.GUID.ToString(), GetTableAlias(indentLevel + 2));
sb.Replace(selector.Body.Type.GUID.ToString() + ".", string.Empty);
sb.Replace(GetTableAlias(indentLevel) + ".", string.Empty);
return sb.ToString();
}
}
private class CrossJoinHandler {
private readonly SqlExpressionParser outerStatement;
private readonly SqlExpressionParser leftStatement;
private readonly SqlExpressionParser rightStatement;
private readonly LambdaExpression selector;
private readonly Type leftReturnType = null;
private readonly Type rightReturnType = null;
private readonly int indentLevel;
public LambdaExpression Selector {
get {
return selector;
}
}
private CrossJoinHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
this.outerStatement = outerStatement;
this.indentLevel = indentLevel;
selector = new MakeSelector(GetLambdaExpression(expression.Arguments.Last())).Selector;
leftStatement = new SqlExpressionParser(indentLevel + 1, outerStatement,
outerStatement.syntaxHelper);
leftStatement.Translate(GetLeftSourceExpression(expression));
rightStatement = new SqlExpressionParser(indentLevel + 1, outerStatement,
outerStatement.syntaxHelper);
rightStatement.Translate(GetRightSourceExpression(expression));
leftReturnType = leftStatement.selectHandler.ReturnType;
rightReturnType = rightStatement.selectHandler.ReturnType;
var sourceExpression = GetLambdaExpression(expression.Arguments[1]).Body;
outerStatement.Visit(sourceExpression);
}
private static Expression GetLeftSourceExpression(MethodCallExpression expression) {
switch (expression.Arguments[0].NodeType) {
case ExpressionType.Call:
return expression.Arguments[0];
case ExpressionType.Constant:
return GetSourceExpression(expression.Arguments[0].Type.GetGenericArguments()[0]);
default:
throw new ArgumentException("Node type not supported " + expression.Arguments[0].NodeType);
}
}
private static MethodCallExpression GetRightSourceExpression(MethodCallExpression expression) {
var rightSource = (GetLambdaExpression(expression.Arguments[2])).Parameters[1];
return GetSourceExpression(rightSource.Type);
}
private static MethodCallExpression GetSourceExpression(Type sourceType) {
return QueryableMethodsProvider.GetSelectCall(sourceType);
}
public static CrossJoinHandler GetCrossJoinHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
Debug.Assert(expression.Method.Name == "SelectMany");
Debug.Assert(expression.Arguments.Count == 3);
return new CrossJoinHandler(outerStatement, indentLevel, expression);
}
public string GetCrossJoinClause() {
StringBuilder sb = new StringBuilder();
sb.Append(GetIndentation(indentLevel));
sb.Append(leftStatement.GetTableName() + " AS " + GetTableAlias(indentLevel + 1));
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(" " + outerStatement.syntaxHelper.GetCrossJoinKeyword() + " ");
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(rightStatement.GetTableName() + " AS " + GetTableAlias(indentLevel + 2));
sb.Append(Environment.NewLine);
return sb.ToString();
}
public string ReplaceAliases(string expression) {
StringBuilder sb = new StringBuilder(expression);
sb.Replace(leftReturnType.GUID.ToString(), GetTableAlias(indentLevel + 1));
sb.Replace(rightReturnType.GUID.ToString(), GetTableAlias(indentLevel + 2));
sb.Replace(selector.Body.Type.GUID.ToString() + ".", string.Empty);
sb.Replace(GetTableAlias(indentLevel) + ".", string.Empty);
return sb.ToString();
}
}
private class OrderByHandler {
private readonly SqlExpressionParser outerStatement;
private readonly Type returnType = null;
private readonly LambdaExpressionHandler lambdaHandler = null;
private readonly string orderByDirection = string.Empty;
public Type ReturnType {
get {
return returnType;
}
}
private OrderByHandler(SqlExpressionParser outerStatement, int indentLevel,
MethodCallExpression expression) {
this.outerStatement = outerStatement;
if (expression.Method.Name == "OrderByDescending" || expression.Method.Name == "ThenByDescending") {
orderByDirection = "Desc";
}
returnType = expression.Arguments[0].Type.GetGenericArguments()[0];
lambdaHandler = new LambdaExpressionHandler(indentLevel,
GetLambdaExpression(expression.Arguments[1]),
outerStatement);
}
public static OrderByHandler GetOrderByHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
Debug.Assert(expression.Method.Name == "OrderBy" ||
expression.Method.Name == "OrderByDescending" ||
expression.Method.Name == "ThenBy" ||
expression.Method.Name == "ThenByDescending");
Debug.Assert(expression.Arguments.Count == 2);
Debug.Assert(expression.Arguments[0].Type.GetGenericArguments().Length == 1);
var selector = GetLambdaExpression(expression.Arguments[1]).Parameters[0];
if (outerStatement.FromIsOrContainsGrouping(selector.Type)) {
return null;
}
return new OrderByHandler(outerStatement, indentLevel, expression);
}
public string GetOrderByClause() {
return lambdaHandler.GetExpressionAsString(false) + " " + orderByDirection;
}
}
private class SetOperationHandler {
private readonly SqlExpressionParser outerStatement;
private readonly SqlExpressionParser leftStatement;
private readonly SqlExpressionParser rightStatement;
private readonly Type returnType;
private readonly int indentLevel;
private readonly string operation;
public Type ReturnType {
get {
return returnType;
}
}
private SetOperationHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
this.outerStatement = outerStatement;
this.indentLevel = indentLevel;
leftStatement = new SqlExpressionParser(indentLevel + 1, outerStatement,
outerStatement.syntaxHelper);
leftStatement.Translate(GetSourceExpression(expression, 0));
rightStatement = new SqlExpressionParser(indentLevel + 1, outerStatement,
outerStatement.syntaxHelper);
rightStatement.Translate(GetSourceExpression(expression, 1));
operation = expression.Method.Name;
returnType = expression.Arguments[0].Type.GetGenericArguments()[0];
outerStatement.Visit(expression.Arguments[0]);
outerStatement.Visit(expression.Arguments[1]);
}
private static Expression GetSourceExpression(MethodCallExpression expression,
int source) {
switch (expression.Arguments[source].NodeType) {
case ExpressionType.Call:
return expression.Arguments[source];
case ExpressionType.Constant:
return GetSourceExpression(expression.Arguments[source].Type, source);
default:
throw new ArgumentException("Node type not supported " + expression.Arguments[source].NodeType);
}
}
private static MethodCallExpression GetSourceExpression(Type sourceType, int source) {
Debug.Assert(sourceType.GetGenericArguments().Length == 1);
return QueryableMethodsProvider.GetSelectCall(sourceType.GetGenericArguments()[source]);
}
public static SetOperationHandler GetSetOperationHandler(SqlExpressionParser outerStatement,
int indentLevel,
MethodCallExpression expression) {
Debug.Assert(expression.Method.Name == "Concat" ||
expression.Method.Name == "Except" ||
expression.Method.Name == "Intersect" ||
expression.Method.Name == "Union");
Debug.Assert(expression.Arguments.Count == 2);
return new SetOperationHandler(outerStatement, indentLevel, expression);
}
public string GetSetOperationClause() {
StringBuilder sb = new StringBuilder();
sb.Append(GetIndentation(indentLevel));
sb.Append(leftStatement.GetSQLStatement());
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(outerStatement.syntaxHelper.GetSetOperationKeyword(operation));
sb.Append(Environment.NewLine);
sb.Append(GetIndentation(indentLevel));
sb.Append(rightStatement.GetSQLStatement());
sb.Append(Environment.NewLine);
return sb.ToString();
}
}
private class LambdaExpressionHandler : ExpressionVisitor {
private readonly SqlExpressionParser outerStatement;
private readonly LambdaExpression lambdaExpression;
private readonly Guid lambaExpressionId;
private readonly int indentLevel;
private readonly Dictionary<string, string> aliases = new Dictionary<string, string>();
private readonly List<string> accessedColumns = new List<string>();
private readonly Stack<Expression> terms = new Stack<Expression>();
private readonly StringBuilder sb = new StringBuilder();
private readonly Dictionary<ConstantExpression, ConstantExpression> visitedConstants =
new Dictionary<ConstantExpression, ConstantExpression>();
private int parameterCount = 0;
public int ParameterCount {
get {
return parameterCount;
}
}
public LambdaExpressionHandler(int indentLevel, LambdaExpression lambdaExpression,
SqlExpressionParser outerStatement)
: this(indentLevel, lambdaExpression, 0, outerStatement) {
}
public LambdaExpressionHandler(int indentLevel, LambdaExpression lambdaExpression,
int parameterBaseIndex,
SqlExpressionParser outerStatement) {
this.indentLevel = indentLevel;
this.lambdaExpression = (LambdaExpression)new Normalizer().Normalize(lambdaExpression);
lambaExpressionId = this.lambdaExpression.Body.Type.GUID;
this.parameterCount = parameterBaseIndex;
this.outerStatement = outerStatement;
this.Visit(this.lambdaExpression);
GetExpressionAsString(false);
}
protected override Expression VisitMethodCall(MethodCallExpression m) {
this.Visit(m.Object);
this.VisitExpressionList(m.Arguments);
terms.Push(m);
if (IsAggregateMethod(m) &&
GetSourceType(m) == lambdaExpression.Parameters[0].Type) {
accessedColumns.Add(m.ToString());
}
return m;
}
protected override Expression VisitUnary(UnaryExpression u) {
if (u.NodeType == ExpressionType.Quote) {
return this.Visit(StripQuotes(u));
}
this.Visit(u.Operand);
terms.Push(u);
return u;
}
protected override Expression VisitBinary(BinaryExpression b) {
this.Visit(b.Left);
this.Visit(b.Right);
terms.Push(b);
return b;
}
protected override Expression VisitConstant(ConstantExpression c) {
terms.Push(c);
return c;
}
protected override Expression VisitParameter(ParameterExpression p) {
terms.Push(p);
return p;
}
protected override Expression VisitMemberAccess(MemberExpression m) {
if (GetSourceType(m) == lambdaExpression.Parameters[0].Type &&
// check if property maps to column in db
(m.Type.IsValueType || m.Type == typeof(string))) {
accessedColumns.Add(GetHashedName(m));
}
terms.Push(m);
return m;
}
protected override NewExpression VisitNew(NewExpression newExpression) {
foreach (var argument in newExpression.Arguments) {
this.Visit(argument);
}
terms.Push(newExpression);
return newExpression;
}
protected override Expression VisitConditional(ConditionalExpression c) {
throw new InvalidOperationException();
}
protected override ElementInit VisitElementInitializer(ElementInit initializer) {
throw new InvalidOperationException();
}
protected override Expression VisitTypeIs(TypeBinaryExpression b) {
throw new InvalidOperationException();
}
protected override MemberMemberBinding VisitMemberMemberBinding(MemberMemberBinding binding) {
throw new InvalidOperationException();
}
protected override MemberListBinding VisitMemberListBinding(MemberListBinding binding) {
throw new InvalidOperationException();
}
protected override IEnumerable<MemberBinding> VisitBindingList(ReadOnlyCollection<MemberBinding> original) {
throw new InvalidOperationException();
}
protected override IEnumerable<ElementInit> VisitElementInitializerList(ReadOnlyCollection<ElementInit> original) {
throw new InvalidOperationException();
}
protected override Expression VisitMemberInit(MemberInitExpression init) {
throw new InvalidOperationException();
}
protected override Expression VisitListInit(ListInitExpression init) {
throw new InvalidOperationException();
}
protected override Expression VisitNewArray(NewArrayExpression na) {
throw new InvalidOperationException();
}
protected override Expression VisitInvocation(InvocationExpression iv) {
throw new InvalidOperationException();
}
public StringBuilder GetExpressionAsString(bool replaceAliases) {
EvaluateTerms();
var result = sb.ToString();
return new StringBuilder(ReplaceAliases(result, replaceAliases));
}
private void EvaluateTerms() {
if (sb.Length > 0) {
// terms have already been evaluated
return;
}
Debug.Assert(terms.Count != 0);
while (terms.Count > 0) {
GetExpression();
if (terms.Count == 1 && terms.Peek().NodeType == ExpressionType.Constant) {
break;
}
if (terms.Count > 1 && terms.Peek().NodeType == ExpressionType.Constant) {
GetOperandValue();
}
}
sb.Append((terms.Pop() as ConstantExpression).Value.ToString());
}
private void GetExpression() {
var op = StripQuotes(terms.Pop());
switch (op.NodeType) {
case ExpressionType.And:
case ExpressionType.AndAlso:
GetBinaryOperation(" AND ");
break;
case ExpressionType.Or:
case ExpressionType.OrElse:
GetBinaryOperation(" OR ");
break;
case ExpressionType.Equal:
GetBinaryOperation(" = ");
break;
case ExpressionType.NotEqual:
GetBinaryOperation(" <> ");
break;
case ExpressionType.LessThan:
GetBinaryOperation(" < ");
break;
case ExpressionType.LessThanOrEqual:
GetBinaryOperation(" <= ");
break;
case ExpressionType.GreaterThan:
GetBinaryOperation(" > ");
break;
case ExpressionType.GreaterThanOrEqual:
GetBinaryOperation(" >= ");
break;
case ExpressionType.ExclusiveOr:
GetXOR();
break;
case ExpressionType.Add:
GetBinaryOperation(" + ");
break;
case ExpressionType.Subtract:
GetBinaryOperation(" - ");
break;
case ExpressionType.Multiply:
GetBinaryOperation(" * ");
break;
case ExpressionType.Divide:
GetBinaryOperation(@" \ ");
break;
case ExpressionType.Modulo:
GetBinaryOperation(" % ");
break;
case ExpressionType.Not:
GetUnaryExpression(" NOT ");
break;
case ExpressionType.Coalesce:
GetCoalesce();
break;
case ExpressionType.Convert:
GetConversion(op as UnaryExpression);
break;
case ExpressionType.Lambda:
GetLambda(op as LambdaExpression);
break;
case ExpressionType.New:
GetNew(op as NewExpression);
break;
case ExpressionType.MemberAccess:
GetMemberAccess(op as MemberExpression);
break;
case ExpressionType.Parameter:
GetParameterValue(op as ParameterExpression);
break;
case ExpressionType.Constant:
GetConstantValue(op as ConstantExpression);
break;
case ExpressionType.Call:
GetMethodCall(op as MethodCallExpression);
break;
default:
throw new NotSupportedException(
string.Format("The operator '{0}' is not supported", op.NodeType));
}
}
private void GetUnaryExpression(string op) {
string unaryOperand = GetUnaryOperand();
terms.Push(Expression.Constant(
new BoxedConstant(op + " (" + unaryOperand + ")"))
);
}
private void GetCoalesce() {
string rightOperand;
string leftOperand;
GetBinaryOperands(out rightOperand, out leftOperand);
terms.Push(Expression.Constant(
new BoxedConstant(outerStatement.syntaxHelper.GetIsNullKeyword() +
"(" + rightOperand + ", " + leftOperand + ")")
));
}
private void GetBinaryOperation(string op) {
string rightOperand;
string leftOperand;
GetBinaryOperands(out rightOperand, out leftOperand);
terms.Push(Expression.Constant(
new BoxedConstant("(" + rightOperand + op + leftOperand + ")")
));
}
private void GetXOR() {
string rightOperand;
string leftOperand;
GetBinaryOperands(out rightOperand, out leftOperand);
var xor = "((" + rightOperand + " OR " + leftOperand + ")" +
" AND NOT " + "(" + rightOperand + " AND " + leftOperand + "))";
terms.Push(Expression.Constant(
new BoxedConstant(xor)
));
}
private void GetLambda(LambdaExpression lambda) {
if (lambda.Body.Type != typeof(void)) {
terms.Push(Expression.Constant(
new BoxedConstant(lambda.ToString()))
);
}
}
private void GetConversion(UnaryExpression op) {
switch (op.Type.Name) {
case "Boolean":
case "Char":
case "Enum":
case "Guid":
case "String":
case "DateTime":
case "Decimal":
case "Int16":
case "Int32":
case "Int64":
case "IntPtr":
case "UInt16":
case "UInt32":
case "UInt64":
case "UIntPtr":
case "Byte":
case "SByte":
case "Double":
case "Single":
case "Nullable`1":
GetUnaryOperand();
//wrong emit sql for conversion
terms.Push(op.Operand);
break;
default:
throw new NotSupportedException(
string.Format("The conversion to '{0}' is not supported", op.Type.Name));
}
}
private void GetConstantValue(ConstantExpression c) {
if (Type.GetTypeCode(c.Value.GetType()) == TypeCode.Object) {
if (c.Value.GetType().Name.StartsWith("Query`1")) {
terms.Push(Expression.Constant(
new BoxedConstant(
GetTableName(c.Value.GetType().GetGenericArguments()[0])
)));
}
else if (c.Value.GetType() == typeof(BoxedConstant)) {
terms.Push(Expression.Constant(
((BoxedConstant)c.Value).Expression));
return;
}
}
if (visitedConstants.ContainsKey(c)) {
terms.Push(visitedConstants[c]);
}
else {
var constantValue = Expression.Constant(new BoxedConstant("@p" + parameterCount.ToString()));
visitedConstants[c] = constantValue;
terms.Push(constantValue);
parameterCount++;
}
}
private void GetParameterValue(ParameterExpression p) {
terms.Push(Expression.Constant(p.Name));
}
private void GetMemberAccess(MemberExpression m) {
if (m.Expression != null) {
terms.Push(Expression.Constant(
new BoxedConstant(GetHashedName(m))));
return;
}
terms.Push(Expression.Constant(
new BoxedConstant(string.Empty)));
}
private void GetMethodCall(MethodCallExpression m) {
if (m.Method.DeclaringType == typeof(Queryable) ||
m.Method.DeclaringType == typeof(Enumerable)) {
GetQueryableMethodCall(m);
return;
}
else if (m.Method.DeclaringType == typeof(string)) {
GetStringMethodCall(m);
return;
}
else if (m.Method.DeclaringType == typeof(Normalizer)) {
switch (m.Method.Name) {
case "IsNotNull":
GetIsNotNullCall(m);
break;
case "GetDateValue":
GetDateValueCall(m);
break;
case "GetStringLength":
GetStringLengthCall(m);
break;
}
return;
}
else if (m.Method.DeclaringType == typeof(object)) {
if (m.Object != null) {
GetUnaryOperand();
}
for (int i = 0; i < m.Arguments.Count; i++) {
GetUnaryOperand();
}
terms.Push(Expression.Constant(new BoxedConstant(m.Method.Name.ToUpper())));
return;
}
throw new ArgumentException();
}
private void GetQueryableMethodCall(MethodCallExpression m) {
Debug.Assert(m.Method.DeclaringType == typeof(Queryable) ||
m.Method.DeclaringType == typeof(Enumerable));
Debug.Assert(m.Object == null);
for (int i = 0; i < m.Arguments.Count; i++) {
GetUnaryOperand();
}
var sourceType = GetSourceType(m);
if (sourceType != lambdaExpression.Parameters[0].Type) {
terms.Push(Expression.Constant(
new BoxedConstant(m.Method.Name)));
return;
}
if (IsAggregateMethod(m)) {
terms.Push(Expression.Constant(
new BoxedConstant(GetAggregate(m))));
return;
}
switch (m.Method.Name) {
case "Any":
GetAnyCall(m);
return;
case "All":
GetAllCall(m);
return;
}
GetQueryableCall(m);
}
private void GetStringMethodCall(MethodCallExpression m) {
Debug.Assert(m.Method.DeclaringType == typeof(string));
var operands = new List<string>();
foreach (var operand in m.Arguments) {
operands.Add(GetOperandValue());
}
if (m.Object != null) {
operands.Add(GetOperandValue());
}
operands.Reverse();
terms.Push(Expression.Constant(
new BoxedConstant(outerStatement.syntaxHelper.GetStringMethod(m.Method, operands))));
}
private void GetIsNotNullCall(MethodCallExpression m) {
Debug.Assert(m.Method.DeclaringType == typeof(Normalizer));
var operand = GetOperandValue();
terms.Push(Expression.Constant(
new BoxedConstant(operand + " IS NOT NULL ")));
}
private void GetDateValueCall(MethodCallExpression m) {
Debug.Assert(m.Method.DeclaringType == typeof(Normalizer));
GetOperandValue();
GetOperandValue();
var member = StripQuotes(m.Arguments[0]) as MemberExpression;
var datePart = (string)(m.Arguments[1] as ConstantExpression).Value;
var memberName = GetHashedName(member);
terms.Push(Expression.Constant(
new BoxedConstant(outerStatement.syntaxHelper
.GetDateTimeProperty(memberName, datePart))));
}
private void GetStringLengthCall(MethodCallExpression m) {
Debug.Assert(m.Method.DeclaringType == typeof(Normalizer));
GetOperandValue();
var member = StripQuotes(m.Arguments[0]) as MemberExpression;
var memberName = GetHashedName(member);
terms.Push(Expression.Constant(
new BoxedConstant(outerStatement.syntaxHelper.GetLengthProperty(memberName))));
}
private void GetQueryableCall(MethodCallExpression m) {
SqlExpressionParser parser = new SqlExpressionParser(indentLevel,
outerStatement,
AggregateType.None,
outerStatement.syntaxHelper);
BinaryExpression correlation = null;
Type enumerableType = null;
if (m.Arguments[0].NodeType != ExpressionType.Call) {
correlation = GetLambdaExpression(
(GetCorrelation(m.Arguments[0].Type,
m.Arguments[0].Type.GetGenericArguments()[0])
.Arguments[1]))
.Body as BinaryExpression;
enumerableType = m.Arguments[0].Type.GetGenericArguments()[0];
}
else {
enumerableType = GetMemberSourceType(m);
correlation = GetLambdaExpression(
(GetCorrelation(enumerableType,
enumerableType.GetGenericArguments()[0]).Arguments[1]))
.Body as BinaryExpression;
}
var whereCall = QueryableMethodsProvider.GetWhereCall(enumerableType, "source",
correlation);
parser.VisitMethodCall(whereCall);
parser.Translate(m);
parameterCount = parser.ParameterCount;
terms.Push(Expression.Constant(
new BoxedConstant(parser.GetSQLStatement())));
}
private void GetAnyCall(MethodCallExpression m) {
SqlExpressionParser parser = new SqlExpressionParser(indentLevel,
outerStatement,
AggregateType.Count,
outerStatement.syntaxHelper);
BinaryExpression correlation = null;
MethodCallExpression lambdaCondition = null;
if (m.Arguments[0].NodeType != ExpressionType.Call) {
correlation = GetLambdaExpression(
(GetCorrelation(m.Arguments[0].Type,
m.Arguments[0].Type.GetGenericArguments()[0])
.Arguments[1]))
.Body as BinaryExpression;
var enumerableType = m.Arguments[0].Type.GetGenericArguments()[0];
var whereCall = QueryableMethodsProvider.GetWhereCall(enumerableType, "source",
correlation);
if (m.Arguments.Count == 2) {
lambdaCondition = QueryableMethodsProvider.GetWhereCall(enumerableType, "source",
GetLambdaExpression(m.Arguments[1]).Body as BinaryExpression);
parser.Visit(lambdaCondition);
}
parser.Translate(whereCall);
}
else {
var enumerableType = GetMemberSourceType(m);
correlation = GetLambdaExpression(
(GetCorrelation(enumerableType,
enumerableType.GetGenericArguments()[0]).Arguments[1]))
.Body as BinaryExpression;
var whereCall = QueryableMethodsProvider.GetWhereCall(enumerableType.GetGenericArguments()[0],
"source", correlation);
parser.VisitMethodCall(whereCall);
if (m.Arguments.Count == 2) {
lambdaCondition = QueryableMethodsProvider.GetWhereCall(enumerableType, "source",
GetLambdaExpression(m.Arguments[1]).Body as BinaryExpression);
parser.Visit(lambdaCondition);
}
parser.Translate(m.Arguments[0] as MethodCallExpression);
}
parameterCount = parser.ParameterCount;
var statement = Environment.NewLine +
"(" + parser.GetSQLStatement() + ") > 0" +
Environment.NewLine;
terms.Push(Expression.Constant(new BoxedConstant(statement)));
}
private void GetAllCall(MethodCallExpression m) {
SqlExpressionParser parser = new SqlExpressionParser(indentLevel,
outerStatement,
AggregateType.Count,
outerStatement.syntaxHelper);
var methodLambda = GetLambdaExpression(m.Arguments[1]);
BinaryExpression correlation = null;
if (m.Arguments[0].NodeType != ExpressionType.Call) {
correlation =
Expression.And(
GetLambdaExpression(
(GetCorrelation(m.Arguments[0].Type,
m.Arguments[0].Type.GetGenericArguments()[0]))
.Arguments[1])
.Body as BinaryExpression,
Expression.Not((methodLambda as LambdaExpression).Body)
);
var enumerableType = m.Arguments[0].Type.GetGenericArguments()[0];
var whereCall = QueryableMethodsProvider.GetWhereCall(enumerableType, "source",
correlation);
parser.Translate(whereCall);
}
else {
var enumerableType = GetMemberSourceType(m);
correlation =
Expression.And(
GetLambdaExpression(
(GetCorrelation(enumerableType,
enumerableType.GetGenericArguments()[0]))
.Arguments[1])
.Body as BinaryExpression,
Expression.Not((methodLambda as LambdaExpression).Body)
);
var whereCall = QueryableMethodsProvider.GetWhereCall(enumerableType.GetGenericArguments()[0],
"source",
correlation);
parser.VisitMethodCall(whereCall);
parser.Translate(m.Arguments[0] as MethodCallExpression);
}
parameterCount = parser.ParameterCount;
var statement = Environment.NewLine +
"(" + parser.GetSQLStatement() + ") = 0" +
Environment.NewLine;
terms.Push(Expression.Constant(
new BoxedConstant(statement)));
}
private string GetCount(MethodCallExpression method) {
Debug.Assert(method.Arguments.Count == 1);
var sourceType = method.Method.GetGenericArguments()[0];
var whereCall = GetCorrelation(method.Arguments[0].Type, sourceType);
var selectCall = QueryableMethodsProvider.GetSelectCall(whereCall);
SqlExpressionParser parser = new SqlExpressionParser(indentLevel + 1,
outerStatement,
GetAggregateTypeFromName(method.Method.Name),
outerStatement.syntaxHelper);
parser.Translate(selectCall);
parameterCount = parser.ParameterCount;
return AddAlias(method, parser.GetSQLStatement());
}
private string GetAggregate(MethodCallExpression method) {
if (method.Arguments.Count == 1) {
return GetCount(method);
}
Debug.Assert(method.Arguments.Count == 2);
var accessLambda = (LambdaExpression)method.Arguments[1];
var sourceType = accessLambda.Parameters[0].Type;
if (sourceType != lambdaExpression.Parameters[0].Type
&& accessLambda.Body.NodeType == ExpressionType.Call) {
return GetNestedAggregate(method);
}
var selectorParam = Expression.Parameter(sourceType,
accessLambda.Parameters[0].Name);
var projectionSelector = Expression.Lambda(accessLambda.Body, selectorParam);
var whereCall = GetCorrelation(method.Arguments[0].Type, sourceType);
var selectCall = QueryableMethodsProvider.GetSelectCall(whereCall, projectionSelector);
SqlExpressionParser parser =
new SqlExpressionParser(indentLevel + 1,
outerStatement,
GetAggregateTypeFromName(method.Method.Name),
outerStatement.syntaxHelper);
parser.Translate(selectCall);
parameterCount = parser.ParameterCount;
return AddAlias(method, parser.GetSQLStatement());
}
private string GetNestedAggregate(MethodCallExpression method) {
var accessLambda = (LambdaExpression)method.Arguments[1];
var accessLambdaBody = ((LambdaExpression)method.Arguments[1]).Body
as MethodCallExpression;
Debug.Assert(accessLambdaBody != null);
var sourceType = accessLambda.Parameters[0].Type;
var selectorParam = Expression.Parameter(sourceType, "source");
var whereCall = GetCorrelation(method.Arguments[0].Type, sourceType);
var foreignKey = (GetLambdaExpression(whereCall.Arguments[1]).Body
as BinaryExpression).Left;
var keyValueType = typeof(KeyAggregatePair<int, int>)
.GetGenericTypeDefinition()
.MakeGenericType(foreignKey.Type,
accessLambda.Body.Type);
var keyValueConstructor =
keyValueType.GetConstructor(new Type[]{foreignKey.Type,
accessLambda.Body.Type});
var newKeyValue = Expression.New(keyValueConstructor,
new Expression[]{foreignKey,
accessLambdaBody
},
new PropertyInfo[]{
keyValueType.GetProperty("Key"),
keyValueType.GetProperty("Aggregate")
});
var projectionSelector = Expression.Lambda(newKeyValue, selectorParam);
var selectCall = QueryableMethodsProvider.GetSelectCall(whereCall, projectionSelector);
SqlExpressionParser parser =
new SqlExpressionParser(indentLevel + 2,
outerStatement,
AggregateType.None,
outerStatement.syntaxHelper);
parser.Translate(selectCall);
parameterCount = parser.ParameterCount;
string aggregate = null;
switch (method.Method.Name) {
case "Count":
aggregate = outerStatement.syntaxHelper.GetCountKeyword();
break;
case "Sum":
aggregate = outerStatement.syntaxHelper.GetSumKeyword();
break;
case "Min":
aggregate = outerStatement.syntaxHelper.GetMinKeyword();
break;
case "Max":
aggregate = outerStatement.syntaxHelper.GetMaxKeyword();
break;
case "Average":
aggregate = outerStatement.syntaxHelper.GetAverageKeyword();
break;
default:
throw new ArgumentException();
}
var statement = GetIndentation(indentLevel) + "SELECT " + aggregate + "(Aggregate)" +
Environment.NewLine +
GetIndentation(indentLevel) + "FROM (" +
Environment.NewLine +
parser.GetSQLStatement() +
GetIndentation(indentLevel) + ") AS aggregate" + indentLevel.ToString();
return AddAlias(method, statement);
}
private MethodCallExpression GetCorrelation(Type methodType, Type sourceType) {
var declaringType = lambdaExpression.Parameters[0].Type;
BinaryExpression whereCondition = null;
// if for example the declaring type looks like
// <>f__AnonymousType0`2[[Order],[Customer]]
// as a result of a join
// we need to correlate both order and customer
var genericArguments = declaringType.GetGenericArguments();
if (genericArguments.Length == 0) {
whereCondition = GetCorrelationCondition(methodType, sourceType, declaringType,
declaringType.GUID + ".");
}
else {
var theType = genericArguments
.Where(t => t.GetProperties()
.Any(p => p.PropertyType == methodType))
.Single();
whereCondition = GetCorrelationCondition(methodType, sourceType, theType,
declaringType.GUID + "." + theType.GUID + ".");
}
var whereCall = QueryableMethodsProvider.GetWhereCall(sourceType, "source", whereCondition);
return whereCall;
}
private BinaryExpression GetCorrelationCondition(Type methodType,
Type sourceType,
Type declaringType,
string tableAlias) {
var foreignKey = GetForeignKey(declaringType, methodType);
var foreignKeyExpression = Expression.MakeMemberAccess(
Expression.Parameter(sourceType, sourceType.Name),
sourceType.GetProperty(foreignKey));
var whereCondition = Expression.Equal(foreignKeyExpression,
Expression.Constant(
new BoxedConstant(tableAlias +
GetPrimaryKey(declaringType))));
return whereCondition;
}
private static AggregateType GetAggregateTypeFromName(string name) {
switch (name) {
case "Count":
return AggregateType.Count;
case "Sum":
return AggregateType.Sum;
case "Min":
return AggregateType.Min;
case "Max":
return AggregateType.Max;
case "Average":
return AggregateType.Average;
}
throw new ArgumentException();
}
private void GetNew(NewExpression newExpression) {
foreach (var argument in newExpression.Arguments) {
GetOperandValue();
}
var args = newExpression.Arguments;
var members = newExpression.Members;
if (newExpression.Type != lambdaExpression.Body.Type) {
var lambdaHandler = new LambdaExpressionHandler(indentLevel + 1,
Expression.Lambda(newExpression,
Expression.Parameter(
lambdaExpression.Parameters[0].Type,
"source")),
outerStatement);
foreach (var column in lambdaHandler.aliases) {
aliases[lambaExpressionId + "." + column.Key] = column.Value;
aliases[column.Key] = column.Value;
}
}
else {
for (int i = 0; i < args.Count; i++) {
string memberName = null;
if (members[i].Name.StartsWith("get_")) {
memberName = members[i].Name.Substring(4);
}
else {
memberName = members[i].Name;
}
string key = lambaExpressionId + "." + memberName;
string value = null;
switch (args[i].NodeType) {
case ExpressionType.MemberAccess:
// hack - check if member maps to a column in db
if (!(args[i].Type.IsValueType || args[i].Type == typeof(string))) {
continue;
}
value = GetHashedName((args[i] as MemberExpression));
break;
case ExpressionType.Call:
if (!IsAggregateMethod(args[i] as MethodCallExpression)) {
continue;
}
var key2 = args[i].ToString();
if (aliases.ContainsKey(key2)) {
aliases[key2] = aliases[key2] + " AS " + key;
}
value = memberName;
break;
default:
continue;
}
aliases[key] = value;
}
}
terms.Push(Expression.Constant(
new BoxedConstant(newExpression.ToString())));
}
private void GetBinaryOperands(out string rightOperand, out string leftOperand) {
Debug.Assert(terms.Count > 1);
leftOperand = GetOperandValue();
rightOperand = GetOperandValue();
}
private string GetUnaryOperand() {
Debug.Assert(terms.Count > 0);
return GetOperandValue();
}
private string GetOperandValue() {
while (terms.Peek().Type != typeof(BoxedConstant)) {
GetExpression();
}
var result = terms.Pop();
return (result as ConstantExpression).Value.ToString();
}
private static string GetProjectionSql(int indentLevel, string projection) {
return Environment.NewLine +
GetIndentation(indentLevel) +
"(" +
Environment.NewLine +
projection +
GetIndentation(indentLevel) +
") " +
Environment.NewLine;
}
private static string GetHashedName(MemberExpression m) {
string memberName = null;
if (m.Type == typeof(string) || m.Type.IsValueType) {
memberName = m.Member.Name;
}
else {
memberName = m.Type.GUID.ToString();
}
if (m.Expression.NodeType == ExpressionType.MemberAccess) {
return GetHashedName((MemberExpression)m.Expression) + "." + memberName;
}
return m.Expression.Type.GUID.ToString() + "." + memberName;
}
private Type GetSourceType(Expression expression) {
switch (expression.NodeType) {
case ExpressionType.MemberAccess:
return GetSourceType(
(expression as MemberExpression).Expression);
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
return GetSourceType((expression as UnaryExpression).Operand);
case ExpressionType.Constant:
case ExpressionType.Parameter:
return expression.Type;
case ExpressionType.Call:
var method = expression as MethodCallExpression;
Debug.Assert(method.Method.DeclaringType == typeof(Queryable) ||
method.Method.DeclaringType == typeof(Enumerable));
return GetSourceType(method.Arguments[0]);
default:
throw new ArgumentException();
}
}
private Type GetMemberSourceType(Expression expression) {
switch (expression.NodeType) {
case ExpressionType.MemberAccess:
return expression.Type;
case ExpressionType.Convert:
case ExpressionType.ConvertChecked:
return GetMemberSourceType((expression as UnaryExpression).Operand);
case ExpressionType.Constant:
case ExpressionType.Parameter:
return expression.Type;
case ExpressionType.Call:
var method = expression as MethodCallExpression;
Debug.Assert(method.Method.DeclaringType == typeof(Queryable) ||
method.Method.DeclaringType == typeof(Enumerable));
return GetMemberSourceType(method.Arguments[0]);
default:
throw new ArgumentException();
}
}
private static string GetPrimaryKey(Type sourceType) {
var columnAttribute = typeof(System.Data.Linq.Mapping.ColumnAttribute);
var foreignKey = (from property in sourceType.GetProperties()
where property.GetCustomAttributes(columnAttribute, false).Length != 0
&& ((System.Data.Linq.Mapping.ColumnAttribute)
property.GetCustomAttributes(columnAttribute, false)[0])
.IsPrimaryKey == true
select property.Name).First();
return foreignKey;
}
private static string GetForeignKey(Type sourceType, Type memberType) {
var associationAttribute = typeof(System.Data.Linq.Mapping.AssociationAttribute);
var foreignKey = (from property in sourceType.GetProperties()
where property.PropertyType == memberType
select ((System.Data.Linq.Mapping.AssociationAttribute)
property.GetCustomAttributes(associationAttribute, false)[0]
).OtherKey).First();
return foreignKey;
}
public string[] GetAccessedFields() {
return accessedColumns.Distinct().ToArray();
}
private string AddAlias(MethodCallExpression method, string projection) {
var key = method.ToString();
aliases[key] = GetProjectionSql(indentLevel, projection);
return aliases[key];
}
public string ReplaceAliases(string expression) {
return ReplaceAliases(expression, true);
}
public string ReplaceAliases(string expression, bool replaceAliases) {
if (!replaceAliases) {
return expression;
}
var result = new StringBuilder(expression);
foreach (var column in aliases) {
result.Replace(column.Key, column.Value);
}
return result.ToString();
}
private class Normalizer : ExpressionVisitor {
private readonly static MethodInfo isNotNullMethod = typeof(Normalizer)
.GetMethod("IsNotNull",
BindingFlags.Static |
BindingFlags.Public);
private readonly static MethodInfo getDateValue = typeof(Normalizer)
.GetMethod("GetDateValue",
BindingFlags.Static |
BindingFlags.Public);
private readonly static MethodInfo getStringLength = typeof(Normalizer)
.GetMethod("GetStringLength",
BindingFlags.Static |
BindingFlags.Public);
public Expression Normalize(Expression e) {
return this.Visit(e);
}
protected override Expression VisitMemberAccess(MemberExpression m) {
if (m.Member.DeclaringType.Name == "Nullable`1") {
switch (m.Member.Name) {
case "Value":
return m.Expression;
case "HasValue":
return Expression.Call(isNotNullMethod.MakeGenericMethod(m.Expression.Type),
m.Expression);
}
}
var member = m;
while (member.Expression as MemberExpression != null) {
member = (MemberExpression)member.Expression;
}
if (m.Member.DeclaringType == typeof(DateTime)) {
return Expression.Call(getDateValue.MakeGenericMethod(m.Type),
new Expression[]{member,
Expression.Constant(m.Member.Name)});
}
if (m.Member.DeclaringType == typeof(string)) {
Debug.Assert(m.Member.Name == "Length");
return Expression.Call(getStringLength,
member);
}
return m;
}
protected override Expression VisitConditional(ConditionalExpression c) {
Debug.Assert(c.Test as ConstantExpression != null);
if ((bool)(c.Test as ConstantExpression).Value == true) {
return c.IfTrue;
}
return c.IfFalse;
}
public static bool IsNotNull<T>(T o) {
throw new NotImplementedException();
}
public static T GetDateValue<T>(Expression expression, string datePart) {
throw new NotImplementedException();
}
public static int GetStringLength(Expression expression) {
throw new NotImplementedException();
}
}
public struct KeyAggregatePair<T, V> {
private readonly T key;
private readonly V aggregate;
public T Key {
get {
return key;
}
}
public V Aggregate {
get {
return aggregate;
}
}
public KeyAggregatePair(T key, V aggregate) {
this.key = key;
this.aggregate = aggregate;
}
}
}
private class BoxedConstant {
private string expression = null;
public BoxedConstant(string expression) {
this.expression = expression;
}
public string Expression {
get {
return expression;
}
}
public static bool operator ==(string s, BoxedConstant bc) {
throw new InvalidOperationException();
}
public static bool operator !=(string s, BoxedConstant bc) {
throw new InvalidOperationException();
}
public static bool operator ==(int i, BoxedConstant bc) {
throw new InvalidOperationException();
}
public static bool operator !=(int i, BoxedConstant bc) {
throw new InvalidOperationException();
}
public static bool operator ==(int? i, BoxedConstant bc) {
throw new InvalidOperationException();
}
public static bool operator !=(int? i, BoxedConstant bc) {
throw new InvalidOperationException();
}
public override string ToString() {
return expression;
}
}
private static class QueryableMethodsProvider {
private static readonly MethodInfo[] queryableMethods = typeof(Queryable).GetMethods();
private static readonly MethodInfo selectMethod =
(from q in queryableMethods
where q.Name == "Select" && q.GetGenericArguments().Length == 2
select q.GetGenericMethodDefinition()).First();
private static readonly MethodInfo whereMethod =
(from q in queryableMethods
where q.Name == "Where" && q.GetGenericArguments().Length == 1
select q.GetGenericMethodDefinition()).First();
private static readonly MethodInfo countMethod = queryableMethods
.Where(a => a.Name == "Count" &&
a.GetParameters().Count() == 1).First();
private static readonly Type queryableType = typeof(System.Linq.IQueryable<IQueryable<int>>)
.GetGenericTypeDefinition();
public static MethodCallExpression GetSelectCall(Type sourceType) {
var queryableType = QueryableMethodsProvider.GetQueryableType(sourceType);
var sourceParam = Expression.Parameter(queryableType, "source");
var selectorParam = Expression.Parameter(sourceType, "param");
var projectionSelector = Expression.Lambda(selectorParam, selectorParam);
return GetSelectCall(sourceParam, projectionSelector);
}
public static MethodCallExpression GetSelectCall(Expression source) {
var sourceType = source.Type.GetGenericArguments()[0];
var selectorParam = Expression.Parameter(sourceType, "param");
var projectionSelector = Expression.Lambda(selectorParam, selectorParam);
return GetSelectCall(source, projectionSelector);
}
public static MethodCallExpression GetSelectCall(Expression source, LambdaExpression projectionSelector) {
var selectQuery = QueryableMethodsProvider
.GetSelectMethod(source.Type.GetGenericArguments()[0],
projectionSelector.Type.GetGenericArguments()[1]);
return Expression.Call(selectQuery, source, Expression.Constant(projectionSelector));
}
public static MethodCallExpression GetSelectCall(Type sourceType, LambdaExpression projectionSelector) {
var queryableType = QueryableMethodsProvider.GetQueryableType(sourceType);
var sourceParam = Expression.Parameter(queryableType, "source");
return GetSelectCall(sourceParam, projectionSelector);
}
public static MethodCallExpression GetWhereCall(Type sourceType, string sourceName, BinaryExpression condition) {
var queryableType = QueryableMethodsProvider.GetQueryableType(sourceType);
var whereLambda = Expression.Lambda(condition, Expression.Parameter(sourceType, sourceName));
var whereQuery = QueryableMethodsProvider.GetWhereMethod(sourceType);
var queryableSource = Expression.Parameter(queryableType, "source");
var whereCall = Expression.Call(whereQuery, queryableSource, whereLambda);
return whereCall;
}
public static Type GetQueryableType(Type tableType) {
return queryableType.MakeGenericType(tableType);
}
public static MethodInfo GetCountMethod(Type tableType) {
return countMethod.MakeGenericMethod(tableType);
}
private static MethodInfo GetSelectMethod(Type tableType, Type projectionSelectorType) {
return selectMethod.MakeGenericMethod(tableType, projectionSelectorType); ;
}
private static MethodInfo GetWhereMethod(Type tableType) {
return whereMethod.MakeGenericMethod(tableType); ;
}
}
private class Binder : ExpressionVisitor {
private readonly LambdaExpression selector = null;
private readonly LambdaExpression binderLambda = null;
private readonly Delegate binderMethod = null;
private readonly Dictionary<string, int> fieldPositions = new Dictionary<string, int>();
private readonly ParameterExpression reader = Expression.Parameter(typeof(DbDataReader),
"reader");
private static readonly MethodInfo getBoolean = typeof(DbDataReader).GetMethod("GetBoolean");
private static readonly MethodInfo getByte = typeof(DbDataReader).GetMethod("GetByte");
private static readonly MethodInfo getChar = typeof(DbDataReader).GetMethod("GetChar");
private static readonly MethodInfo getDateTime = typeof(DbDataReader).GetMethod("GetDateTime");
private static readonly MethodInfo getDecimal = typeof(DbDataReader).GetMethod("GetDecimal");
private static readonly MethodInfo getDouble = typeof(DbDataReader).GetMethod("GetDouble");
private static readonly MethodInfo getGUID = typeof(DbDataReader).GetMethod("GetGuid");
private static readonly MethodInfo getInt16 = typeof(DbDataReader).GetMethod("GetInt16");
private static readonly MethodInfo getInt32 = typeof(DbDataReader).GetMethod("GetInt32");
private static readonly MethodInfo getInt64 = typeof(DbDataReader).GetMethod("GetInt64");
private static readonly MethodInfo getString = typeof(DbDataReader).GetMethod("GetString");
private static readonly MethodInfo getValue = typeof(DbDataReader).GetMethod("GetValue");
private static readonly MethodInfo isDbNull = typeof(DbDataReader).GetMethod("IsDBNull");
private static readonly ThreadSafeCache<string, Binder> binderCache =
new ThreadSafeCache<string, Binder>();
private static readonly MethodInfo convert =
(from m in typeof(Binder).GetMethods(BindingFlags.NonPublic |
BindingFlags.Static)
where m.Name == "Convert"
select m).First().GetGenericMethodDefinition();
private static readonly MethodInfo partialEval =
(from partial in typeof(Evaluator).GetMethods()
where partial.Name == "PartialEval" && partial.GetParameters().Count() == 1
select partial).First();
private Binder(LambdaExpression selector) {
Debug.Assert(selector != null);
Debug.Assert(selector.Parameters.Count == 1);
this.selector = selector;
if (selector.Body.NodeType != ExpressionType.Parameter) {
binderLambda = Expression.Lambda(((LambdaExpression)this.Visit(selector)).Body,
reader);
}
else {
binderLambda = GetBindingLambda(selector);
}
binderMethod = binderLambda.Compile();
}
public static Delegate GetBinder(LambdaExpression selector) {
string key = selector.Parameters[0].Type.GUID +
selector.ToString() +
selector.Type.GUID;
if (binderCache.ContainsKey(key)) {
return binderCache[key].binderMethod;
}
Binder binder = new Binder(selector);
Debug.Assert(binder.binderMethod != null);
binderCache.TryAdd(key, binder);
return binder.binderMethod;
}
protected override Expression VisitMemberAccess(MemberExpression m) {
Debug.Assert(selector.Parameters.Count == 1);
var accessedType = GetAccessedType(m);
if (accessedType != selector.Parameters[0].Type) {
return m;
}
var accessedMembers = new List<MemberExpression>();
accessedMembers.Add(m);
while (accessedMembers[accessedMembers.Count - 1].Member.DeclaringType != accessedType) {
accessedMembers.Add((MemberExpression)accessedMembers[accessedMembers.Count - 1].Expression);
}
int fieldPosition = GetFieldPosition(m);
if (GetAccessedType(accessedMembers[accessedMembers.Count - 1]) == selector.Parameters[0].Type) {
return GetFieldReader(m, fieldPosition);
}
else {
var readerAccessor = GetFieldReader(accessedMembers[accessedMembers.Count - 1], fieldPosition);
for (int i = accessedMembers.Count - 2; i >= 0; i--) {
readerAccessor = Expression.MakeMemberAccess(readerAccessor, accessedMembers[i].Member);
}
return readerAccessor;
}
}
protected override Expression VisitMethodCall(MethodCallExpression m) {
if (!IsAggregateMethod(m)) {
if ((m.Method.DeclaringType == typeof(Queryable) ||
m.Method.DeclaringType == typeof(Enumerable))
&& m.Type.Name == "IQueryable`1") {
var converter = convert.MakeGenericMethod(m.Type);
return Expression.Convert(Expression.Call(partialEval,
base.VisitMethodCall(m)),
m.Type,
converter);
}
return base.VisitMethodCall(m);
}
if (m.Arguments[0].NodeType != ExpressionType.MemberAccess) {
return base.VisitMethodCall(m);
}
Debug.Assert(m.Arguments.Count > 0);
Debug.Assert(m.Arguments[0].NodeType == ExpressionType.MemberAccess);
if (GetAccessedType(m.Arguments[0] as MemberExpression) != selector.Parameters[0].Type) {
return m;
}
int fieldPosition = GetFieldPosition(m.ToString());
return GetFieldReader(m, fieldPosition);
}
private Expression GetFieldReader(Expression m, int fieldPosition) {
var field = Expression.Constant(fieldPosition, typeof(int));
var readerExpression = GetReaderExpression(m, field);
var isDbNullExpression = Expression.Call(reader, isDbNull, field);
var conditionalExpression =
Expression.Condition(Expression.Not(isDbNullExpression),
readerExpression,
Expression.Convert(Expression.Constant(null),
readerExpression.Type));
return conditionalExpression;
}
private Expression GetReaderExpression(Expression m, ConstantExpression field) {
MethodInfo getReaderMethod = GetReaderMethod(m);
var readerExpression = Expression.Call(reader, getReaderMethod, field);
if (getReaderMethod.ReturnType == m.Type) {
return readerExpression;
}
return Expression.Convert(readerExpression, m.Type);
}
private static MethodInfo GetReaderMethod(Expression m) {
Type memberType = GetMemberType(m);
MethodInfo getMethod = null;
switch (Type.GetTypeCode(memberType)) {
case TypeCode.Boolean:
getMethod = getBoolean;
break;
case TypeCode.Byte:
getMethod = getByte;
break;
case TypeCode.Char:
getMethod = getChar;
break;
case TypeCode.DateTime:
getMethod = getDateTime;
break;
case TypeCode.Decimal:
getMethod = getDecimal;
break;
case TypeCode.Double:
getMethod = getDouble;
break;
case TypeCode.Int16:
getMethod = getInt16;
break;
case TypeCode.Int32:
getMethod = getInt32;
break;
case TypeCode.Int64:
getMethod = getInt64;
break;
case TypeCode.String:
getMethod = getString;
break;
case TypeCode.Object:
getMethod = getValue;
break;
default:
if (m.Type == typeof(Guid)) {
getMethod = getGUID;
}
else {
getMethod = getValue;
}
break;
}
return getMethod;
}
private int GetFieldPosition(MemberExpression m) {
return GetFieldPosition(GetAccessedType(m).GUID + m.ToString());
}
private int GetFieldPosition(string fieldName) {
int fieldPosition = 0;
if (fieldPositions.ContainsKey(fieldName)) {
fieldPosition = fieldPositions[fieldName];
return fieldPosition;
}
fieldPosition = fieldPositions.Count();
fieldPositions.Add(fieldName, fieldPosition);
return fieldPosition;
}
private static Type GetMemberType(Expression m) {
Type memberType = null;
if (m.Type.Name == "Nullable`1") {
memberType = m.Type.GetGenericArguments()[0];
}
else {
memberType = m.Type;
}
return memberType;
}
private static Type GetAccessedType(MemberExpression m) {
if (m.Expression.NodeType == ExpressionType.MemberAccess) {
return GetAccessedType((MemberExpression)m.Expression);
}
return m.Expression.Type;
}
private LambdaExpression GetBindingLambda(LambdaExpression selector) {
var instanceType = selector.Body.Type;
var properties = (from property in instanceType.GetProperties()
where property.PropertyType.IsValueType ||
property.PropertyType == typeof(string)
orderby property.Name
select instanceType.GetProperty(property.Name,
BindingFlags.Instance |
BindingFlags.Public |
BindingFlags.NonPublic))
.ToArray();
if (properties.Any(p => p.CanWrite == false)) {
return GetBindingConstructor(instanceType);
}
var bindings = new MemberBinding[properties.Length];
for (int i = 0; i < properties.Length; i++) {
var callMethod = GetFieldReader(
Expression.MakeMemberAccess(
Expression.Parameter(instanceType, "param"),
properties[i]),
i);
bindings[i] = Expression.Bind(properties[i], callMethod);
}
return Expression.Lambda(Expression.MemberInit(Expression.New(instanceType),
bindings),
reader);
}
private LambdaExpression GetBindingConstructor(Type instanceType) {
var constructors = instanceType.GetConstructors();
Debug.Assert(constructors.Length == 1);
var constructorParameters = constructors[0].GetParameters();
var arguments = new List<Expression>();
for (int i = 0; i < constructorParameters.Length; i++) {
var callMethod = GetReaderExpression(
Expression.Parameter(constructorParameters[i].ParameterType, "arg"),
Expression.Constant(i));
arguments.Add(callMethod);
}
var constructor = Expression.New(constructors[0], arguments);
return Expression.Lambda(constructor,
reader);
}
private static T Convert<T>(Expression m) {
var methodCall = m as MethodCallExpression;
Debug.Assert(methodCall != null);
return (T)Expression.Lambda(methodCall).Compile().DynamicInvoke();
}
}
private class Executor<T> : ExpressionVisitor, IEnumerable<T> {
private readonly DbConnection cachedConnection = null;
private readonly SqlExpressionParser sqlExpressionParser = null;
private readonly Func<DbDataReader, T> binder = null;
private readonly List<object> parameters = new List<object>();
private List<T> result = null;
public Executor(DbConnection connection,
SqlExpressionParser sqlExpressionParser,
Expression expression,
Delegate binder) {
Debug.Assert(connection != null && sqlExpressionParser != null);
Debug.Assert(binder as Func<DbDataReader, T> != null);
this.Visit(expression);
this.cachedConnection = connection;
this.sqlExpressionParser = sqlExpressionParser;
this.binder = (Func<DbDataReader, T>)binder;
}
public IEnumerator<T> GetEnumerator() {
GetResult();
foreach (var element in result) {
yield return element;
}
}
private void GetResult() {
if (result != null) {
return;
}
result = new List<T>();
var connection = GetConnection();
connection.Open();
DbCommand cmd = connection.CreateCommand();
cmd.CommandText = sqlExpressionParser.GetSQLStatement();
PopulateParameters(cmd);
DbDataReader reader = cmd.ExecuteReader();
if (!reader.HasRows) {
reader.Close();
return;
}
while (reader.Read()) {
result.Add(binder(reader));
}
reader.Close();
connection.Close();
}
private DbConnection GetConnection() {
return (DbConnection)(cachedConnection as ICloneable).Clone();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() {
return this.GetEnumerator();
}
private void PopulateParameters(DbCommand cmd) {
for (int i = 0; i < parameters.Count; i++) {
var parameter = cmd.CreateParameter();
parameter.ParameterName = "@p" + (parameters.Count - (i + 1));
parameter.Value = parameters[i];
cmd.Parameters.Add(parameter);
}
}
protected override Expression VisitConstant(ConstantExpression c) {
if (c.Value == null) {
parameters.Insert(0, "NULL");
}
else {
switch (Type.GetTypeCode(c.Value.GetType())) {
case TypeCode.Object:
break;
case TypeCode.Boolean:
if ((bool)c.Value) { // true -> 1
parameters.Add(1);
}
else { // false -> 0
parameters.Add(0);
}
break;
default:
parameters.Add(c.Value);
break;
}
}
return c;
}
protected override Expression VisitConditional(ConditionalExpression c) {
Debug.Assert(c.Test as ConstantExpression != null);
if ((bool)(c.Test as ConstantExpression).Value == true) {
return this.Visit(c.IfTrue);
}
return this.Visit(c.IfFalse);
}
}
private class ConstantEnumerable<T> : IEnumerable<T> {
public ConstantEnumerable() {
}
public IEnumerator<T> GetEnumerator() {
yield break;
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() {
return this.GetEnumerator();
}
}
private class MakeSelector : ExpressionVisitor {
private readonly LambdaExpression selector;
private readonly ParameterExpression leftSource;
private readonly ParameterExpression rightSource;
private readonly ParameterExpression selectorParam;
private readonly PropertyInfo leftMemberInfo;
private readonly PropertyInfo rightMemberInfo;
public LambdaExpression Selector {
get {
return selector;
}
}
public MakeSelector(LambdaExpression selector) {
Debug.Assert(selector.Parameters.Count == 2);
this.leftSource = selector.Parameters[0];
this.rightSource = selector.Parameters[1];
selectorParam = Expression.Parameter(typeof(KeyValuePair<int, int>)
.GetGenericTypeDefinition()
.MakeGenericType(leftSource.Type,
rightSource.Type),
"source");
leftMemberInfo = selectorParam.Type.GetProperty("Key");
rightMemberInfo = selectorParam.Type.GetProperty("Value");
this.selector =
Expression.Lambda(
((LambdaExpression)this.VisitLambda(selector)).Body,
selectorParam);
}
protected override Expression VisitMemberAccess(MemberExpression m) {
if (m.Expression == leftSource || m.Expression == rightSource) {
var c = Expression.MakeMemberAccess(selectorParam,
m.Expression == leftSource ? leftMemberInfo : rightMemberInfo);
return Expression.MakeMemberAccess(c, m.Member);
}
return base.VisitMemberAccess(m);
}
}
}
}