Click here to Skip to main content
15,884,298 members
Articles / Desktop Programming / WPF

Building WPF Applications with Self-Tracking Entity Generator and Visual Studio 2012 - Project Setup

Rate me:
Please Sign up or sign in to vote.
5.00/5 (14 votes)
17 Mar 2013CPOL8 min read 68.4K   3.5K   44  
This article describes the project setup of building a WPF sample application with Self-Tracking Entity Generator and Visual Studio 2012.
<#@ template language="C#" debug="false" hostspecific="true"#>
<#@ include file="EF.Utility.CS.ttinclude"#><#@
 output extension=".cs"#><#

DefineMetadata();

CodeGenerationTools code = new CodeGenerationTools(this);
MetadataTools ef = new MetadataTools(this);
MetadataLoader loader = new MetadataLoader(this);
CodeRegion region = new CodeRegion(this);

EntityFrameworkTemplateFileManager fileManager = EntityFrameworkTemplateFileManager.Create(this);

string inputFile = @"SchoolModel.edmx";
EdmItemCollection ItemCollection = loader.CreateEdmItemCollection(inputFile);
ModelNamespace = loader.GetModelNamespace(inputFile);
DefaultSummaryComment = CodeGenerationTools.GetResourceString("Template_CommentNoDocumentation");
string namespaceName = code.EscapeNamespace(@"SchoolSample.EntityModel");
UpdateObjectNamespaceMap(namespaceName);
string codeForClientQuery = @"true";

EntityContainer container = ItemCollection.GetItems<EntityContainer>().FirstOrDefault();
if (container == null)
{
    return "// No EntityContainer exists in the model, so no code was generated";
}

WriteHeader(fileManager);
BeginNamespace(namespaceName, code);

#>
<#=Accessibility.ForType(container)#> partial class <#=code.Escape(container)#> : ObjectContext
{
    public const string ConnectionString = "name=<#=container.Name#>";
    public const string ContainerName = "<#=container.Name#>";

<#
region.Begin("Constructors", 2);
#>

    public <#=code.Escape(container)#>()
        : base(ConnectionString, ContainerName)
    {
        Initialize();
    }

    public <#=code.Escape(container)#>(string connectionString)
        : base(connectionString, ContainerName)
    {
        Initialize();
    }

    public <#=code.Escape(container)#>(EntityConnection connection)
        : base(connection, ContainerName)
    {
        Initialize();
    }

    private void Initialize()
    {
        // Creating proxies requires the use of the ProxyDataContractResolver and
        // may allow lazy loading which can expand the loaded graph during serialization.
        ContextOptions.ProxyCreationEnabled = false;
        ObjectMaterialized += new ObjectMaterializedEventHandler(HandleObjectMaterialized);
    }

    private void HandleObjectMaterialized(object sender, ObjectMaterializedEventArgs e)
    {
        var entity = e.Entity as IObjectWithChangeTracker;
        if (entity != null)
        {
            bool changeTrackingEnabled = entity.ChangeTracker.ChangeTrackingEnabled;
            try
            {
                entity.MarkAsUnchanged();
            }
            finally
            {
                entity.ChangeTracker.ChangeTrackingEnabled = changeTrackingEnabled;
            }
            this.StoreReferenceKeyValues(entity);
        }
    }
<#
        region.End();

        region.Begin("ObjectSet Properties");

        foreach (EntitySet entitySet in container.BaseEntitySets.OfType<EntitySet>())
        {
#>

    <#=Accessibility.ForReadOnlyProperty(entitySet)#> ObjectSet<<#=code.GetTypeName(entitySet.ElementType)#>> <#=code.Escape(entitySet)#>
    {
        get { return <#=code.FieldName(entitySet) #>  ?? (<#=code.FieldName(entitySet)#> = CreateObjectSet<<#=code.GetTypeName(entitySet.ElementType)#>>("<#=entitySet.Name#>")); }
    }
    private ObjectSet<<#=code.GetTypeName(entitySet.ElementType)#>> <#=code.FieldName(entitySet)#>;
<#
        }

        region.End();

        region.Begin("Function Imports");

        foreach (EdmFunction edmFunction in container.FunctionImports)
        {
            var parameters = FunctionImportParameter.Create(edmFunction.Parameters, code, ef);
            string paramList = String.Join(", ", parameters.Select(p => p.FunctionParameterType + " " + p.FunctionParameterName).ToArray());
            TypeUsage returnType = edmFunction.ReturnParameters.Count == 0 ? null : ef.GetElementType(edmFunction.ReturnParameters[0].TypeUsage);
            if (edmFunction.IsComposableAttribute)
            {
#>

    /// <summary>
    /// <#=SummaryComment(edmFunction)#>
    /// </summary><#=LongDescriptionCommentElement(edmFunction, region.CurrentIndentLevel)#><#=ParameterComments(parameters.Select(p => new Tuple<string, string>(p.RawFunctionParameterName, SummaryComment(p.Source))), region.CurrentIndentLevel)#>
    [EdmFunction("<#=edmFunction.NamespaceName#>", "<#=edmFunction.Name#>")]
    <#=code.SpaceAfter(NewModifier(edmFunction))#><#=AccessibilityAndVirtual(Accessibility.ForMethod(edmFunction))#> <#="IQueryable<" + code.GetTypeName(returnType, ModelNamespace) + ">"#> <#=code.Escape(edmFunction)#>(<#=paramList#>)
    {
<#
                WriteFunctionParameters(parameters);
#>
        return base.CreateQuery<<#=code.GetTypeName(returnType, ModelNamespace)#>>("[<#=edmFunction.NamespaceName#>].[<#=edmFunction.Name#>](<#=string.Join(", ", parameters.Select(p => "@" + p.EsqlParameterName).ToArray())#>)"<#=code.StringBefore(", ", string.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
    }
<#
            }
            else
            {
#>

    /// <summary>
    /// <#=SummaryComment(edmFunction)#>
    /// </summary><#=LongDescriptionCommentElement(edmFunction, region.CurrentIndentLevel)#><#=ParameterComments(parameters.Select(p => new Tuple<string, string>(p.RawFunctionParameterName, SummaryComment(p.Source))), region.CurrentIndentLevel)#>
    <#=code.SpaceAfter(NewModifier(edmFunction))#><#=AccessibilityAndVirtual(Accessibility.ForMethod(edmFunction))#> <#=returnType == null ? "int" : "ObjectResult<" + code.GetTypeName(returnType, ModelNamespace) + ">"#> <#=code.Escape(edmFunction)#>(<#=paramList#>)
    {
<#
                WriteFunctionParameters(parameters);
#>
        return base.ExecuteFunction<#=returnType == null ? "" : "<" + code.GetTypeName(returnType, ModelNamespace) + ">"#>("<#=edmFunction.Name#>"<#=code.StringBefore(", ", String.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
    }
<#
                if(returnType != null && returnType.EdmType.BuiltInTypeKind == BuiltInTypeKind.EntityType)
                {
#>

    /// <summary>
    /// <#=SummaryComment(edmFunction)#>
    /// </summary><#=LongDescriptionCommentElement(edmFunction, region.CurrentIndentLevel)#>
    /// <param name="mergeOption"></param><#=ParameterComments(parameters.Select(p => new Tuple<string, string>(p.RawFunctionParameterName, SummaryComment(p.Source))), region.CurrentIndentLevel)#>
    <#=code.SpaceAfter(NewModifier(edmFunction))#><#=Accessibility.ForMethod(edmFunction)#> <#=returnType == null ? "int" : "ObjectResult<" + code.GetTypeName(returnType, ModelNamespace) + ">"#> <#=code.Escape(edmFunction)#>(<#=code.StringAfter(paramList, ", ")#>MergeOption mergeOption)
    {
<#
                    WriteFunctionParameters(parameters);
#>
        return base.<#=returnType == null ? "ExecuteFunction" : "ExecuteFunction<" + code.GetTypeName(returnType, ModelNamespace) + ">"#>("<#=edmFunction.Name#>", mergeOption<#=code.StringBefore(", ", string.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
    }
<#
                }
            }
        }
        region.End();
#>
}
<#
    if (codeForClientQuery == "true")
    {
#>

public static class ObjectQueryExtension
{
    /// <summary>
    /// ApplyIncludePath takes the list of include paths from a CleintQuery object 
    /// and calls Include() on the ObjectQuery.
    /// </summary>
    /// <typeparam name="TEntity">Expected type of the ObjectQuery</typeparam>
    /// <param name="source">The ObjectQuery to which the list of include paths will be applied</param>
    /// <param name="clientQuery">The ClientQuery object that contains the list of include paths</param>
    /// <returns></returns>
    public static ObjectQuery<TEntity> ApplyIncludePath<TEntity>(this ObjectQuery<TEntity> source,
                                                                 ClientQuery clientQuery) where TEntity : class
    {
        if (source == null)
        {
            throw new ArgumentNullException("source");
        }
        if (clientQuery == null)
        {
            throw new ArgumentNullException("clientQuery");
        }
        return clientQuery.IncludeList.Aggregate(source, (current, includeItem) => current.Include(includeItem));
    }

    /// <summary>
    /// ApplyClientQuery takes both the list of include paths and serialized Expression tree
    ///  from a CleintQuery object and applies on the ObjectQuery.
    /// </summary>
    /// <typeparam name="TEntity">Expected type of the ObjectQuery</typeparam>
    /// <param name="source">The ObjectQuery to which both the list of include paths and serialized Expression tree will be applied</param>
    /// <param name="clientQuery">The ClientQuery object that contains both the list of include paths and serialized Expression tree</param>
    /// <param name="assemblies">Additional assemblies where types will be searched</param>
    public static ObjectQuery<TEntity> ApplyClientQuery<TEntity>(this ObjectQuery<TEntity> source,
                                                                 ClientQuery clientQuery,
                                                                 IEnumerable<Assembly> assemblies = null)
        where TEntity : class
    {
        source = source.ApplyIncludePath(clientQuery);
        var sourceExpressionVisitor = new SourceExpressionVisitor();

        var sourceExpression = sourceExpressionVisitor.Visit(((IOrderedQueryable) source).Expression);
        var deserializer = new Deserializer(sourceExpression, assemblies);
        var expression = deserializer.Deserialize(clientQuery.XmlExpression);
        var lambda = Expression.Lambda<Func<<#=code.Escape(container)#>, IQueryable<TEntity>>>(expression,
                                                                                  sourceExpressionVisitor.
                                                                                      ContextParameterExpression);
        var compiledQuery = CompiledQuery.Compile(lambda);
        return (ObjectQuery<TEntity>) compiledQuery.Invoke((<#=code.Escape(container)#>) source.Context);
    }
}

internal class SourceExpressionVisitor : ExpressionVisitor
{
    #region Private Data Member

    private ParameterExpression _contextParameterExpression;

    #endregion Private Data Member

    #region Public Property

    public ParameterExpression ContextParameterExpression
    {
        get
        {
            if (_contextParameterExpression == null)
            {
                throw new Exception("ContextParameterExpression cannot be null.");
            }
            return _contextParameterExpression;
        }
    }

    #endregion Public Property

    #region Protected & Private Method

    protected override Expression VisitConstant(ConstantExpression node)
    {
        if (typeof (IOrderedQueryable).IsAssignableFrom(node.Type))
        {
            var elementType = ((IOrderedQueryable) (node).Value).ElementType;
            return GetExpression(elementType);
        }
        return base.VisitConstant(node);
    }

    private Expression GetExpression(Type elementType)
    {
<#
        foreach (EntityType entity in ItemCollection.GetItems<EntityType>().OrderBy(e => e.Name))
        {
            if (entity.BaseType == null)
            {
#>
        if (elementType == typeof (<#=code.Escape(entity)#>))
        {
            Expression<Func<<#=code.Escape(container)#>, ObjectQuery<<#=code.Escape(entity)#>>>> query =
                context => context.<#=code.Escape(GetEntitySetFromEntityType(container, entity))#>;
            _contextParameterExpression = query.Parameters[0];
            return query.Body;
        }
<#
            }
            else
            {
#>
        if (elementType == typeof (<#=code.Escape(entity)#>))
        {
            Expression<Func<<#=code.Escape(container)#>, ObjectQuery<<#=code.Escape(entity)#>>>> query =
                context => context.<#=code.Escape(GetEntitySetFromEntityType(container, entity))#>.OfType<<#=code.Escape(entity)#>>();
            _contextParameterExpression = query.Parameters[0];
            return query.Body;
        }
<#
            }
        }
#>
        throw new Exception("no match for elementType type.");
    }

    #endregion Protected & Private Method
}
<#
        WriteDeserializer();
        WriteTypeResolver();
    }

    EndNamespace(namespaceName);

    fileManager.StartNewFile(Path.GetFileNameWithoutExtension(Host.TemplateFile) + ".Extensions.cs");
    BeginNamespace(namespaceName, code);
    WriteApplyChanges(code);
    EndNamespace(namespaceName);

    fileManager.Process();
#>

<#+
private void WriteHeader(EntityFrameworkTemplateFileManager fileManager, params string[] extraUsings)
{
    fileManager.StartHeader();
#>
//------------------------------------------------------------------------------
// <auto-generated>
//     This code was generated from a template.
//
//     Changes to this file may cause incorrect behavior and will be lost if
//     the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Data.Common;
using System.Data.EntityClient;
using System.Data.Metadata.Edm;
using System.Data.Objects;
using System.Data.Objects.DataClasses;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.Serialization;
using System.Text;
using System.Threading;
using System.Xml.Linq;
<#=String.Join(String.Empty, extraUsings.Select(u => "using " + u + ";" + Environment.NewLine).ToArray())#>
<#+
    fileManager.EndBlock();
}

void BeginNamespace(string namespaceName, CodeGenerationTools code)
{
    CodeRegion region = new CodeRegion(this);
    if (!String.IsNullOrEmpty(namespaceName))
    {
#>
namespace <#=code.EscapeNamespace(namespaceName)#>
{
<#+
        PushIndent(CodeRegion.GetIndent(1));
    }
}

void EndNamespace(string namespaceName)
{
    if (!String.IsNullOrEmpty(namespaceName))
    {
        PopIndent();
#>
}
<#+
    }
}

string AccessibilityAndVirtual(string accessibility)
{
    if (accessibility != "private")
    {
        return accessibility + " virtual";
    }

    return accessibility;
}

EntitySet GetEntitySetFromEntityType(EntityContainer container, EntityType entity)
{
    EdmType edmType = entity;
    while (edmType.BaseType != null) edmType = edmType.BaseType;
    return container.BaseEntitySets.OfType<EntitySet>()
        .Single(set => set.ElementType.Name.Equals(edmType.Name));
}

void WriteDeserializer()
{
#>

internal class Deserializer
{
    #region Private Data Member

    private readonly Dictionary<string, ParameterExpression> _parameters =
        new Dictionary<string, ParameterExpression>();

    private readonly Expression _substituteExpression;
    private readonly TypeResolver _resolver;

    #endregion Private Data Member

    #region Constructor

    public Deserializer(Expression substituteExpression, IEnumerable<Assembly> assemblies = null)
    {
        _substituteExpression = substituteExpression;
        _resolver = new TypeResolver(assemblies);
    }

    #endregion Constructor

    #region Public Deserializer Method

    public Expression Deserialize(XElement xml)
    {
        _parameters.Clear();
        return ParseExpressionFromXmlNonNull(xml);
    }

    #endregion Public Deserializer Method

    #region Private Deserializer Method

    private Expression ParseExpressionFromXmlNonNull(XElement xml)
    {
        if (xml.Name.LocalName == "ClientQuery" && xml.Attribute("elementType") != null)
        {
            return _substituteExpression;
        }
        switch (xml.Name.LocalName)
        {
            case "BinaryExpression":
                return ParseBinaryExpressionFromXml(xml);
            case "ConditionalExpression":
                return ParseConditionalExpressionFromXml(xml);
            case "ConstantExpression":
                return ParseConstantExpressionFromXml(xml);
            case "InvocationExpression":
                return ParseInvocationExpressionFromXml(xml);
            case "LambdaExpression":
                return ParseLambdaExpressionFromXml(xml);
            case "ListInitExpression":
                return ParseListInitExpressionFromXml(xml);
            case "MemberExpression":
                return ParseMemberExpressionFromXml(xml);
            case "MemberInitExpression":
                return ParseMemberInitExpressionFromXml(xml);
            case "MethodCallExpression":
                return ParseMethodCallExpressionFromXml(xml);
            case "NewArrayExpression":
                return ParseNewArrayExpressionFromXml(xml);
            case "NewExpression":
                return ParseNewExpressionFromXml(xml);
            case "ParameterExpression":
                return ParseParameterExpressionFromXml(xml);
            case "TypeBinaryExpression":
                return ParseTypeBinaryExpressionFromXml(xml);
            case "UnaryExpression":
                return ParseUnaryExpressionFromXml(xml);
            default:
                throw new NotSupportedException(xml.Name.LocalName);
        }
    }

    /// <summary>
    /// Parse BinaryExpression From XElement 
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private BinaryExpression ParseBinaryExpressionFromXml(XElement xml)
    {
        var right = ParseExpressionFromXml(xml.Element("Right"));
        var left = ParseExpressionFromXml(xml.Element("Left"));
        var method = ParseMethodInfoFromXml(xml.Element("Method"));
        var conversion = ParseExpressionFromXml(xml.Element("Conversion")) as LambdaExpression;
        var isLiftedToNull = ParseWithDataContractSerializer<bool>(xml.Element("IsLiftedToNull"));
        var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
        return Expression.MakeBinary(expressionType, left, right, isLiftedToNull, method, conversion);
    }

    /// <summary>
    /// Parse ConditionalExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private ConditionalExpression ParseConditionalExpressionFromXml(XElement xml)
    {
        var test = ParseExpressionFromXml(xml.Element("Test"));
        var ifTrue = ParseExpressionFromXml(xml.Element("IfTrue"));
        var ifFalse = ParseExpressionFromXml(xml.Element("IfFalse"));
        var type = ParseTypeFromXml(xml.Element("Type"));
        return Expression.Condition(test, ifTrue, ifFalse, type);
    }

    /// <summary>
    /// Parse ConstantExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private ConstantExpression ParseConstantExpressionFromXml(XElement xml)
    {
        var type = ParseTypeFromXml(xml.Element("Type"));
        var result = ParseWithDataContractSerializer(xml.Element("Value"), type);
        return Expression.Constant(result, type);
    }

    /// <summary>
    /// Parse InvocationExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private InvocationExpression ParseInvocationExpressionFromXml(XElement xml)
    {
        var expression = ParseExpressionFromXml(xml.Element("Expression"));
        var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments");
        return Expression.Invoke(expression, arguments);
    }

    /// <summary>
    /// Parse LambdaExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private LambdaExpression ParseLambdaExpressionFromXml(XElement xml)
    {
        var type = ParseTypeFromXml(xml.Element("Type"));
        var body = ParseExpressionFromXml(xml.Element("Body"));
        var name = ParseWithDataContractSerializer<string>(xml.Element("Name"));
        var tailCall = ParseWithDataContractSerializer<bool>(xml.Element("TailCall"));
        var parameters = ParseExpressionListFromXml<ParameterExpression>(xml, "Parameters");
        return Expression.Lambda(type, body, name, tailCall, parameters);
    }

    /// <summary>
    /// Parse ListInitExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private ListInitExpression ParseListInitExpressionFromXml(XElement xml)
    {
        var newExpression = ParseExpressionFromXml(xml.Element("NewExpression")) as NewExpression;
        if (newExpression == null) throw new Exception("Expceted a NewExpression");
        var initializers = ParseElementInitListFromXml(xml, "Initializers");
        return Expression.ListInit(newExpression, initializers);
    }

    /// <summary>
    /// Parse MemberExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private MemberExpression ParseMemberExpressionFromXml(XElement xml)
    {
        var expression = ParseExpressionFromXml(xml.Element("Expression"));
        var member = ParseMemberInfoFromXml(xml.Element("Member"));
        return Expression.MakeMemberAccess(expression, member);
    }

    /// <summary>
    /// Parse MemberInitExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private MemberInitExpression ParseMemberInitExpressionFromXml(XElement xml)
    {
        var newExpression = ParseExpressionFromXml(xml.Element("NewExpression")) as NewExpression;
        if (newExpression == null) throw new Exception("Expceted a NewExpression");
        var bindings = ParseMemberBindingListFromXml(xml, "Bindings");
        if (bindings == null) throw new Exception("Bindings cannot be null");
        return Expression.MemberInit(newExpression, bindings);
    }

    /// <summary>
    /// Parse MethodCallExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private MethodCallExpression ParseMethodCallExpressionFromXml(XElement xml)
    {
        var instance = ParseExpressionFromXml(xml.Element("Object"));
        var method = ParseMethodInfoFromXml(xml.Element("Method"));
        var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments") ?? new Expression[0];
        return Expression.Call(instance, method, arguments);
    }

    /// <summary>
    /// Parse NewArrayExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private NewArrayExpression ParseNewArrayExpressionFromXml(XElement xml)
    {
        var type = ParseTypeFromXml(xml.Element("Type"));
        if (type == null) throw new Exception("Type cannot be null");
        if (!type.IsArray) throw new Exception("Expected array type");
        var elemType = type.GetElementType();
        var expressions = ParseExpressionListFromXml<Expression>(xml, "Expressions");
        var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
        switch (expressionType)
        {
            case ExpressionType.NewArrayInit:
                return Expression.NewArrayInit(elemType, expressions);
            case ExpressionType.NewArrayBounds:
                return Expression.NewArrayBounds(elemType, expressions);
            default:
                throw new Exception("Expected NewArrayInit or NewArrayBounds");
        }
    }

    /// <summary>
    /// Parse NewExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private NewExpression ParseNewExpressionFromXml(XElement xml)
    {
        var constructor = ParseConstructorInfoFromXml(xml.Element("Constructor"));
        var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments").ToArray();
        var members = ParseMemberInfoListFromXml<MemberInfo>(xml, "Members").ToArray();
        return members.Length == 0
                   ? Expression.New(constructor, arguments)
                   : Expression.New(constructor, arguments, members);
    }

    /// <summary>
    /// Parse ParameterExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private ParameterExpression ParseParameterExpressionFromXml(XElement xml)
    {
        var type = ParseTypeFromXml(xml.Element("Type"));
        var name = ParseWithDataContractSerializer<string>(xml.Element("Name"));
        var id = name + type.FullName;
        if (!_parameters.ContainsKey(id))
            _parameters.Add(id, Expression.Parameter(type, name));
        return _parameters[id];
    }

    /// <summary>
    /// Parse TypeBinaryExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private TypeBinaryExpression ParseTypeBinaryExpressionFromXml(XElement xml)
    {
        var expression = ParseExpressionFromXml(xml.Element("Expression"));
        var typeOperand = ParseTypeFromXml(xml.Element("TypeOperand"));
        var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
        switch (expressionType)
        {
            case ExpressionType.TypeEqual:
                return Expression.TypeEqual(expression, typeOperand);
            case ExpressionType.TypeIs:
                return Expression.TypeIs(expression, typeOperand);
            default:
                throw new Exception("Expected TypeEqual or TypeIs");
        }
    }

    /// <summary>
    /// Parse UnaryExpression From XElement
    /// </summary>
    /// <param name="xml"></param>
    /// <returns></returns>
    private UnaryExpression ParseUnaryExpressionFromXml(XElement xml)
    {
        var operand = ParseExpressionFromXml(xml.Element("Operand"));
        var type = ParseTypeFromXml(xml.Element("Type"));
        var method = ParseMethodInfoFromXml(xml.Element("Method"));
        var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
        return Expression.MakeUnary(expressionType, operand, type, method);
    }

    #endregion Private Deserializer Method

    #region Private Deserializer Helper Method

    private ConstructorInfo ParseConstructorInfoFromXml(XElement xml)
    {
        if (xml.IsEmpty) return null;
        var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
        var xElement = xml.Element("Parameters");
        if (xElement == null) throw new Exception("Parameters not found.");
        var parameters = from paramXml in xElement.Elements()
                         select ParseTypeFromXml(paramXml);
        return declaringType.GetConstructor(parameters.ToArray());
    }

    private ElementInit ParseElementInitFromXml(XElement xml)
    {
        var addMethod = ParseMethodInfoFromXml(xml.Element("AddMethod"));
        var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments");
        return Expression.ElementInit(addMethod, arguments);
    }

    private Expression ParseExpressionFromXml(XElement xml)
    {
        return xml.IsEmpty ? null : ParseExpressionFromXmlNonNull(xml.Elements().First());
    }

    private FieldInfo ParseFieldInfoFromXml(XElement xml)
    {
        var xAttribute = xml.Attribute("FieldName");
        if (xAttribute == null) throw new Exception("FieldName not found.");
        var fieldName = xAttribute.Value;
        var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
        return declaringType.GetField(fieldName);
    }

    private MemberBinding ParseMemberBindingFromXml(XElement xml)
    {
        var member = ParseMemberInfoFromXml(xml.Element("Member"));
        switch (xml.Name.LocalName)
        {
            case "MemberAssignment":
                var expression = ParseExpressionFromXml(xml.Element("Expression"));
                return Expression.Bind(member, expression);
            case "MemberMemberBinding":
                var bindings = ParseMemberBindingListFromXml(xml, "Bindings");
                return Expression.MemberBind(member, bindings);
            case "MemberListBinding":
                var initializers = ParseElementInitListFromXml(xml, "Initializers");
                return Expression.ListBind(member, initializers);
        }
        throw new NotImplementedException();
    }

    private MemberInfo ParseMemberInfoFromXml(XElement xml)
    {
        var xAttribute = xml.Attribute("MemberType");
        if (xAttribute == null) throw new Exception("MemberType not found.");
        var memberType = (MemberTypes) Enum.Parse(typeof (MemberTypes), xAttribute.Value, false);
        switch (memberType)
        {
            case MemberTypes.Field:
                return ParseFieldInfoFromXml(xml);
            case MemberTypes.Property:
                return ParsePropertyInfoFromXml(xml);
            case MemberTypes.Method:
                return ParseMethodInfoFromXml(xml);
            case MemberTypes.Constructor:
                return ParseConstructorInfoFromXml(xml);
            default:
                throw new NotSupportedException(string.Format("MEmberType {0} not supported", memberType));
        }
    }

    private MethodInfo ParseMethodInfoFromXml(XElement xml)
    {
        if (xml.IsEmpty) return null;
        var xAttribute = xml.Attribute("MethodName");
        if (xAttribute == null) throw new Exception("MethodName not found.");
        var name = xAttribute.Value;
        var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
        var xElement = xml.Element("Parameters");
        if (xElement == null) throw new Exception("Parameters not found.");
        var parameters = from paramXml in xElement.Elements()
                         select ParseTypeFromXml(paramXml);
        xElement = xml.Element("GenericArgTypes");
        if (xElement == null) throw new Exception("GenericArgTypes not found.");
        var genArgs = from argXml in xElement.Elements()
                      select ParseTypeFromXml(argXml);
        return _resolver.GetMethod(declaringType, name, parameters.ToArray(), genArgs.ToArray());
    }

    private PropertyInfo ParsePropertyInfoFromXml(XElement xml)
    {
        var xAttribute = xml.Attribute("PropertyName");
        if (xAttribute == null) throw new Exception("PropertyName not found.");
        var propertyName = xAttribute.Value;
        var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
        return declaringType.GetProperty(propertyName);
    }

    private Type ParseTypeFromXml(XElement xml)
    {
        return ParseTypeFromXmlCore(xml.Elements().First());
    }

    private IEnumerable<ElementInit> ParseElementInitListFromXml(XElement xml, string elemName)
    {
        var xElement = xml.Element(elemName);
        if (xElement == null) throw new Exception(elemName + " not found.");
        return from tXml in xElement.Elements()
               select ParseElementInitFromXml(tXml);
    }

    private IEnumerable<T> ParseExpressionListFromXml<T>(XElement xml, string elemName) where T : Expression
    {
        var xElement = xml.Element(elemName);
        if (xElement == null) throw new Exception(elemName + " not found.");
        return from tXml in xElement.Elements()
               select (T) ParseExpressionFromXmlNonNull(tXml);
    }

    private IEnumerable<MemberBinding> ParseMemberBindingListFromXml(XElement xml, string elemName)
    {
        var xElement = xml.Element(elemName);
        if (xElement == null) throw new Exception(elemName + " not found.");
        return from tXml in xElement.Elements()
               select ParseMemberBindingFromXml(tXml);
    }

    private IEnumerable<T> ParseMemberInfoListFromXml<T>(XElement xml, string elemName) where T : MemberInfo
    {
        var xElement = xml.Element(elemName);
        if (xElement == null) throw new Exception(elemName + " not found.");
        return from tXml in xElement.Elements()
               select (T) ParseMemberInfoFromXml(tXml);
    }

    private Type ParseTypeFromXmlCore(XElement xml)
    {
        switch (xml.Name.ToString())
        {
            case "Type":
                return ParseNormalTypeFromXmlCore(xml);
            case "AnonymousType":
                return ParseAnonymousTypeFromXmlCore(xml);
            default:
                throw new ArgumentException("Expected 'Type' or 'AnonymousType'");
        }
    }

    private Type ParseNormalTypeFromXmlCore(XElement xml)
    {
        if (!xml.HasElements)
        {
            var xAttribute = xml.Attribute("Name");
            if (xAttribute == null) throw new Exception("Name not found.");
            return _resolver.GetType(xAttribute.Value);
        }

        var genericArgumentTypes = from genArgXml in xml.Elements()
                                   select ParseTypeFromXmlCore(genArgXml);
        var attribute = xml.Attribute("Name");
        if (attribute == null) throw new Exception("Name not found.");
        return _resolver.GetType(attribute.Value, genericArgumentTypes);
    }

    private Type ParseAnonymousTypeFromXmlCore(XElement xElement)
    {
        var xAttribute = xElement.Attribute("Name");
        if (xAttribute == null) throw new Exception("Name not found.");
        var name = xAttribute.Value;
        var properties = from propXml in xElement.Elements("Property")
                         let attribute = propXml.Attribute("Name")
                         where attribute != null
                         select new TypeResolver.NameTypePair
                                    {
                                        Name = attribute.Value,
                                        Type = ParseTypeFromXml(propXml)
                                    };
        var ctrParams = from propXml in xElement.Elements("Constructor").Elements("Parameter")
                        let attribute = propXml.Attribute("Name")
                        where attribute != null
                        select new TypeResolver.NameTypePair
                                   {
                                       Name = attribute.Value,
                                       Type = ParseTypeFromXml(propXml)
                                   };
        return _resolver.GetOrCreateAnonymousType(name, properties.ToArray(), ctrParams.ToArray());
    }

    private T ParseWithDataContractSerializer<T>(XElement xml)
    {
        if (xml == null) throw new Exception("XElement cannot be null.");
        using (var stream = new MemoryStream(Encoding.UTF8.GetBytes(xml.Value)))
        {
            var deserializer = new DataContractSerializer(typeof (T));
            return (T) deserializer.ReadObject(stream);
        }
    }

    private object ParseWithDataContractSerializer(XElement xml, Type expectedType)
    {
        if (xml == null) throw new Exception("XElement cannot be null.");
        using (var stream = new MemoryStream(Encoding.UTF8.GetBytes(xml.Value)))
        {
            var deserializer = new DataContractSerializer(expectedType);
            return deserializer.ReadObject(stream);
        }
    }

    #endregion Private Deserializer Helper Method
}
<#+
}

void WriteTypeResolver()
{
#>

internal sealed class TypeResolver
{
    #region Private Data Member

    private readonly Dictionary<AnonTypeId, Type> _anonymousTypes = new Dictionary<AnonTypeId, Type>();
    private readonly ModuleBuilder _moduleBuilder;
    private int _anonymousTypeIndex;

    private readonly HashSet<Assembly> _assemblies = new HashSet<Assembly>
                                                         {
                                                             typeof (String).Assembly,
                                                             // mscorlib.dll
                                                             typeof (ExpressionType).Assembly,
                                                             // System.Core.dll
                                                             typeof (XElement).Assembly,
                                                             // System.Xml.Linq.dll
                                                             Assembly.GetExecutingAssembly(),
                                                         };

    #endregion Private Data Member

    #region Constructor

    public TypeResolver(IEnumerable<Assembly> assemblies = null)
    {
        var asmname = new AssemblyName {Name = "AnonymousTypes"};
        var assemblyBuilder = Thread.GetDomain().DefineDynamicAssembly(asmname,
                                                                       AssemblyBuilderAccess.Run);
        _moduleBuilder = assemblyBuilder.DefineDynamicModule("AnonymousTypes");

        if (assemblies != null)
        {
            foreach (var assembly in assemblies)
                _assemblies.Add(assembly);
        }
    }

    #endregion Constructor

    #region Public Method

    /// <summary>
    /// Get type of a generic type
    /// </summary>
    /// <param name="typeName"></param>
    /// <param name="genericArgumentTypes"></param>
    /// <returns></returns>
    public Type GetType(string typeName, IEnumerable<Type> genericArgumentTypes)
    {
        return GetType(typeName).MakeGenericType(genericArgumentTypes.ToArray());
    }

    /// <summary>
    /// Get type based on type name
    /// </summary>
    /// <param name="typeName"></param>
    /// <returns></returns>
    public Type GetType(string typeName)
    {
        Type type;
        if (string.IsNullOrEmpty(typeName))
            throw new ArgumentNullException("typeName");

        // array type
        if (typeName.EndsWith("[]"))
            return GetType(typeName.Substring(0, typeName.Length - 2)).MakeArrayType();

        // load type from assemblies
        foreach (var assembly in _assemblies)
        {
            type = assembly.GetType(typeName);
            if (type != null)
                return type;
        }

        // call Type.GetType()
        type = Type.GetType(typeName, false, true);
        if (type != null)
            return type;

        throw new ArgumentException("Could not find a matching type", typeName);
    }

    /// <summary>
    /// Get method based on type, name and parameters
    /// </summary>
    /// <param name="declaringType"></param>
    /// <param name="name"></param>
    /// <param name="parameterTypes"></param>
    /// <param name="genArgTypes"></param>
    /// <returns></returns>
    public MethodInfo GetMethod(Type declaringType, string name, Type[] parameterTypes, Type[] genArgTypes)
    {
        var methods = from mi in declaringType.GetMethods()
                      where mi.Name == name
                      select mi;
        foreach (var method in methods)
        {
            try
            {
                var realMethod = method;
                if (method.IsGenericMethod)
                {
                    realMethod = method.MakeGenericMethod(genArgTypes);
                }
                var methodParameterTypes = realMethod.GetParameters().Select(p => p.ParameterType);
                if (MatchPiecewise(parameterTypes, methodParameterTypes))
                {
                    return realMethod;
                }
            }
            catch (ArgumentException)
            {
            }
        }
        return null;
    }

    /// <summary>
    /// Get or create anonymous type
    /// </summary>
    /// <param name="name"></param>
    /// <param name="properties"></param>
    /// <param name="ctrParams"></param>
    /// <returns></returns>
    public Type GetOrCreateAnonymousType(string name, NameTypePair[] properties, NameTypePair[] ctrParams)
    {
        var id = new AnonTypeId(name, properties.Concat(ctrParams));
        if (_anonymousTypes.ContainsKey(id))
            return _anonymousTypes[id];

        const string anonPrefix = "<>f__AnonymousType";
        var anonTypeBuilder = _moduleBuilder.DefineType(anonPrefix + _anonymousTypeIndex++,
                                                        TypeAttributes.Public | TypeAttributes.Class);

        var fieldBuilders = new FieldBuilder[properties.Length];
        var propertyBuilders = new PropertyBuilder[properties.Length];

        for (var i = 0; i < properties.Length; i++)
        {
            fieldBuilders[i] = anonTypeBuilder.DefineField("_generatedfield_" + properties[i].Name,
                                                           properties[i].Type, FieldAttributes.Private);
            propertyBuilders[i] = anonTypeBuilder.DefineProperty(properties[i].Name,
                                                                 System.Reflection.PropertyAttributes.None,
                                                                 properties[i].Type, new Type[0]);
            var propertyGetterBuilder = anonTypeBuilder.DefineMethod("get_" + properties[i].Name,
                                                                     MethodAttributes.Public,
                                                                     properties[i].Type, new Type[0]);
            var getIlGenerator = propertyGetterBuilder.GetILGenerator();
            getIlGenerator.Emit(OpCodes.Ldarg_0);
            getIlGenerator.Emit(OpCodes.Ldfld, fieldBuilders[i]);
            getIlGenerator.Emit(OpCodes.Ret);
            propertyBuilders[i].SetGetMethod(propertyGetterBuilder);
        }

        var constructorBuilder = anonTypeBuilder.DefineConstructor(
            MethodAttributes.HideBySig | MethodAttributes.Public | MethodAttributes.Public,
            CallingConventions.Standard, ctrParams.Select(prop => prop.Type).ToArray());
        var constructorIlGenerator = constructorBuilder.GetILGenerator();
        for (var i = 0; i < ctrParams.Length; i++)
        {
            constructorIlGenerator.Emit(OpCodes.Ldarg_0);
            constructorIlGenerator.Emit(OpCodes.Ldarg, i + 1);
            constructorIlGenerator.Emit(OpCodes.Stfld, fieldBuilders[i]);
            constructorBuilder.DefineParameter(i + 1, ParameterAttributes.None, ctrParams[i].Name);
        }
        constructorIlGenerator.Emit(OpCodes.Ret);

        var anonType = anonTypeBuilder.CreateType();
        _anonymousTypes.Add(id, anonType);
        return anonType;
    }

    #endregion Public Method

    #region Private Method

    private bool MatchPiecewise<T>(IEnumerable<T> first, IEnumerable<T> second)
    {
        var firstArray = first.ToArray();
        var secondArray = second.ToArray();
        if (firstArray.Length != secondArray.Length)
            return false;
        return !firstArray.Where((t, i) => !t.Equals(secondArray[i])).Any();
    }

    #endregion Private Method

    #region nested classes

    public class NameTypePair
    {
        public string Name { get; set; }
        public Type Type { get; set; }

        public override int GetHashCode()
        {
            return Name.GetHashCode() + Type.GetHashCode();
        }

        public override bool Equals(object obj)
        {
            if (!(obj is NameTypePair))
                return false;
            var other = obj as NameTypePair;
            return Name.Equals(other.Name) && Type.Equals(other.Type);
        }
    }

    private class AnonTypeId
    {
        public string Name { get; private set; }
        public IEnumerable<NameTypePair> Properties { get; private set; }

        public AnonTypeId(string name, IEnumerable<NameTypePair> properties)
        {
            Name = name;
            Properties = properties;
        }

        public override int GetHashCode()
        {
            return Name.GetHashCode() + Properties.Sum(ntpair => ntpair.GetHashCode());
        }

        public override bool Equals(object obj)
        {
            if (!(obj is AnonTypeId))
                return false;
            var other = obj as AnonTypeId;
            return (Name.Equals(other.Name)
                    && Properties.SequenceEqual(other.Properties));
        }
    }

    #endregion
}
<#+
}

void WriteApplyChanges(CodeGenerationTools code)
{
#>
public static class SelfTrackingEntitiesContextExtensions
{
    /// <summary>
    /// ApplyChanges takes the changes in a connected set of entities and applies them to an ObjectContext.
    /// </summary>
    /// <typeparam name="TEntity">Expected type of the ObjectSet</typeparam>
    /// <param name="objectSet">The ObjectSet referencing the ObjectContext to which changes will be applied.</param>
    /// <param name="entity">The entity serving as the entry point of the object graph that contains changes.</param>
    public static void ApplyChanges<TEntity>(this ObjectSet<TEntity> objectSet, TEntity entity) where TEntity : class, IObjectWithChangeTracker
    {
        if (objectSet == null)
        {
            throw new ArgumentNullException("objectSet");
        }

        objectSet.Context.ApplyChanges<TEntity>(objectSet.EntitySet.EntityContainer.Name + "." + objectSet.EntitySet.Name, entity);
    }

    /// <summary>
    /// ApplyChanges takes the changes in a connected set of entities and applies them to an ObjectContext.
    /// </summary>
    /// <typeparam name="TEntity">Expected type of the EntitySet</typeparam>
    /// <param name="context">The ObjectContext to which changes will be applied.</param>
    /// <param name="entitySetName">The EntitySet name of the entity.</param>
    /// <param name="entity">The entity serving as the entry point of the object graph that contains changes.</param>
    public static void ApplyChanges<TEntity>(this ObjectContext context, string entitySetName, TEntity entity) where TEntity : IObjectWithChangeTracker
    {
        if (context == null)
        {
            throw new ArgumentNullException("context");
        }

        if (String.IsNullOrEmpty(entitySetName))
        {
            throw new ArgumentException("String parameter cannot be null or empty.", "entitySetName");
        }

        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }

        bool lazyLoadingSetting = context.ContextOptions.LazyLoadingEnabled;
        try
        {
            context.ContextOptions.LazyLoadingEnabled = false;

            EntityIndex entityIndex = AddHelper.AddAllEntities(context, entitySetName, entity);
            RelationshipSet allRelationships = new RelationshipSet(context, entityIndex.AllEntities);

            #region Handle Initial Entity State

            foreach (IObjectWithChangeTracker changedEntity in entityIndex.AllEntities.Where(x => x.ChangeTracker.State == ObjectState.Deleted))
            {
                HandleDeletedEntity(context, entityIndex, allRelationships, changedEntity);
            }

            foreach (IObjectWithChangeTracker changedEntity in entityIndex.AllEntities.Where(x => x.ChangeTracker.State != ObjectState.Deleted))
            {
                HandleEntity(context, entityIndex, allRelationships, changedEntity);
            }

            #endregion

            #region Loop through each object state entries

            foreach (IObjectWithChangeTracker changedEntity in entityIndex.AllEntities)
            {
                ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(changedEntity);

                EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(changedEntity);
                foreach (NavigationProperty navProp in entityType.NavigationProperties)
                {
                    RelatedEnd relatedEnd = entry.GetRelatedEnd(navProp.Name);
                    if(!((AssociationType)relatedEnd.RelationshipSet.ElementType).IsForeignKey)
                    {
                        ApplyChangesToIndependentAssociation(context, (IObjectWithChangeTracker)changedEntity, entry, navProp, relatedEnd, allRelationships);
                    }

                }
            }
            #endregion

            // Change all the remaining relationships to the appropriate state
            foreach (var relationship in allRelationships)
            {
                context.ObjectStateManager.ChangeRelationshipState(
                    relationship.End0,
                    relationship.End1,
                    relationship.AssociationSet.ElementType.FullName,
                    relationship.AssociationEndMembers[1].Name,
                    relationship.State);
            }
        }
        finally
        {
            context.ContextOptions.LazyLoadingEnabled = lazyLoadingSetting;
        }
    }

    private static void ApplyChangesToIndependentAssociation(ObjectContext context, IObjectWithChangeTracker changedEntity, ObjectStateEntry entry, NavigationProperty navProp,
        IRelatedEnd relatedEnd, RelationshipSet allRelationships)
    {
        ObjectChangeTracker changeTracker = changedEntity.ChangeTracker;

        if (changeTracker.State == ObjectState.Added)
        {
            // Relationships should remain added so remove them from the list of allRelationships
            foreach (object relatedEntity in relatedEnd)
            {
                ObjectStateEntry addedRelationshipEntry =
                            context.ObjectStateManager.ChangeRelationshipState(
                                changedEntity,
                                relatedEntity,
                                navProp.Name,
                                EntityState.Added);

                allRelationships.Remove(addedRelationshipEntry);
            }
        }
        else
        {
            if (navProp.ToEndMember.RelationshipMultiplicity == RelationshipMultiplicity.Many)
            {
                //Handle removal to FixupCollections
                ObjectList collectionPropertyChanges = null;
                if (changeTracker.ObjectsRemovedFromCollectionProperties.TryGetValue(navProp.Name, out collectionPropertyChanges))
                {
                    foreach (var removedEntityFromAssociation in collectionPropertyChanges)
                    {
                        ObjectStateEntry deletedRelationshipEntry =
                            context.ObjectStateManager.ChangeRelationshipState(
                                changedEntity,
                                removedEntityFromAssociation,
                                navProp.Name,
                                EntityState.Deleted);

                        allRelationships.Remove(deletedRelationshipEntry);
                    }
                }

                //Handle addition to FixupCollection
                if (changeTracker.ObjectsAddedToCollectionProperties.TryGetValue(navProp.Name, out collectionPropertyChanges))
                {
                    foreach (var addedEntityFromAssociation in collectionPropertyChanges)
                    {
                        allRelationships.Remove(AddRelationshipUnlessExists(context, relatedEnd, entry, addedEntityFromAssociation, navProp.Name));
                    }
                }
            }
            else
            {

                // Handle original relationship values
                object originalReferenceValue;
                if (changeTracker.OriginalValues.TryGetValue(navProp.Name, out originalReferenceValue))
                {
                    if (originalReferenceValue != null)
                    {
                        //Capture the deletion of association
                        ObjectStateEntry deletedRelationshipEntry =
                            context.ObjectStateManager.ChangeRelationshipState(
                                entry.Entity,
                                originalReferenceValue,
                                navProp.Name,
                                EntityState.Deleted);

                        allRelationships.Remove(deletedRelationshipEntry);
                    }

                    //Capture the Addition of association
                    object currentReferenceValue = null;
                    foreach (object o in relatedEnd)
                    {
                        currentReferenceValue = o;
                        break;
                    }
                    if (currentReferenceValue != null)
                    {
                        allRelationships.Remove(AddRelationshipUnlessExists(context, relatedEnd, entry, currentReferenceValue, navProp.Name));
                    }
                    // if the current value of the reference is null, then the user must set the entity reference to null
                    // which is already being handled by the deletion of the relationship
                }
            }
        }
    }

    // Creates an Added relationship entry unless the relationship already exists in the Unchanged state
    private static ObjectStateEntry AddRelationshipUnlessExists(ObjectContext context, IRelatedEnd relatedEnd, ObjectStateEntry fromEntry, object toEntity, string navPropName)
    {
        var toEntry = context.ObjectStateManager.GetObjectStateEntry(toEntity);
        var associationSet = ((AssociationSet)relatedEnd.RelationshipSet);
        var fromEnd = associationSet.AssociationSetEnds[relatedEnd.SourceRoleName].CorrespondingAssociationEndMember;
        var toEnd = associationSet.AssociationSetEnds[relatedEnd.TargetRoleName].CorrespondingAssociationEndMember;

        ObjectStateEntry existingRelationship;
        if (!context.TryGetObjectStateEntry(fromEntry.EntityKey, toEntry.EntityKey, associationSet, fromEnd, toEnd, out existingRelationship) ||
            existingRelationship.State != EntityState.Unchanged)
        {
            return context.ObjectStateManager.ChangeRelationshipState(
                    fromEntry.Entity,
                    toEntity,
                    navPropName,
                    EntityState.Added);
        }
        return existingRelationship;
    }

    // Extracts the relationship key information from the ExtendedProperties and OriginalValues records of each ObjectChangeTracker
    // This is done by:
    //  1. Creating any existing relationship specified in the ExtendedProperties
    //  2. Determine if there was a previous relationship, and if there was create a deleted relationship between the entity and the previous entity or key value
    private static void HandleRelationshipKeys(ObjectContext context, EntityIndex entityIndex, RelationshipSet allRelationships, IObjectWithChangeTracker entity)
    {
        ObjectChangeTracker changeTracker = entity.ChangeTracker;
        if (changeTracker.State == ObjectState.Unchanged ||
            changeTracker.State == ObjectState.Modified ||
            changeTracker.State == ObjectState.Deleted)
        {
            ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(entity);
            EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(entity);
            RelationshipManager relationshipManager = context.ObjectStateManager.GetRelationshipManager(entity);

            foreach (var entityReference in EnumerateSaveReferences(relationshipManager))
            {
                AssociationSet associationSet = ((AssociationSet)entityReference.RelationshipSet);
                AssociationEndMember fromEnd = associationSet.AssociationSetEnds[entityReference.SourceRoleName].CorrespondingAssociationEndMember;
                AssociationEndMember toEnd = associationSet.AssociationSetEnds[entityReference.TargetRoleName].CorrespondingAssociationEndMember;

                // Find if there is a NavigationProperty for this candidate
                NavigationProperty navigationProperty = entityType.NavigationProperties.
                                           SingleOrDefault(x => x.RelationshipType == associationSet.ElementType &&
                                                           x.FromEndMember == fromEnd &&
                                                           x.ToEndMember == toEnd);

                // Only handle relationship keys in one of these cases
                // 1. There is no navigation property
                // 2. The navigation property has a null current reference value and there are no removes or adds
                // 3. The navigation property has a current reference value, but there is no remove

                EntityKey currentKey = GetSavedReferenceKey(entityIndex, entityReference, entity, navigationProperty, changeTracker.ExtendedProperties);

                // Get any original value from the change tracking information
                object originalValue = null;
                EntityKey originalKey = null;
                bool hasOriginalValue = false;
                if (changeTracker.OriginalValues != null)
                {
                    // Try to get the original value from the NavigationProperty first
                    if (navigationProperty != null)
                    {
                        hasOriginalValue = changeTracker.OriginalValues.TryGetValue(navigationProperty.Name, out originalValue);
                    }
                    // Try to get the original value from the reference key second
                    if (!hasOriginalValue || originalValue == null)
                    {
                        originalKey = GetSavedReferenceKey(entityIndex, entityReference, entity, navigationProperty, changeTracker.OriginalValues);
                    }
                }

                // Create the current relationship
                if (currentKey != null)
                {
                    // If the key is for a deleted entity, move that key to an originalValue and fixup the entities key values
                    // Otherwise create a new relationship
                    ObjectStateEntry currentEntry;
                    if (context.ObjectStateManager.TryGetObjectStateEntry(currentKey, out currentEntry) &&
                       currentEntry.Entity != null &&
                       currentEntry.State == EntityState.Deleted)
                    {
                        entityReference.EntityKey = null;
                        MoveSavedReferenceKey(entityReference, entity, navigationProperty, changeTracker.ExtendedProperties, changeTracker.OriginalValues);
                        originalKey = currentKey;
                    }
                    else
                    {
                        CreateRelationship(context, entityReference, entry.EntityKey, currentKey, originalKey == null ? EntityState.Unchanged : EntityState.Added);
                    }
                }
                else
                {
                    // Find the current key
                    // Cannot get the EntityKey directly because this is null when it points to an Added entity
                    currentKey = entityReference.GetCurrentEntityKey(context);
                }

                // Create the original relationship
                if (originalKey != null)
                {
                    // If the key is for a deleted entity, remember to create a deleted relationship,
                    // otherwise use the entityReference to setup the deleted relationship
                    ObjectStateEntry originalEntry = null;
                    ObjectStateEntry deletedRelationshipEntry = null;
                    if (context.ObjectStateManager.TryGetObjectStateEntry(originalKey, out originalEntry) &&
                       originalEntry.Entity != null &&
                       originalEntry.State == EntityState.Deleted)
                    {
                        allRelationships.Add(entityReference, entry.Entity, originalEntry.Entity, EntityState.Deleted);
                    }
                    else
                    {
                        // To create a deleted relationship to a key, first detach the existing relationship between entry and currentKey
                        EntityState currentRelationshipState = DetachRelationship(context, entityReference, entry, currentKey);

                        // If the relationship is 1 to 0..1, detach the relationship from currentKey to its target (targetKey)
                        EntityState targetRelationshipState = EntityState.Detached;
                        EntityReference targetReference = null;
                        EntityKey targetKey = null;
                        if (originalEntry != null &&
                            originalEntry.Entity != null &&
                            originalEntry.RelationshipManager != null &&
                            associationSet.AssociationSetEnds[fromEnd.Name].CorrespondingAssociationEndMember.RelationshipMultiplicity != RelationshipMultiplicity.Many)
                        {
                            targetReference = originalEntry.RelationshipManager.GetRelatedEnd(entityReference.RelationshipName, entityReference.SourceRoleName) as EntityReference;
                            targetKey = targetReference.GetCurrentEntityKey(context);
                            if (targetKey != null)
                            {
                                targetRelationshipState = DetachRelationship(context, targetReference, originalEntry, targetKey);
                            }
                        }


                        // Create the deleted relationship between entry and originalKey
                        deletedRelationshipEntry = CreateRelationship(context, entityReference, entry.EntityKey, originalKey, EntityState.Deleted);

                        // Set the previous relationship between entry and currentKey back
                        CreateRelationship(context, entityReference, entry.EntityKey, currentKey, currentRelationshipState);

                        // Set the previous relationship between originalEntry and targetKey back
                        if (targetKey != null)
                        {
                            CreateRelationship(context, targetReference, originalEntry.EntityKey, targetKey, targetRelationshipState);
                        }
                    }
                    if (deletedRelationshipEntry != null)
                    {
                        // Remove the deleted relationship from those that need to be processed later in ApplyChanges
                        allRelationships.Remove(deletedRelationshipEntry);
                    }
                }
                else if (currentKey == null && originalValue != null && entityReference.IsDependentEndOfReferentialConstraint())
                {
                    // the graph won't have this hooked up because there is no current value, but there is an original value,
                    // so the relationship processing code will want to delete a relationship.
                    // we can add this one so it has a relationship to change to deleted.
                    context.ObjectStateManager.ChangeRelationshipState(
                                                        entry.Entity,
                                                        originalValue,
                                                        entityReference.RelationshipName,
                                                        entityReference.TargetRoleName,
                                                        EntityState.Added);
                }
            }
        }
    }

    private static ObjectStateEntry CreateRelationship(ObjectContext context, EntityReference entityReference, EntityKey fromKey, EntityKey toKey, EntityState state)
    {
        if (state != EntityState.Detached)
        {
            AssociationSet associationSet = ((AssociationSet)entityReference.RelationshipSet);
            AssociationEndMember fromEnd = associationSet.AssociationSetEnds[entityReference.SourceRoleName].CorrespondingAssociationEndMember;
            AssociationEndMember toEnd = associationSet.AssociationSetEnds[entityReference.TargetRoleName].CorrespondingAssociationEndMember;

            // set the relationship to the original relationship in the unchanged state
            Debug.Assert(toKey != null, "why/how would we do a delete with a null originalKey?");

            if (toKey.IsTemporary)
            {
                // Clear any existing relationship
                entityReference.EntityKey = null;

                // If the target entity is Added, use Add on RelatedEnd
                ObjectStateEntry targetEntry;
                context.ObjectStateManager.TryGetObjectStateEntry(toKey, out targetEntry);
                Debug.Assert(targetEntry != null, "Should have found the state entry");
                ((IRelatedEnd)entityReference).Add(targetEntry.Entity);
            }
            else
            {
                entityReference.EntityKey = toKey;
            }

            ObjectStateEntry relationshipEntry;
            bool found = context.TryGetObjectStateEntry(fromKey, toKey, associationSet, fromEnd, toEnd, out relationshipEntry);
            Debug.Assert(found, "Did not find the created relationship.");

            switch (state)
            {
                case EntityState.Added:
                    break;
                case EntityState.Unchanged:
                    relationshipEntry.AcceptChanges();
                    break;
                case EntityState.Deleted:
                    relationshipEntry.AcceptChanges();
                    entityReference.EntityKey = null;
                    break;
            }
            return relationshipEntry;
        }
        return null;
    }

    private static EntityState DetachRelationship(ObjectContext context, EntityReference entityReference, ObjectStateEntry fromEntry, EntityKey toKey)
    {
        EntityState currentRelationshipState = EntityState.Detached;

        if (toKey != null)
        {
            AssociationSet associationSet = ((AssociationSet)entityReference.RelationshipSet);
            AssociationEndMember fromEnd = associationSet.AssociationSetEnds[entityReference.SourceRoleName].CorrespondingAssociationEndMember;
            AssociationEndMember toEnd = associationSet.AssociationSetEnds[entityReference.TargetRoleName].CorrespondingAssociationEndMember;

            ObjectStateEntry currentRelationshipEntry = null;

            if (context.TryGetObjectStateEntry(fromEntry.EntityKey, toKey, associationSet, fromEnd, toEnd, out currentRelationshipEntry))
            {
                currentRelationshipState = currentRelationshipEntry.State;

                entityReference.EntityKey = null;
                if (currentRelationshipEntry.State == EntityState.Deleted)
                {
                    currentRelationshipEntry.AcceptChanges();
                }
                Debug.Assert(currentRelationshipEntry.State == EntityState.Detached, "relationship was not detached");
            }
        }
        return currentRelationshipState;
    }

    private static string CreateReferenceKeyLookup(string keyMemberName, EntityReference reference, NavigationProperty navigationProperty)
    {
        // use the more usable navigation property name to qualify the member
        // if available
        if (navigationProperty != null)
        {
            return String.Format(CultureInfo.InvariantCulture, "{0}.{1}", navigationProperty.Name, keyMemberName);
        }
        else
        {
            return String.Format(CultureInfo.InvariantCulture, "Navigate({0}.{1}).{2}", reference.RelationshipSet.ElementType.FullName, reference.TargetRoleName, keyMemberName);
        }
    }

    // retrieves the key corresponding to the passed in EntityReference
    // these keys can be set during the ObjectMaterialized event or through relationship fixup
    private static EntityKey GetSavedReferenceKey(EntityIndex entityIndex, EntityReference reference, object entity, NavigationProperty navigationProperty, IDictionary<string, object> values)
    {
        Debug.Assert(navigationProperty == null || reference.RelationshipSet.ElementType == navigationProperty.RelationshipType, "the reference and navigationProperty should correspond");

        EntitySet entitySet = ((AssociationSet)reference.RelationshipSet).AssociationSetEnds[reference.TargetRoleName].EntitySet;

        List<EntityKeyMember> foundKeyMembers = new List<EntityKeyMember>(1);
        bool foundNone = true;
        bool missingSome = false;
        foreach (var keyMember in entitySet.ElementType.KeyMembers)
        {
            string lookupKey = CreateReferenceKeyLookup(keyMember.Name, reference, navigationProperty);
            object value;
            if (values.TryGetValue(lookupKey, out value))
            {
                foundKeyMembers.Add(new EntityKeyMember(keyMember.Name, value));
                foundNone = false;
            }
            else
            {
                missingSome = true;
            }
        }

        if (foundNone)
        {
            // we didn't find a key
            return null;
        }
        else if (missingSome)
        {
            throw new InvalidOperationException(
                String.Format(
                    CultureInfo.CurrentCulture,
                    "The OriginalValues or ExtendedProperties collections on the type '{0}' contained only a partial key to satisfy the relationship '{1}' targeting the role '{2}'",
                    ObjectContext.GetObjectType(entity.GetType()).FullName,
                    reference.RelationshipName,
                    reference.TargetRoleName));
        }

        EntityKey key = entityIndex.ConvertEntityKey(new EntityKey(reference.GetEntitySetName(), foundKeyMembers));
        return key;
    }

    // Moves the key corresponding to the passed in EntityReference from a source collection to a target collection
    private static void MoveSavedReferenceKey(EntityReference reference, object entity, NavigationProperty navigationProperty, IDictionary<string, object> sourceValues, IDictionary<string, object> targetValues)
    {
        Debug.Assert(navigationProperty == null || reference.RelationshipSet.ElementType == navigationProperty.RelationshipType, " the reference and navigationProperty should correspond");

        EntitySet entitySet = ((AssociationSet)reference.RelationshipSet).AssociationSetEnds[reference.TargetRoleName].EntitySet;

        bool missingSome = false;
        foreach (var keyMember in entitySet.ElementType.KeyMembers)
        {
            string lookupKey = CreateReferenceKeyLookup(keyMember.Name, reference, navigationProperty);
            object value;
            if (sourceValues.TryGetValue(lookupKey, out value))
            {
                if (targetValues.ContainsKey(lookupKey))
                {
                    targetValues[lookupKey] = value;
                }
                else
                {
                    targetValues.Add(lookupKey, value);
                }
                sourceValues.Remove(lookupKey);
            }
            else
            {
                missingSome = true;
            }
        }

        if (missingSome)
        {
            throw new InvalidOperationException(
                String.Format(
                    CultureInfo.CurrentCulture,
                    " The OriginalValues or ExtendedProperties collections on the type '{0}' contained only a partial key to satisfy the relationship '{1}' targeting the role '{2}'",
                    ObjectContext.GetObjectType(entity.GetType()).FullName,
                    reference.RelationshipName,
                    reference.TargetRoleName));
        }
    }

    private static IEnumerable<EntityReference> EnumerateSaveReferences(RelationshipManager manager)
    {
        return manager.GetAllRelatedEnds().OfType<EntityReference>()
                .Where(er => er.RelationshipSet.ElementType.RelationshipEndMembers[er.SourceRoleName].RelationshipMultiplicity != RelationshipMultiplicity.One &&
                    !((AssociationSet)er.RelationshipSet).ElementType.IsForeignKey);
    }

    internal static void StoreReferenceKeyValues(this ObjectContext context, IObjectWithChangeTracker entity)
    {
        if(entity == null)
        {
            throw new ArgumentNullException("entity");
        }

        ObjectStateEntry entry;
        if (!context.ObjectStateManager.TryGetObjectStateEntry(entity, out entry))
        {
            // must be a no tracking query, the reference key info won't be available
            return;
        }

        var relationshipManager = entry.RelationshipManager;
        EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(entity);
        foreach (EntityReference entityReference in EnumerateSaveReferences(relationshipManager))
        {
            NavigationProperty navigationProperty = entityType.NavigationProperties.FirstOrDefault(n => n.RelationshipType == entityReference.RelationshipSet.ElementType &&
                    n.FromEndMember.Name == entityReference.SourceRoleName &&
                    n.ToEndMember.Name == entityReference.TargetRoleName);

            object value = entityReference.GetValue();
            if ((navigationProperty == null || value == null) && entityReference.EntityKey != null)
            {
                foreach (var item in entityReference.EntityKey.EntityKeyValues)
                {
                    string key = CreateReferenceKeyLookup(item.Key, entityReference, navigationProperty);
                    entity.ChangeTracker.ExtendedProperties.Add(key, item.Value);
                }
            }
        }
    }

    private static void HandleEntity(ObjectContext context, EntityIndex entityIndex, RelationshipSet allRelationships, IObjectWithChangeTracker entity)
    {
        ChangeEntityStateBasedOnObjectState(context, entity);
        HandleRelationshipKeys(context, entityIndex, allRelationships, entity);
        UpdateOriginalValues(context, entity);
    }

    private static void HandleDeletedEntity(ObjectContext context, EntityIndex entityIndex, RelationshipSet allRelationships, IObjectWithChangeTracker entity)
    {
        HandleRelationshipKeys(context, entityIndex, allRelationships, entity);
        ChangeEntityStateBasedOnObjectState(context, entity);
        UpdateOriginalValues(context, entity);
    }

    private static void UpdateOriginalValues(ObjectContext context, IObjectWithChangeTracker entity)
    {
        if (entity.ChangeTracker.State == ObjectState.Unchanged ||
            entity.ChangeTracker.State == ObjectState.Added ||
            entity.ChangeTracker.OriginalValues == null)
        {
            // nothing to do here
            return;
        }

        // we only need/want to deal with scalar and complex properties

        ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(entity);
        OriginalValueRecord originalValueRecord = entry.GetUpdatableOriginalValues();
        EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(entity);

        // walk through each property and see if we have an original value for it
        // set it if we do.  Walk down through ComplexType properties to set original values
        // for each of them also
        //
        // it is expected that the original values will be sparse because we are trying
        // to only capture originals for the ones we are required to have (concurrency, sproc, condition, more?)
        foreach(EdmProperty property in entityType.Properties)
        {
            object value;
            if(property.TypeUsage.EdmType is SimpleType && entity.ChangeTracker.OriginalValues.TryGetValue(property.Name, out value))
            {
                originalValueRecord.SetValue(property, value);
            }
            else if(property.TypeUsage.EdmType is ComplexType)
            {
                OriginalValueRecord complexOriginalValues = originalValueRecord.GetOriginalValueRecord(property.Name);
                UpdateOriginalValues((ComplexType)property.TypeUsage.EdmType, ObjectContext.GetObjectType(entity.GetType()).FullName, property.Name, entity.ChangeTracker.OriginalValues, complexOriginalValues);
            }
        }
    }

    private static void UpdateOriginalValues(ComplexType complexType, string entityTypeName, string propertyPathToType, IDictionary<string, object> originalValueSource, OriginalValueRecord complexOriginalValueRecord)
    {
        // Note that complexOriginalValueRecord may be null
        // a null complexOriginalValueRecord will only occur if a null reference is assigned
        // to a ComplexType property and then given to ApplyChanges.
        //
        // walk through each property and see if we have an original value for it
        // set it if we do.  Walk down through ComplexType properties to set original values
        // for each of them also
        foreach (EdmProperty property in complexType.Properties)
        {
            object value;
            string propertyPath = String.Format(CultureInfo.InvariantCulture, "{0}.{1}", propertyPathToType, property.Name);
            if (property.TypeUsage.EdmType is SimpleType && originalValueSource.TryGetValue(propertyPath, out value))
            {
                if (complexOriginalValueRecord != null)
                {
                    complexOriginalValueRecord.SetValue(property, value);
                }
                else if (value != null)
                {
                    Debug.Assert(complexOriginalValueRecord == null, "we only throw when the value is not null and the record is null");
                    throw new InvalidOperationException(
                        String.Format(
                        CultureInfo.CurrentCulture,
                        "Can not set the original value on the object stored in the property '{0}' on the type '{1}' because the property is null.",
                        propertyPathToType,
                        entityTypeName));
                }
            }
            else if (property.TypeUsage.EdmType is ComplexType)
            {
                OriginalValueRecord nestedOriginalValueRecord = null;
                if (complexOriginalValueRecord != null)
                {
                    nestedOriginalValueRecord = complexOriginalValueRecord.GetOriginalValueRecord(property.Name);
                }
                // recurse down the chain of complex types
                UpdateOriginalValues((ComplexType)property.TypeUsage.EdmType, entityTypeName, propertyPath, originalValueSource, nestedOriginalValueRecord);
            }
        }
    }

    private static OriginalValueRecord GetOriginalValueRecord(this OriginalValueRecord record, string name)
    {
        int ordinal = record.GetOrdinal(name);
        if (!record.IsDBNull(ordinal))
        {
            return record.GetDataRecord(ordinal) as OriginalValueRecord;
        }
        else
        {
            return null;
        }
    }

    private static void SetValue(this OriginalValueRecord record, EdmProperty edmProperty, object value)
    {
        PrimitiveType primitiveType = edmProperty.TypeUsage.EdmType as PrimitiveType;
        if (value == null && primitiveType != null)
        {
            Type entityClrType = primitiveType.ClrEquivalentType;
            if (entityClrType.IsValueType && !edmProperty.Nullable)
            {
                // Skip setting null original values on non-nullable CLR types because the ObjectStateEntry won't allow this
                return;
            }
        }

        int ordinal = record.GetOrdinal(edmProperty.Name);
        record.SetValue(ordinal, value);
    }


    private static void ChangeEntityStateBasedOnObjectState(ObjectContext context, IObjectWithChangeTracker entity)
    {
        switch (entity.ChangeTracker.State)
        {
            case (ObjectState.Added):
                // No-op: the state entry is already marked as added
                Debug.Assert(context.ObjectStateManager.GetObjectStateEntry(entity).State == EntityState.Added, "State should have been Added");
                break;
            case (ObjectState.Unchanged):
                context.ObjectStateManager.ChangeObjectState(entity, EntityState.Unchanged);
                break;
            case (ObjectState.Modified):
                context.ObjectStateManager.ChangeObjectState(entity, EntityState.Modified);
                break;
            case (ObjectState.Deleted):
                context.ObjectStateManager.ChangeObjectState(entity, EntityState.Deleted);
                break;
        }
    }

    private static EntityType GetCSpaceEntityType(this MetadataWorkspace workspace, object entity)
    {
        Type type = ObjectContext.GetObjectType(entity.GetType());
        EntityType ospaceEntityType = null;
        StructuralType cspaceEntityType = null;
        EntityType entityType = null;
        if (workspace.TryGetItem<EntityType>(
            type.FullName,
            DataSpace.OSpace,
            out ospaceEntityType))
        {
            if (workspace.TryGetEdmSpaceType(
                ospaceEntityType,
                out cspaceEntityType))
            {
                entityType = cspaceEntityType as EntityType;
            }
        }
        if(entityType == null)
        {
            throw new ArgumentException(String.Format(CultureInfo.CurrentCulture, "Unable to find a CSpace type for type {0}", type.FullName));
        }
        return entityType;
    }

    private static object GetValue(this System.Data.Objects.DataClasses.EntityReference entityReference)
    {
        foreach (object value in entityReference)
        {
            return value;
        }
        return null;
    }

    private static EntityKey GetCurrentEntityKey(this System.Data.Objects.DataClasses.EntityReference entityReference, ObjectContext context)
    {
        EntityKey currentKey = null;
        object currentValue = entityReference.GetValue();
        if (currentValue != null)
        {
            ObjectStateEntry relatedEntry = context.ObjectStateManager.GetObjectStateEntry(currentValue);
            currentKey = relatedEntry.EntityKey;
        }
        else
        {
            currentKey = entityReference.EntityKey;
        }
        return currentKey;
    }

    private static RelatedEnd GetRelatedEnd(this ObjectStateEntry entry, string navigationPropertyIdentity)
    {
        NavigationProperty navigationProperty =
                        GetNavigationProperty(entry.ObjectStateManager.MetadataWorkspace.GetCSpaceEntityType(entry.Entity), navigationPropertyIdentity);
        return entry.RelationshipManager.GetRelatedEnd(
            navigationProperty.RelationshipType.FullName, navigationProperty.ToEndMember.Name) as RelatedEnd;
    }

    private static NavigationProperty GetNavigationProperty(this EntityType entityType, string navigationPropertyIdentity)
    {
        NavigationProperty navigationProperty;
        if (!entityType.NavigationProperties.TryGetValue(navigationPropertyIdentity, false, out navigationProperty))
        {
            throw new InvalidOperationException(
                String.Format(
                    CultureInfo.CurrentCulture,
                    "Could not find navigation property '{0}' in EntityType '{1}'.",
                    navigationPropertyIdentity,
                    entityType.FullName));
        }
        return navigationProperty;
    }

    private static string GetEntitySetName(this RelatedEnd relatedEnd)
    {
        EntitySet entitySet = ((AssociationSet)relatedEnd.RelationshipSet).AssociationSetEnds[relatedEnd.TargetRoleName].EntitySet;
        return entitySet.EntityContainer.Name + "." + entitySet.Name;
    }

    private static bool IsDependentEndOfReferentialConstraint(this RelatedEnd relatedEnd)
    {
        if (null != relatedEnd.RelationshipSet)
        {
            // NOTE Referential constraints collection will usually contains 0 or 1 element,
            // so performance shouldn't be an issue here
            foreach (ReferentialConstraint constraint in ((AssociationType)relatedEnd.RelationshipSet.ElementType).ReferentialConstraints)
            {
                if (constraint.ToRole.Name == relatedEnd.SourceRoleName)
                {
                    // Example:
                    //    Client&lt;C_ID&gt; --- Order&lt;O_ID, Client_ID&gt;
                    //    RI Constraint: Principal/From &lt;Client.C_ID&gt;,  Dependent/To &lt;Order.Client_ID&gt;
                    // When current RelatedEnd is a CollectionOrReference in Order's relationships,
                    // constarint.ToRole == this._fromEndProperty == Order
                    return true;
                }
            }
        }
        return false;
    }

    private static bool TryGetObjectStateEntry(this ObjectContext context, EntityKey from, EntityKey to, AssociationSet associationSet, AssociationEndMember fromEnd, AssociationEndMember toEnd, out ObjectStateEntry entry)
    {
        entry = null;
        foreach (var relationshipEntry in (from e in context.ObjectStateManager.GetObjectStateEntries(EntityState.Added | EntityState.Unchanged)
                                           where e.IsRelationship && e.EntitySet == associationSet
                                           select e))
        {
            CurrentValueRecord currentValues = relationshipEntry.CurrentValues;
            int fromOrdinal = currentValues.GetOrdinal(fromEnd.Name);
            int toOrdinal = currentValues.GetOrdinal(toEnd.Name);
            if (((EntityKey)currentValues.GetValue(fromOrdinal)) == from &&
                ((EntityKey)currentValues.GetValue(toOrdinal)) == to)
            {
                entry = relationshipEntry;
                return true;
            }
        }
        return false;
    }

    private sealed class AddHelper
    {
        private readonly ObjectContext _context;
        private readonly EntityIndex _entityIndex;

        // Used during add processing
        private readonly Queue<Tuple<string, IObjectWithChangeTracker>> _entitiesToAdd;
        private readonly Queue<Tuple<ObjectStateEntry, string, IEnumerable<object>>> _entitiesDuringAdd;

        public static EntityIndex AddAllEntities(ObjectContext context, string entitySetName, IObjectWithChangeTracker entity)
        {
            AddHelper addHelper = new AddHelper(context);

            try
            {
                // Include the root element to start the Apply
                addHelper.QueueAdd(entitySetName, entity);

                // Add everything
                while (addHelper.HasMore)
                {
                    Tuple<string, IObjectWithChangeTracker> entityInSet = addHelper.NextAdd();
                    // Only add the object if it's not already in the context
                    ObjectStateEntry entry = null;
                    if (!context.ObjectStateManager.TryGetObjectStateEntry(entityInSet.Item2, out entry))
                    {
                        context.AddObject(entityInSet.Item1, entityInSet.Item2);
                    }
                }
            }
            finally
            {
                addHelper.Detach();
            }
            return addHelper.EntityIndex;
        }

        private AddHelper(ObjectContext context)
        {
            _context = context;
            _context.ObjectStateManager.ObjectStateManagerChanged += this.HandleStateManagerChange;

            _entityIndex = new EntityIndex(context);
            _entitiesToAdd = new Queue<Tuple<string, IObjectWithChangeTracker>>();
            _entitiesDuringAdd = new Queue<Tuple<ObjectStateEntry, string, IEnumerable<object>>>();
        }

        private void Detach()
        {
            _context.ObjectStateManager.ObjectStateManagerChanged -= this.HandleStateManagerChange;
        }

        private void HandleStateManagerChange(object sender, CollectionChangeEventArgs args)
        {
            if (args.Action == CollectionChangeAction.Add)
            {
                IObjectWithChangeTracker entity = args.Element as IObjectWithChangeTracker;
                ObjectStateEntry entry = _context.ObjectStateManager.GetObjectStateEntry(entity);
                ObjectChangeTracker changeTracker = entity.ChangeTracker;

                changeTracker.ChangeTrackingEnabled = false;
                _entityIndex.Add(entry, changeTracker);

                // Queue removed reference values
                var navPropNames = _context.MetadataWorkspace.GetCSpaceEntityType(entity).NavigationProperties.Select(n => n.Name);
                var entityRefOriginalValues = changeTracker.OriginalValues.Where(kvp => navPropNames.Contains(kvp.Key));
                foreach (KeyValuePair<string, object> originalValueWithName in entityRefOriginalValues)
                {
                    if (originalValueWithName.Value != null)
                    {
                        _entitiesDuringAdd.Enqueue(new Tuple<ObjectStateEntry, string, IEnumerable<object>>(
                            entry,
                            originalValueWithName.Key,
                            new object[] { originalValueWithName.Value }));
                    }
                }

                // Queue removed collection values
                foreach (KeyValuePair<string, ObjectList> collectionPropertyChangesWithName in changeTracker.ObjectsRemovedFromCollectionProperties)
                {
                    _entitiesDuringAdd.Enqueue(new Tuple<ObjectStateEntry, string, IEnumerable<object>>(
                        entry,
                        collectionPropertyChangesWithName.Key,
                        collectionPropertyChangesWithName.Value));
                }
            }
        }

        private EntityIndex EntityIndex
        {
            get { return _entityIndex; }
        }

        private bool HasMore
        {
            get { ProcessNewAdds(); return _entitiesToAdd.Count > 0; }
        }

        private void QueueAdd(string entitySetName, IObjectWithChangeTracker entity)
        {
            if (!_entityIndex.Contains(entity))
            {
                // Queue the entity so that we can add the 'removed collection' items
                _entitiesToAdd.Enqueue(new Tuple<string, IObjectWithChangeTracker>(entitySetName, entity));
            }
        }

        private Tuple<string, IObjectWithChangeTracker> NextAdd()
        {
            ProcessNewAdds();
            return _entitiesToAdd.Dequeue();
        }

        private void ProcessNewAdds()
        {
            while (_entitiesDuringAdd.Count > 0)
            {
                Tuple<ObjectStateEntry, string, IEnumerable<object>> relatedEntities = _entitiesDuringAdd.Dequeue();
                RelatedEnd relatedEnd = relatedEntities.Item1.GetRelatedEnd(relatedEntities.Item2);
                string entitySetName = relatedEnd.GetEntitySetName();

                foreach (var targetEntity in relatedEntities.Item3)
                {
                    QueueAdd(entitySetName, targetEntity as IObjectWithChangeTracker);
                }
            }
        }
    }

    private sealed class EntityIndex
    {
        private readonly ObjectContext _context;

        // Set of all entities
        private readonly HashSet<IObjectWithChangeTracker> _allEntities;

        // Index of the final key that will be used in the context (could be real for non-added, could be temporary for added)
        // to the initial temporary key
        private readonly Dictionary<EntityKey, EntityKey> _temporaryKeyMap;

        public EntityIndex(ObjectContext context)
        {
            _context = context;

            _allEntities = new HashSet<IObjectWithChangeTracker>();
            _temporaryKeyMap = new Dictionary<EntityKey, EntityKey>();
        }

        public void Add(ObjectStateEntry entry, ObjectChangeTracker changeTracker)
        {
            EntityKey temporaryKey = entry.EntityKey;
            EntityKey finalKey;

            if (!_allEntities.Contains(entry.Entity))
            {
                // Track that this Apply will be handling this entity
                _allEntities.Add(entry.Entity as IObjectWithChangeTracker);
            }

            if (changeTracker.State == ObjectState.Added)
            {
                finalKey = temporaryKey;
            }
            else
            {
                finalKey = _context.CreateEntityKey(temporaryKey.EntityContainerName + "." + temporaryKey.EntitySetName, entry.Entity);
            }
            if (!_temporaryKeyMap.ContainsKey(finalKey))
            {
                _temporaryKeyMap.Add(finalKey, temporaryKey);
            }
        }

        public bool Contains(object entity)
        {
            return _allEntities.Contains(entity);
        }

        public IEnumerable<IObjectWithChangeTracker> AllEntities
        {
            get { return _allEntities; }
        }

        // Converts the passed in EntityKey to the EntityKey that is usable by the current state of ApplyChanges
        public EntityKey ConvertEntityKey(EntityKey targetKey)
        {
            ObjectStateEntry targetEntry;
            if (!_context.ObjectStateManager.TryGetObjectStateEntry(targetKey, out targetEntry))
            {
                // If no entry exists, then either:
                // 1. This is an EntityKey that is not represented in the set of entities being dealt with during the Apply
                // 2. This is an EntityKey that will represent one of the yet-to-be-processed Added entries, so look it up
                EntityKey temporaryKey;
                if (_temporaryKeyMap.TryGetValue(targetKey, out temporaryKey))
                {
                    targetKey = temporaryKey;
                }
            }
            return targetKey;
        }
    }

    // The RelationshipSet builds a list of all relationships from an
    // initial set of entities
    private sealed class RelationshipSet : IEnumerable<RelationshipWrapper>
    {
        private readonly HashSet<RelationshipWrapper> _relationships;
        private readonly ObjectContext _context;

        public RelationshipSet(ObjectContext context, IEnumerable<object> allEntities)
        {
            _context = context;
            _relationships = new HashSet<RelationshipWrapper>();
            foreach (object entity in allEntities)
            {
                ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(entity);
                foreach (IRelatedEnd relatedEnd in entry.RelationshipManager.GetAllRelatedEnds())
                {
                    if (!((AssociationType)relatedEnd.RelationshipSet.ElementType).IsForeignKey)
                    {
                        foreach (object targetEntity in relatedEnd)
                        {
                            Add(relatedEnd, entity, targetEntity, EntityState.Unchanged);
                        }
                    }
                }
            }
        }

        // Adds an entry to the index based on a IRelatedEnd
        public void Add(IRelatedEnd relatedEnd, object sourceEntity, object targetEntity, EntityState state)
        {
            RelationshipWrapper wrapper = new RelationshipWrapper(
                                (AssociationSet)relatedEnd.RelationshipSet,
                                relatedEnd.SourceRoleName,
                                sourceEntity,
                                relatedEnd.TargetRoleName,
                                targetEntity,
                                state);
            if (!_relationships.Contains(wrapper))
            {
                _relationships.Add(wrapper);
            }
        }

        // Removes an entry from the index based on a relationship ObjectStateEntry
        public void Remove(ObjectStateEntry relationshipEntry)
        {
            Debug.Assert(relationshipEntry.IsRelationship);
            AssociationSet associationSet = (AssociationSet)relationshipEntry.EntitySet;
            DbDataRecord values = relationshipEntry.State == EntityState.Deleted ? relationshipEntry.OriginalValues : relationshipEntry.CurrentValues;
            int fromOridinal = values.GetOrdinal(associationSet.ElementType.AssociationEndMembers[0].Name);
            object fromEntity = _context.ObjectStateManager.GetObjectStateEntry((EntityKey)values.GetValue(fromOridinal)).Entity;
            int toOridinal = values.GetOrdinal(associationSet.ElementType.AssociationEndMembers[1].Name);
            object toEntity = _context.ObjectStateManager.GetObjectStateEntry((EntityKey)values.GetValue(toOridinal)).Entity;

            if (fromEntity != null && toEntity != null)
            {
                RelationshipWrapper wrapper = new RelationshipWrapper(
                    associationSet,
                    associationSet.ElementType.AssociationEndMembers[0].Name,
                    fromEntity,
                    associationSet.ElementType.AssociationEndMembers[1].Name,
                    toEntity,
                    EntityState.Unchanged);

                _relationships.Remove(wrapper);
            }
        }

        #region IEnumerable<RelationshipWrapper>

        public IEnumerator<RelationshipWrapper> GetEnumerator()
        {
            return _relationships.GetEnumerator();
        }

        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return _relationships.GetEnumerator();
        }

        #endregion
    }

    // A RelationshipWrapper is used to identify a relationship between two entities
    // The relationship is identified by the AssociationSet, and the order of the entities based
    // on the roles they play (via AssociationEndMember)
    private sealed class RelationshipWrapper : IEquatable<RelationshipWrapper>
    {
        internal readonly AssociationSet AssociationSet;
        internal readonly object End0;
        internal readonly object End1;
        internal readonly EntityState State;

        internal RelationshipWrapper(AssociationSet extent,
                                     string role0, object end0,
                                     string role1, object end1,
                                     EntityState state)
        {
            Debug.Assert(null != extent, "null AssociationSet");
            Debug.Assert(null != (object)end0, "null end0");
            Debug.Assert(null != (object)end1, "null end1");

            AssociationSet = extent;
            Debug.Assert(extent.ElementType.AssociationEndMembers.Count == 2, "only 2 ends are supported");

            State = state;

            if (extent.ElementType.AssociationEndMembers[0].Name == role0)
            {
                Debug.Assert(extent.ElementType.AssociationEndMembers[1].Name == role1, "a)roleAndKey1 Name differs");
                End0 = end0;
                End1 = end1;
            }
            else
            {
                Debug.Assert(extent.ElementType.AssociationEndMembers[0].Name == role1, "b)roleAndKey1 Name differs");
                Debug.Assert(extent.ElementType.AssociationEndMembers[1].Name == role0, "b)roleAndKey0 Name differs");
                End0 = end1;
                End1 = end0;
            }
        }

        internal ReadOnlyMetadataCollection<AssociationEndMember> AssociationEndMembers
        {
            get { return this.AssociationSet.ElementType.AssociationEndMembers; }
        }

        public override int GetHashCode()
        {
            return this.AssociationSet.Name.GetHashCode() ^ (this.End0.GetHashCode() + this.End1.GetHashCode());
        }

        public override bool Equals(object obj)
        {
            return Equals(obj as RelationshipWrapper);
        }

        public bool Equals(RelationshipWrapper wrapper)
        {
            return (Object.ReferenceEquals(this, wrapper) ||
                    ((null != wrapper) &&
                     Object.ReferenceEquals(this.AssociationSet, wrapper.AssociationSet) &&
                     Object.ReferenceEquals(this.End0, wrapper.End0) &&
                     Object.ReferenceEquals(this.End1, wrapper.End1)));
        }
    }
}
<#+
}
#>
<#+
string DefaultSummaryComment{ get; set; }

string SummaryComment(MetadataItem item)
{
    if (item.Documentation != null && item.Documentation.Summary != null)
    {
        return PrefixLinesOfMultilineComment(XMLCOMMENT_START + " ", XmlEntityize(item.Documentation.Summary));
    }

    if (DefaultSummaryComment != null)
    {
        return DefaultSummaryComment;
    }

    return string.Empty;
}

string LongDescriptionCommentElement(MetadataItem item, int indentLevel)
{
    if (item.Documentation != null && !String.IsNullOrEmpty(item.Documentation.LongDescription))
    {
        string comment = Environment.NewLine;
        string lineStart = CodeRegion.GetIndent(indentLevel) + XMLCOMMENT_START + " ";
        comment += lineStart + "<LongDescription>" + Environment.NewLine;
        comment += lineStart + PrefixLinesOfMultilineComment(lineStart, XmlEntityize(item.Documentation.LongDescription)) + Environment.NewLine;
        comment += lineStart + "</LongDescription>";
        return comment;
    }
    return string.Empty;
}

string NewModifier(NavigationProperty navigationProperty)
{
    Type baseType = typeof(EntityObject);
    return NewModifier(baseType, navigationProperty.Name);
}

string NewModifier(EdmFunction edmFunction)
{
    Type baseType = typeof(ObjectContext);
    return NewModifier(baseType, edmFunction.Name);
}

string NewModifier(EntitySet set)
{
    Type baseType = typeof(ObjectContext);
    return NewModifier(baseType, set.Name);
}

string NewModifier(EdmProperty property)
{
    Type baseType;
    if (property.DeclaringType.BuiltInTypeKind == BuiltInTypeKind.EntityType)
    {
        baseType = typeof(EntityObject);
    }
    else
    {
        baseType = typeof(ComplexObject);
    }
    return NewModifier(baseType, property.Name);
}

string NewModifier(Type type, string memberName)
{
    if (HasBaseMemberWithMatchingName(type, memberName))
    {
        return "new";
    }
    return string.Empty;
}

string PrefixLinesOfMultilineComment(string prefix, string comment)
{
    return comment.Replace(Environment.NewLine, Environment.NewLine + prefix);
}

string ParameterComments(IEnumerable<Tuple<string, string>> parameters, int indentLevel)
{
    System.Text.StringBuilder builder = new System.Text.StringBuilder();
    foreach (Tuple<string, string> parameter in parameters)
    {
        builder.AppendLine();
        builder.Append(CodeRegion.GetIndent(indentLevel));
        builder.Append(XMLCOMMENT_START);
        builder.Append(String.Format(CultureInfo.InvariantCulture, " <param name=\"{0}\">{1}</param>", parameter.Item1, parameter.Item2));
    }
    return builder.ToString();
}

private void WriteFunctionParameters(IEnumerable<FunctionImportParameter> parameters)
{
    foreach (FunctionImportParameter parameter in parameters)
    {
        if (!parameter.NeedsLocalVariable)
        {
            continue;
        }
#>
        ObjectParameter <#=parameter.LocalVariableName#>;
        if (<#=parameter.IsNullableOfT ? parameter.FunctionParameterName + ".HasValue" : parameter.FunctionParameterName + " != null"#>)
        {
            <#=parameter.LocalVariableName#> = new ObjectParameter("<#=parameter.EsqlParameterName#>", <#=parameter.FunctionParameterName#>);
        }
        else
        {
            <#=parameter.LocalVariableName#> = new ObjectParameter("<#=parameter.EsqlParameterName#>", typeof(<#=parameter.RawClrTypeName#>));
        }

<#+
    }
}

string XmlEntityize(string text)
{
    if (string.IsNullOrEmpty(text))
    {
        return string.Empty;
    }

    text = text.Replace("&","&amp;");
    text = text.Replace("<","&lt;").Replace(">","&gt;");
    string id = Guid.NewGuid().ToString();
    text = text.Replace(Environment.NewLine, id);
    text = text.Replace("\r", "&#xD;").Replace("\n","&#xA;");
    text = text.Replace(id, Environment.NewLine);
    return text.Replace("\'","&apos;").Replace("\"","&quot;");
}

const string XMLCOMMENT_START = "///";

public string ModelNamespace{ get; set; }

string GetObjectNamespace(string csdlNamespaceName)
{
    string objectNamespace;
    if (EdmToObjectNamespaceMap.TryGetValue(csdlNamespaceName, out objectNamespace))
    {
        return objectNamespace;
    }

    return csdlNamespaceName;
}

static bool HasBaseMemberWithMatchingName(Type type, string memberName)
{
    BindingFlags bindingFlags = BindingFlags.FlattenHierarchy | BindingFlags.NonPublic | BindingFlags.Public
                | BindingFlags.Instance | BindingFlags.Static;
    return type.GetMembers(bindingFlags).Where(m => IsVisibleMember(m)).Any(m => m.Name == memberName);
}

static bool IsVisibleMethod(MethodBase methodBase)
{
    if (methodBase == null)
        return false;

    return !methodBase.IsPrivate && !methodBase.IsAssembly;
}

static bool IsVisibleMember(MemberInfo memberInfo)
{
    if (memberInfo is EventInfo)
    {
        EventInfo ei = (EventInfo)memberInfo;
        MethodInfo add = ei.GetAddMethod();
        MethodInfo remove = ei.GetRemoveMethod();
        return IsVisibleMethod(add) || IsVisibleMethod(remove);
    }
    else if (memberInfo is FieldInfo)
    {
        FieldInfo fi = (FieldInfo)memberInfo;
        return !fi.IsPrivate && !fi.IsAssembly;
    }
    else if (memberInfo is MethodBase)
    {
        MethodBase mb = (MethodBase)memberInfo;
        if (mb.IsSpecialName)
            return false;
        return IsVisibleMethod(mb);
    }
    else if (memberInfo is PropertyInfo)
    {
        PropertyInfo pi = (PropertyInfo)memberInfo;
        MethodInfo get = pi.GetGetMethod();
        MethodInfo set = pi.GetSetMethod();
        return IsVisibleMethod(get) || IsVisibleMethod(set);
    }

    return false;
}

public Dictionary<string, string> EdmToObjectNamespaceMap
{
    get { return _edmToObjectNamespaceMap; }
    set { _edmToObjectNamespaceMap = value; }
}
public Dictionary<string, string> _edmToObjectNamespaceMap = new Dictionary<string, string>();

void UpdateObjectNamespaceMap(string objectNamespace)
{
    if(objectNamespace != ModelNamespace && !EdmToObjectNamespaceMap.ContainsKey(ModelNamespace))
    {
        EdmToObjectNamespaceMap.Add(ModelNamespace, objectNamespace);
    }
}

private void DefineMetadata()
{
    TemplateMetadata[MetadataConstants.TT_TEMPLATE_NAME] = "CSharpSelfTracking.Context";
    TemplateMetadata[MetadataConstants.TT_TEMPLATE_VERSION] = "5.0";
    }
#>

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Software Developer (Senior)
United States United States
Weidong has been an information system professional since 1990. He has a Master's degree in Computer Science, and is currently a MCSD .NET

Comments and Discussions