<#@ template language="C#" debug="false" hostspecific="true"#>
<#@ include file="EF.Utility.CS.ttinclude"#><#@
output extension=".cs"#><#
DefineMetadata();
CodeGenerationTools code = new CodeGenerationTools(this);
MetadataTools ef = new MetadataTools(this);
MetadataLoader loader = new MetadataLoader(this);
CodeRegion region = new CodeRegion(this);
EntityFrameworkTemplateFileManager fileManager = EntityFrameworkTemplateFileManager.Create(this);
string inputFile = @"SchoolModel.edmx";
EdmItemCollection ItemCollection = loader.CreateEdmItemCollection(inputFile);
ModelNamespace = loader.GetModelNamespace(inputFile);
DefaultSummaryComment = CodeGenerationTools.GetResourceString("Template_CommentNoDocumentation");
string namespaceName = code.EscapeNamespace(@"SchoolSample.EntityModel");
UpdateObjectNamespaceMap(namespaceName);
string codeForClientQuery = @"true";
EntityContainer container = ItemCollection.GetItems<EntityContainer>().FirstOrDefault();
if (container == null)
{
return "// No EntityContainer exists in the model, so no code was generated";
}
WriteHeader(fileManager);
BeginNamespace(namespaceName, code);
#>
<#=Accessibility.ForType(container)#> partial class <#=code.Escape(container)#> : ObjectContext
{
public const string ConnectionString = "name=<#=container.Name#>";
public const string ContainerName = "<#=container.Name#>";
<#
region.Begin("Constructors", 2);
#>
public <#=code.Escape(container)#>()
: base(ConnectionString, ContainerName)
{
Initialize();
}
public <#=code.Escape(container)#>(string connectionString)
: base(connectionString, ContainerName)
{
Initialize();
}
public <#=code.Escape(container)#>(EntityConnection connection)
: base(connection, ContainerName)
{
Initialize();
}
private void Initialize()
{
// Creating proxies requires the use of the ProxyDataContractResolver and
// may allow lazy loading which can expand the loaded graph during serialization.
ContextOptions.ProxyCreationEnabled = false;
ObjectMaterialized += new ObjectMaterializedEventHandler(HandleObjectMaterialized);
}
private void HandleObjectMaterialized(object sender, ObjectMaterializedEventArgs e)
{
var entity = e.Entity as IObjectWithChangeTracker;
if (entity != null)
{
bool changeTrackingEnabled = entity.ChangeTracker.ChangeTrackingEnabled;
try
{
entity.MarkAsUnchanged();
}
finally
{
entity.ChangeTracker.ChangeTrackingEnabled = changeTrackingEnabled;
}
this.StoreReferenceKeyValues(entity);
}
}
<#
region.End();
region.Begin("ObjectSet Properties");
foreach (EntitySet entitySet in container.BaseEntitySets.OfType<EntitySet>())
{
#>
<#=Accessibility.ForReadOnlyProperty(entitySet)#> ObjectSet<<#=code.GetTypeName(entitySet.ElementType)#>> <#=code.Escape(entitySet)#>
{
get { return <#=code.FieldName(entitySet) #> ?? (<#=code.FieldName(entitySet)#> = CreateObjectSet<<#=code.GetTypeName(entitySet.ElementType)#>>("<#=entitySet.Name#>")); }
}
private ObjectSet<<#=code.GetTypeName(entitySet.ElementType)#>> <#=code.FieldName(entitySet)#>;
<#
}
region.End();
region.Begin("Function Imports");
foreach (EdmFunction edmFunction in container.FunctionImports)
{
var parameters = FunctionImportParameter.Create(edmFunction.Parameters, code, ef);
string paramList = String.Join(", ", parameters.Select(p => p.FunctionParameterType + " " + p.FunctionParameterName).ToArray());
TypeUsage returnType = edmFunction.ReturnParameters.Count == 0 ? null : ef.GetElementType(edmFunction.ReturnParameters[0].TypeUsage);
if (edmFunction.IsComposableAttribute)
{
#>
/// <summary>
/// <#=SummaryComment(edmFunction)#>
/// </summary><#=LongDescriptionCommentElement(edmFunction, region.CurrentIndentLevel)#><#=ParameterComments(parameters.Select(p => new Tuple<string, string>(p.RawFunctionParameterName, SummaryComment(p.Source))), region.CurrentIndentLevel)#>
[EdmFunction("<#=edmFunction.NamespaceName#>", "<#=edmFunction.Name#>")]
<#=code.SpaceAfter(NewModifier(edmFunction))#><#=AccessibilityAndVirtual(Accessibility.ForMethod(edmFunction))#> <#="IQueryable<" + code.GetTypeName(returnType, ModelNamespace) + ">"#> <#=code.Escape(edmFunction)#>(<#=paramList#>)
{
<#
WriteFunctionParameters(parameters);
#>
return base.CreateQuery<<#=code.GetTypeName(returnType, ModelNamespace)#>>("[<#=edmFunction.NamespaceName#>].[<#=edmFunction.Name#>](<#=string.Join(", ", parameters.Select(p => "@" + p.EsqlParameterName).ToArray())#>)"<#=code.StringBefore(", ", string.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
}
<#
}
else
{
#>
/// <summary>
/// <#=SummaryComment(edmFunction)#>
/// </summary><#=LongDescriptionCommentElement(edmFunction, region.CurrentIndentLevel)#><#=ParameterComments(parameters.Select(p => new Tuple<string, string>(p.RawFunctionParameterName, SummaryComment(p.Source))), region.CurrentIndentLevel)#>
<#=code.SpaceAfter(NewModifier(edmFunction))#><#=AccessibilityAndVirtual(Accessibility.ForMethod(edmFunction))#> <#=returnType == null ? "int" : "ObjectResult<" + code.GetTypeName(returnType, ModelNamespace) + ">"#> <#=code.Escape(edmFunction)#>(<#=paramList#>)
{
<#
WriteFunctionParameters(parameters);
#>
return base.ExecuteFunction<#=returnType == null ? "" : "<" + code.GetTypeName(returnType, ModelNamespace) + ">"#>("<#=edmFunction.Name#>"<#=code.StringBefore(", ", String.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
}
<#
if(returnType != null && returnType.EdmType.BuiltInTypeKind == BuiltInTypeKind.EntityType)
{
#>
/// <summary>
/// <#=SummaryComment(edmFunction)#>
/// </summary><#=LongDescriptionCommentElement(edmFunction, region.CurrentIndentLevel)#>
/// <param name="mergeOption"></param><#=ParameterComments(parameters.Select(p => new Tuple<string, string>(p.RawFunctionParameterName, SummaryComment(p.Source))), region.CurrentIndentLevel)#>
<#=code.SpaceAfter(NewModifier(edmFunction))#><#=Accessibility.ForMethod(edmFunction)#> <#=returnType == null ? "int" : "ObjectResult<" + code.GetTypeName(returnType, ModelNamespace) + ">"#> <#=code.Escape(edmFunction)#>(<#=code.StringAfter(paramList, ", ")#>MergeOption mergeOption)
{
<#
WriteFunctionParameters(parameters);
#>
return base.<#=returnType == null ? "ExecuteFunction" : "ExecuteFunction<" + code.GetTypeName(returnType, ModelNamespace) + ">"#>("<#=edmFunction.Name#>", mergeOption<#=code.StringBefore(", ", string.Join(", ", parameters.Select(p => p.ExecuteParameterName).ToArray()))#>);
}
<#
}
}
}
region.End();
#>
}
<#
if (codeForClientQuery == "true")
{
#>
public static class ObjectQueryExtension
{
/// <summary>
/// ApplyIncludePath takes the list of include paths from a CleintQuery object
/// and calls Include() on the ObjectQuery.
/// </summary>
/// <typeparam name="TEntity">Expected type of the ObjectQuery</typeparam>
/// <param name="source">The ObjectQuery to which the list of include paths will be applied</param>
/// <param name="clientQuery">The ClientQuery object that contains the list of include paths</param>
/// <returns></returns>
public static ObjectQuery<TEntity> ApplyIncludePath<TEntity>(this ObjectQuery<TEntity> source,
ClientQuery clientQuery) where TEntity : class
{
if (source == null)
{
throw new ArgumentNullException("source");
}
if (clientQuery == null)
{
throw new ArgumentNullException("clientQuery");
}
return clientQuery.IncludeList.Aggregate(source, (current, includeItem) => current.Include(includeItem));
}
/// <summary>
/// ApplyClientQuery takes both the list of include paths and serialized Expression tree
/// from a CleintQuery object and applies on the ObjectQuery.
/// </summary>
/// <typeparam name="TEntity">Expected type of the ObjectQuery</typeparam>
/// <param name="source">The ObjectQuery to which both the list of include paths and serialized Expression tree will be applied</param>
/// <param name="clientQuery">The ClientQuery object that contains both the list of include paths and serialized Expression tree</param>
/// <param name="assemblies">Additional assemblies where types will be searched</param>
public static ObjectQuery<TEntity> ApplyClientQuery<TEntity>(this ObjectQuery<TEntity> source,
ClientQuery clientQuery,
IEnumerable<Assembly> assemblies = null)
where TEntity : class
{
source = source.ApplyIncludePath(clientQuery);
var sourceExpressionVisitor = new SourceExpressionVisitor();
var sourceExpression = sourceExpressionVisitor.Visit(((IOrderedQueryable) source).Expression);
var deserializer = new Deserializer(sourceExpression, assemblies);
var expression = deserializer.Deserialize(clientQuery.XmlExpression);
var lambda = Expression.Lambda<Func<<#=code.Escape(container)#>, IQueryable<TEntity>>>(expression,
sourceExpressionVisitor.
ContextParameterExpression);
var compiledQuery = CompiledQuery.Compile(lambda);
return (ObjectQuery<TEntity>) compiledQuery.Invoke((<#=code.Escape(container)#>) source.Context);
}
}
internal class SourceExpressionVisitor : ExpressionVisitor
{
#region Private Data Member
private ParameterExpression _contextParameterExpression;
#endregion Private Data Member
#region Public Property
public ParameterExpression ContextParameterExpression
{
get
{
if (_contextParameterExpression == null)
{
throw new Exception("ContextParameterExpression cannot be null.");
}
return _contextParameterExpression;
}
}
#endregion Public Property
#region Protected & Private Method
protected override Expression VisitConstant(ConstantExpression node)
{
if (typeof (IOrderedQueryable).IsAssignableFrom(node.Type))
{
var elementType = ((IOrderedQueryable) (node).Value).ElementType;
return GetExpression(elementType);
}
return base.VisitConstant(node);
}
private Expression GetExpression(Type elementType)
{
<#
foreach (EntityType entity in ItemCollection.GetItems<EntityType>().OrderBy(e => e.Name))
{
if (entity.BaseType == null)
{
#>
if (elementType == typeof (<#=code.Escape(entity)#>))
{
Expression<Func<<#=code.Escape(container)#>, ObjectQuery<<#=code.Escape(entity)#>>>> query =
context => context.<#=code.Escape(GetEntitySetFromEntityType(container, entity))#>;
_contextParameterExpression = query.Parameters[0];
return query.Body;
}
<#
}
else
{
#>
if (elementType == typeof (<#=code.Escape(entity)#>))
{
Expression<Func<<#=code.Escape(container)#>, ObjectQuery<<#=code.Escape(entity)#>>>> query =
context => context.<#=code.Escape(GetEntitySetFromEntityType(container, entity))#>.OfType<<#=code.Escape(entity)#>>();
_contextParameterExpression = query.Parameters[0];
return query.Body;
}
<#
}
}
#>
throw new Exception("no match for elementType type.");
}
#endregion Protected & Private Method
}
<#
WriteDeserializer();
WriteTypeResolver();
}
EndNamespace(namespaceName);
fileManager.StartNewFile(Path.GetFileNameWithoutExtension(Host.TemplateFile) + ".Extensions.cs");
BeginNamespace(namespaceName, code);
WriteApplyChanges(code);
EndNamespace(namespaceName);
fileManager.Process();
#>
<#+
private void WriteHeader(EntityFrameworkTemplateFileManager fileManager, params string[] extraUsings)
{
fileManager.StartHeader();
#>
//------------------------------------------------------------------------------
// <auto-generated>
// This code was generated from a template.
//
// Changes to this file may cause incorrect behavior and will be lost if
// the code is regenerated.
// </auto-generated>
//------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Data.Common;
using System.Data.EntityClient;
using System.Data.Metadata.Edm;
using System.Data.Objects;
using System.Data.Objects.DataClasses;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.Serialization;
using System.Text;
using System.Threading;
using System.Xml.Linq;
<#=String.Join(String.Empty, extraUsings.Select(u => "using " + u + ";" + Environment.NewLine).ToArray())#>
<#+
fileManager.EndBlock();
}
void BeginNamespace(string namespaceName, CodeGenerationTools code)
{
CodeRegion region = new CodeRegion(this);
if (!String.IsNullOrEmpty(namespaceName))
{
#>
namespace <#=code.EscapeNamespace(namespaceName)#>
{
<#+
PushIndent(CodeRegion.GetIndent(1));
}
}
void EndNamespace(string namespaceName)
{
if (!String.IsNullOrEmpty(namespaceName))
{
PopIndent();
#>
}
<#+
}
}
string AccessibilityAndVirtual(string accessibility)
{
if (accessibility != "private")
{
return accessibility + " virtual";
}
return accessibility;
}
EntitySet GetEntitySetFromEntityType(EntityContainer container, EntityType entity)
{
EdmType edmType = entity;
while (edmType.BaseType != null) edmType = edmType.BaseType;
return container.BaseEntitySets.OfType<EntitySet>()
.Single(set => set.ElementType.Name.Equals(edmType.Name));
}
void WriteDeserializer()
{
#>
internal class Deserializer
{
#region Private Data Member
private readonly Dictionary<string, ParameterExpression> _parameters =
new Dictionary<string, ParameterExpression>();
private readonly Expression _substituteExpression;
private readonly TypeResolver _resolver;
#endregion Private Data Member
#region Constructor
public Deserializer(Expression substituteExpression, IEnumerable<Assembly> assemblies = null)
{
_substituteExpression = substituteExpression;
_resolver = new TypeResolver(assemblies);
}
#endregion Constructor
#region Public Deserializer Method
public Expression Deserialize(XElement xml)
{
_parameters.Clear();
return ParseExpressionFromXmlNonNull(xml);
}
#endregion Public Deserializer Method
#region Private Deserializer Method
private Expression ParseExpressionFromXmlNonNull(XElement xml)
{
if (xml.Name.LocalName == "ClientQuery" && xml.Attribute("elementType") != null)
{
return _substituteExpression;
}
switch (xml.Name.LocalName)
{
case "BinaryExpression":
return ParseBinaryExpressionFromXml(xml);
case "ConditionalExpression":
return ParseConditionalExpressionFromXml(xml);
case "ConstantExpression":
return ParseConstantExpressionFromXml(xml);
case "InvocationExpression":
return ParseInvocationExpressionFromXml(xml);
case "LambdaExpression":
return ParseLambdaExpressionFromXml(xml);
case "ListInitExpression":
return ParseListInitExpressionFromXml(xml);
case "MemberExpression":
return ParseMemberExpressionFromXml(xml);
case "MemberInitExpression":
return ParseMemberInitExpressionFromXml(xml);
case "MethodCallExpression":
return ParseMethodCallExpressionFromXml(xml);
case "NewArrayExpression":
return ParseNewArrayExpressionFromXml(xml);
case "NewExpression":
return ParseNewExpressionFromXml(xml);
case "ParameterExpression":
return ParseParameterExpressionFromXml(xml);
case "TypeBinaryExpression":
return ParseTypeBinaryExpressionFromXml(xml);
case "UnaryExpression":
return ParseUnaryExpressionFromXml(xml);
default:
throw new NotSupportedException(xml.Name.LocalName);
}
}
/// <summary>
/// Parse BinaryExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private BinaryExpression ParseBinaryExpressionFromXml(XElement xml)
{
var right = ParseExpressionFromXml(xml.Element("Right"));
var left = ParseExpressionFromXml(xml.Element("Left"));
var method = ParseMethodInfoFromXml(xml.Element("Method"));
var conversion = ParseExpressionFromXml(xml.Element("Conversion")) as LambdaExpression;
var isLiftedToNull = ParseWithDataContractSerializer<bool>(xml.Element("IsLiftedToNull"));
var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
return Expression.MakeBinary(expressionType, left, right, isLiftedToNull, method, conversion);
}
/// <summary>
/// Parse ConditionalExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private ConditionalExpression ParseConditionalExpressionFromXml(XElement xml)
{
var test = ParseExpressionFromXml(xml.Element("Test"));
var ifTrue = ParseExpressionFromXml(xml.Element("IfTrue"));
var ifFalse = ParseExpressionFromXml(xml.Element("IfFalse"));
var type = ParseTypeFromXml(xml.Element("Type"));
return Expression.Condition(test, ifTrue, ifFalse, type);
}
/// <summary>
/// Parse ConstantExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private ConstantExpression ParseConstantExpressionFromXml(XElement xml)
{
var type = ParseTypeFromXml(xml.Element("Type"));
var result = ParseWithDataContractSerializer(xml.Element("Value"), type);
return Expression.Constant(result, type);
}
/// <summary>
/// Parse InvocationExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private InvocationExpression ParseInvocationExpressionFromXml(XElement xml)
{
var expression = ParseExpressionFromXml(xml.Element("Expression"));
var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments");
return Expression.Invoke(expression, arguments);
}
/// <summary>
/// Parse LambdaExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private LambdaExpression ParseLambdaExpressionFromXml(XElement xml)
{
var type = ParseTypeFromXml(xml.Element("Type"));
var body = ParseExpressionFromXml(xml.Element("Body"));
var name = ParseWithDataContractSerializer<string>(xml.Element("Name"));
var tailCall = ParseWithDataContractSerializer<bool>(xml.Element("TailCall"));
var parameters = ParseExpressionListFromXml<ParameterExpression>(xml, "Parameters");
return Expression.Lambda(type, body, name, tailCall, parameters);
}
/// <summary>
/// Parse ListInitExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private ListInitExpression ParseListInitExpressionFromXml(XElement xml)
{
var newExpression = ParseExpressionFromXml(xml.Element("NewExpression")) as NewExpression;
if (newExpression == null) throw new Exception("Expceted a NewExpression");
var initializers = ParseElementInitListFromXml(xml, "Initializers");
return Expression.ListInit(newExpression, initializers);
}
/// <summary>
/// Parse MemberExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private MemberExpression ParseMemberExpressionFromXml(XElement xml)
{
var expression = ParseExpressionFromXml(xml.Element("Expression"));
var member = ParseMemberInfoFromXml(xml.Element("Member"));
return Expression.MakeMemberAccess(expression, member);
}
/// <summary>
/// Parse MemberInitExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private MemberInitExpression ParseMemberInitExpressionFromXml(XElement xml)
{
var newExpression = ParseExpressionFromXml(xml.Element("NewExpression")) as NewExpression;
if (newExpression == null) throw new Exception("Expceted a NewExpression");
var bindings = ParseMemberBindingListFromXml(xml, "Bindings");
if (bindings == null) throw new Exception("Bindings cannot be null");
return Expression.MemberInit(newExpression, bindings);
}
/// <summary>
/// Parse MethodCallExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private MethodCallExpression ParseMethodCallExpressionFromXml(XElement xml)
{
var instance = ParseExpressionFromXml(xml.Element("Object"));
var method = ParseMethodInfoFromXml(xml.Element("Method"));
var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments") ?? new Expression[0];
return Expression.Call(instance, method, arguments);
}
/// <summary>
/// Parse NewArrayExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private NewArrayExpression ParseNewArrayExpressionFromXml(XElement xml)
{
var type = ParseTypeFromXml(xml.Element("Type"));
if (type == null) throw new Exception("Type cannot be null");
if (!type.IsArray) throw new Exception("Expected array type");
var elemType = type.GetElementType();
var expressions = ParseExpressionListFromXml<Expression>(xml, "Expressions");
var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
switch (expressionType)
{
case ExpressionType.NewArrayInit:
return Expression.NewArrayInit(elemType, expressions);
case ExpressionType.NewArrayBounds:
return Expression.NewArrayBounds(elemType, expressions);
default:
throw new Exception("Expected NewArrayInit or NewArrayBounds");
}
}
/// <summary>
/// Parse NewExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private NewExpression ParseNewExpressionFromXml(XElement xml)
{
var constructor = ParseConstructorInfoFromXml(xml.Element("Constructor"));
var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments").ToArray();
var members = ParseMemberInfoListFromXml<MemberInfo>(xml, "Members").ToArray();
return members.Length == 0
? Expression.New(constructor, arguments)
: Expression.New(constructor, arguments, members);
}
/// <summary>
/// Parse ParameterExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private ParameterExpression ParseParameterExpressionFromXml(XElement xml)
{
var type = ParseTypeFromXml(xml.Element("Type"));
var name = ParseWithDataContractSerializer<string>(xml.Element("Name"));
var id = name + type.FullName;
if (!_parameters.ContainsKey(id))
_parameters.Add(id, Expression.Parameter(type, name));
return _parameters[id];
}
/// <summary>
/// Parse TypeBinaryExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private TypeBinaryExpression ParseTypeBinaryExpressionFromXml(XElement xml)
{
var expression = ParseExpressionFromXml(xml.Element("Expression"));
var typeOperand = ParseTypeFromXml(xml.Element("TypeOperand"));
var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
switch (expressionType)
{
case ExpressionType.TypeEqual:
return Expression.TypeEqual(expression, typeOperand);
case ExpressionType.TypeIs:
return Expression.TypeIs(expression, typeOperand);
default:
throw new Exception("Expected TypeEqual or TypeIs");
}
}
/// <summary>
/// Parse UnaryExpression From XElement
/// </summary>
/// <param name="xml"></param>
/// <returns></returns>
private UnaryExpression ParseUnaryExpressionFromXml(XElement xml)
{
var operand = ParseExpressionFromXml(xml.Element("Operand"));
var type = ParseTypeFromXml(xml.Element("Type"));
var method = ParseMethodInfoFromXml(xml.Element("Method"));
var expressionType = ParseWithDataContractSerializer<ExpressionType>(xml.Element("NodeType"));
return Expression.MakeUnary(expressionType, operand, type, method);
}
#endregion Private Deserializer Method
#region Private Deserializer Helper Method
private ConstructorInfo ParseConstructorInfoFromXml(XElement xml)
{
if (xml.IsEmpty) return null;
var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
var xElement = xml.Element("Parameters");
if (xElement == null) throw new Exception("Parameters not found.");
var parameters = from paramXml in xElement.Elements()
select ParseTypeFromXml(paramXml);
return declaringType.GetConstructor(parameters.ToArray());
}
private ElementInit ParseElementInitFromXml(XElement xml)
{
var addMethod = ParseMethodInfoFromXml(xml.Element("AddMethod"));
var arguments = ParseExpressionListFromXml<Expression>(xml, "Arguments");
return Expression.ElementInit(addMethod, arguments);
}
private Expression ParseExpressionFromXml(XElement xml)
{
return xml.IsEmpty ? null : ParseExpressionFromXmlNonNull(xml.Elements().First());
}
private FieldInfo ParseFieldInfoFromXml(XElement xml)
{
var xAttribute = xml.Attribute("FieldName");
if (xAttribute == null) throw new Exception("FieldName not found.");
var fieldName = xAttribute.Value;
var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
return declaringType.GetField(fieldName);
}
private MemberBinding ParseMemberBindingFromXml(XElement xml)
{
var member = ParseMemberInfoFromXml(xml.Element("Member"));
switch (xml.Name.LocalName)
{
case "MemberAssignment":
var expression = ParseExpressionFromXml(xml.Element("Expression"));
return Expression.Bind(member, expression);
case "MemberMemberBinding":
var bindings = ParseMemberBindingListFromXml(xml, "Bindings");
return Expression.MemberBind(member, bindings);
case "MemberListBinding":
var initializers = ParseElementInitListFromXml(xml, "Initializers");
return Expression.ListBind(member, initializers);
}
throw new NotImplementedException();
}
private MemberInfo ParseMemberInfoFromXml(XElement xml)
{
var xAttribute = xml.Attribute("MemberType");
if (xAttribute == null) throw new Exception("MemberType not found.");
var memberType = (MemberTypes) Enum.Parse(typeof (MemberTypes), xAttribute.Value, false);
switch (memberType)
{
case MemberTypes.Field:
return ParseFieldInfoFromXml(xml);
case MemberTypes.Property:
return ParsePropertyInfoFromXml(xml);
case MemberTypes.Method:
return ParseMethodInfoFromXml(xml);
case MemberTypes.Constructor:
return ParseConstructorInfoFromXml(xml);
default:
throw new NotSupportedException(string.Format("MEmberType {0} not supported", memberType));
}
}
private MethodInfo ParseMethodInfoFromXml(XElement xml)
{
if (xml.IsEmpty) return null;
var xAttribute = xml.Attribute("MethodName");
if (xAttribute == null) throw new Exception("MethodName not found.");
var name = xAttribute.Value;
var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
var xElement = xml.Element("Parameters");
if (xElement == null) throw new Exception("Parameters not found.");
var parameters = from paramXml in xElement.Elements()
select ParseTypeFromXml(paramXml);
xElement = xml.Element("GenericArgTypes");
if (xElement == null) throw new Exception("GenericArgTypes not found.");
var genArgs = from argXml in xElement.Elements()
select ParseTypeFromXml(argXml);
return _resolver.GetMethod(declaringType, name, parameters.ToArray(), genArgs.ToArray());
}
private PropertyInfo ParsePropertyInfoFromXml(XElement xml)
{
var xAttribute = xml.Attribute("PropertyName");
if (xAttribute == null) throw new Exception("PropertyName not found.");
var propertyName = xAttribute.Value;
var declaringType = ParseTypeFromXml(xml.Element("DeclaringType"));
return declaringType.GetProperty(propertyName);
}
private Type ParseTypeFromXml(XElement xml)
{
return ParseTypeFromXmlCore(xml.Elements().First());
}
private IEnumerable<ElementInit> ParseElementInitListFromXml(XElement xml, string elemName)
{
var xElement = xml.Element(elemName);
if (xElement == null) throw new Exception(elemName + " not found.");
return from tXml in xElement.Elements()
select ParseElementInitFromXml(tXml);
}
private IEnumerable<T> ParseExpressionListFromXml<T>(XElement xml, string elemName) where T : Expression
{
var xElement = xml.Element(elemName);
if (xElement == null) throw new Exception(elemName + " not found.");
return from tXml in xElement.Elements()
select (T) ParseExpressionFromXmlNonNull(tXml);
}
private IEnumerable<MemberBinding> ParseMemberBindingListFromXml(XElement xml, string elemName)
{
var xElement = xml.Element(elemName);
if (xElement == null) throw new Exception(elemName + " not found.");
return from tXml in xElement.Elements()
select ParseMemberBindingFromXml(tXml);
}
private IEnumerable<T> ParseMemberInfoListFromXml<T>(XElement xml, string elemName) where T : MemberInfo
{
var xElement = xml.Element(elemName);
if (xElement == null) throw new Exception(elemName + " not found.");
return from tXml in xElement.Elements()
select (T) ParseMemberInfoFromXml(tXml);
}
private Type ParseTypeFromXmlCore(XElement xml)
{
switch (xml.Name.ToString())
{
case "Type":
return ParseNormalTypeFromXmlCore(xml);
case "AnonymousType":
return ParseAnonymousTypeFromXmlCore(xml);
default:
throw new ArgumentException("Expected 'Type' or 'AnonymousType'");
}
}
private Type ParseNormalTypeFromXmlCore(XElement xml)
{
if (!xml.HasElements)
{
var xAttribute = xml.Attribute("Name");
if (xAttribute == null) throw new Exception("Name not found.");
return _resolver.GetType(xAttribute.Value);
}
var genericArgumentTypes = from genArgXml in xml.Elements()
select ParseTypeFromXmlCore(genArgXml);
var attribute = xml.Attribute("Name");
if (attribute == null) throw new Exception("Name not found.");
return _resolver.GetType(attribute.Value, genericArgumentTypes);
}
private Type ParseAnonymousTypeFromXmlCore(XElement xElement)
{
var xAttribute = xElement.Attribute("Name");
if (xAttribute == null) throw new Exception("Name not found.");
var name = xAttribute.Value;
var properties = from propXml in xElement.Elements("Property")
let attribute = propXml.Attribute("Name")
where attribute != null
select new TypeResolver.NameTypePair
{
Name = attribute.Value,
Type = ParseTypeFromXml(propXml)
};
var ctrParams = from propXml in xElement.Elements("Constructor").Elements("Parameter")
let attribute = propXml.Attribute("Name")
where attribute != null
select new TypeResolver.NameTypePair
{
Name = attribute.Value,
Type = ParseTypeFromXml(propXml)
};
return _resolver.GetOrCreateAnonymousType(name, properties.ToArray(), ctrParams.ToArray());
}
private T ParseWithDataContractSerializer<T>(XElement xml)
{
if (xml == null) throw new Exception("XElement cannot be null.");
using (var stream = new MemoryStream(Encoding.UTF8.GetBytes(xml.Value)))
{
var deserializer = new DataContractSerializer(typeof (T));
return (T) deserializer.ReadObject(stream);
}
}
private object ParseWithDataContractSerializer(XElement xml, Type expectedType)
{
if (xml == null) throw new Exception("XElement cannot be null.");
using (var stream = new MemoryStream(Encoding.UTF8.GetBytes(xml.Value)))
{
var deserializer = new DataContractSerializer(expectedType);
return deserializer.ReadObject(stream);
}
}
#endregion Private Deserializer Helper Method
}
<#+
}
void WriteTypeResolver()
{
#>
internal sealed class TypeResolver
{
#region Private Data Member
private readonly Dictionary<AnonTypeId, Type> _anonymousTypes = new Dictionary<AnonTypeId, Type>();
private readonly ModuleBuilder _moduleBuilder;
private int _anonymousTypeIndex;
private readonly HashSet<Assembly> _assemblies = new HashSet<Assembly>
{
typeof (String).Assembly,
// mscorlib.dll
typeof (ExpressionType).Assembly,
// System.Core.dll
typeof (XElement).Assembly,
// System.Xml.Linq.dll
Assembly.GetExecutingAssembly(),
};
#endregion Private Data Member
#region Constructor
public TypeResolver(IEnumerable<Assembly> assemblies = null)
{
var asmname = new AssemblyName {Name = "AnonymousTypes"};
var assemblyBuilder = Thread.GetDomain().DefineDynamicAssembly(asmname,
AssemblyBuilderAccess.Run);
_moduleBuilder = assemblyBuilder.DefineDynamicModule("AnonymousTypes");
if (assemblies != null)
{
foreach (var assembly in assemblies)
_assemblies.Add(assembly);
}
}
#endregion Constructor
#region Public Method
/// <summary>
/// Get type of a generic type
/// </summary>
/// <param name="typeName"></param>
/// <param name="genericArgumentTypes"></param>
/// <returns></returns>
public Type GetType(string typeName, IEnumerable<Type> genericArgumentTypes)
{
return GetType(typeName).MakeGenericType(genericArgumentTypes.ToArray());
}
/// <summary>
/// Get type based on type name
/// </summary>
/// <param name="typeName"></param>
/// <returns></returns>
public Type GetType(string typeName)
{
Type type;
if (string.IsNullOrEmpty(typeName))
throw new ArgumentNullException("typeName");
// array type
if (typeName.EndsWith("[]"))
return GetType(typeName.Substring(0, typeName.Length - 2)).MakeArrayType();
// load type from assemblies
foreach (var assembly in _assemblies)
{
type = assembly.GetType(typeName);
if (type != null)
return type;
}
// call Type.GetType()
type = Type.GetType(typeName, false, true);
if (type != null)
return type;
throw new ArgumentException("Could not find a matching type", typeName);
}
/// <summary>
/// Get method based on type, name and parameters
/// </summary>
/// <param name="declaringType"></param>
/// <param name="name"></param>
/// <param name="parameterTypes"></param>
/// <param name="genArgTypes"></param>
/// <returns></returns>
public MethodInfo GetMethod(Type declaringType, string name, Type[] parameterTypes, Type[] genArgTypes)
{
var methods = from mi in declaringType.GetMethods()
where mi.Name == name
select mi;
foreach (var method in methods)
{
try
{
var realMethod = method;
if (method.IsGenericMethod)
{
realMethod = method.MakeGenericMethod(genArgTypes);
}
var methodParameterTypes = realMethod.GetParameters().Select(p => p.ParameterType);
if (MatchPiecewise(parameterTypes, methodParameterTypes))
{
return realMethod;
}
}
catch (ArgumentException)
{
}
}
return null;
}
/// <summary>
/// Get or create anonymous type
/// </summary>
/// <param name="name"></param>
/// <param name="properties"></param>
/// <param name="ctrParams"></param>
/// <returns></returns>
public Type GetOrCreateAnonymousType(string name, NameTypePair[] properties, NameTypePair[] ctrParams)
{
var id = new AnonTypeId(name, properties.Concat(ctrParams));
if (_anonymousTypes.ContainsKey(id))
return _anonymousTypes[id];
const string anonPrefix = "<>f__AnonymousType";
var anonTypeBuilder = _moduleBuilder.DefineType(anonPrefix + _anonymousTypeIndex++,
TypeAttributes.Public | TypeAttributes.Class);
var fieldBuilders = new FieldBuilder[properties.Length];
var propertyBuilders = new PropertyBuilder[properties.Length];
for (var i = 0; i < properties.Length; i++)
{
fieldBuilders[i] = anonTypeBuilder.DefineField("_generatedfield_" + properties[i].Name,
properties[i].Type, FieldAttributes.Private);
propertyBuilders[i] = anonTypeBuilder.DefineProperty(properties[i].Name,
System.Reflection.PropertyAttributes.None,
properties[i].Type, new Type[0]);
var propertyGetterBuilder = anonTypeBuilder.DefineMethod("get_" + properties[i].Name,
MethodAttributes.Public,
properties[i].Type, new Type[0]);
var getIlGenerator = propertyGetterBuilder.GetILGenerator();
getIlGenerator.Emit(OpCodes.Ldarg_0);
getIlGenerator.Emit(OpCodes.Ldfld, fieldBuilders[i]);
getIlGenerator.Emit(OpCodes.Ret);
propertyBuilders[i].SetGetMethod(propertyGetterBuilder);
}
var constructorBuilder = anonTypeBuilder.DefineConstructor(
MethodAttributes.HideBySig | MethodAttributes.Public | MethodAttributes.Public,
CallingConventions.Standard, ctrParams.Select(prop => prop.Type).ToArray());
var constructorIlGenerator = constructorBuilder.GetILGenerator();
for (var i = 0; i < ctrParams.Length; i++)
{
constructorIlGenerator.Emit(OpCodes.Ldarg_0);
constructorIlGenerator.Emit(OpCodes.Ldarg, i + 1);
constructorIlGenerator.Emit(OpCodes.Stfld, fieldBuilders[i]);
constructorBuilder.DefineParameter(i + 1, ParameterAttributes.None, ctrParams[i].Name);
}
constructorIlGenerator.Emit(OpCodes.Ret);
var anonType = anonTypeBuilder.CreateType();
_anonymousTypes.Add(id, anonType);
return anonType;
}
#endregion Public Method
#region Private Method
private bool MatchPiecewise<T>(IEnumerable<T> first, IEnumerable<T> second)
{
var firstArray = first.ToArray();
var secondArray = second.ToArray();
if (firstArray.Length != secondArray.Length)
return false;
return !firstArray.Where((t, i) => !t.Equals(secondArray[i])).Any();
}
#endregion Private Method
#region nested classes
public class NameTypePair
{
public string Name { get; set; }
public Type Type { get; set; }
public override int GetHashCode()
{
return Name.GetHashCode() + Type.GetHashCode();
}
public override bool Equals(object obj)
{
if (!(obj is NameTypePair))
return false;
var other = obj as NameTypePair;
return Name.Equals(other.Name) && Type.Equals(other.Type);
}
}
private class AnonTypeId
{
public string Name { get; private set; }
public IEnumerable<NameTypePair> Properties { get; private set; }
public AnonTypeId(string name, IEnumerable<NameTypePair> properties)
{
Name = name;
Properties = properties;
}
public override int GetHashCode()
{
return Name.GetHashCode() + Properties.Sum(ntpair => ntpair.GetHashCode());
}
public override bool Equals(object obj)
{
if (!(obj is AnonTypeId))
return false;
var other = obj as AnonTypeId;
return (Name.Equals(other.Name)
&& Properties.SequenceEqual(other.Properties));
}
}
#endregion
}
<#+
}
void WriteApplyChanges(CodeGenerationTools code)
{
#>
public static class SelfTrackingEntitiesContextExtensions
{
/// <summary>
/// ApplyChanges takes the changes in a connected set of entities and applies them to an ObjectContext.
/// </summary>
/// <typeparam name="TEntity">Expected type of the ObjectSet</typeparam>
/// <param name="objectSet">The ObjectSet referencing the ObjectContext to which changes will be applied.</param>
/// <param name="entity">The entity serving as the entry point of the object graph that contains changes.</param>
public static void ApplyChanges<TEntity>(this ObjectSet<TEntity> objectSet, TEntity entity) where TEntity : class, IObjectWithChangeTracker
{
if (objectSet == null)
{
throw new ArgumentNullException("objectSet");
}
objectSet.Context.ApplyChanges<TEntity>(objectSet.EntitySet.EntityContainer.Name + "." + objectSet.EntitySet.Name, entity);
}
/// <summary>
/// ApplyChanges takes the changes in a connected set of entities and applies them to an ObjectContext.
/// </summary>
/// <typeparam name="TEntity">Expected type of the EntitySet</typeparam>
/// <param name="context">The ObjectContext to which changes will be applied.</param>
/// <param name="entitySetName">The EntitySet name of the entity.</param>
/// <param name="entity">The entity serving as the entry point of the object graph that contains changes.</param>
public static void ApplyChanges<TEntity>(this ObjectContext context, string entitySetName, TEntity entity) where TEntity : IObjectWithChangeTracker
{
if (context == null)
{
throw new ArgumentNullException("context");
}
if (String.IsNullOrEmpty(entitySetName))
{
throw new ArgumentException("String parameter cannot be null or empty.", "entitySetName");
}
if (entity == null)
{
throw new ArgumentNullException("entity");
}
bool lazyLoadingSetting = context.ContextOptions.LazyLoadingEnabled;
try
{
context.ContextOptions.LazyLoadingEnabled = false;
EntityIndex entityIndex = AddHelper.AddAllEntities(context, entitySetName, entity);
RelationshipSet allRelationships = new RelationshipSet(context, entityIndex.AllEntities);
#region Handle Initial Entity State
foreach (IObjectWithChangeTracker changedEntity in entityIndex.AllEntities.Where(x => x.ChangeTracker.State == ObjectState.Deleted))
{
HandleDeletedEntity(context, entityIndex, allRelationships, changedEntity);
}
foreach (IObjectWithChangeTracker changedEntity in entityIndex.AllEntities.Where(x => x.ChangeTracker.State != ObjectState.Deleted))
{
HandleEntity(context, entityIndex, allRelationships, changedEntity);
}
#endregion
#region Loop through each object state entries
foreach (IObjectWithChangeTracker changedEntity in entityIndex.AllEntities)
{
ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(changedEntity);
EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(changedEntity);
foreach (NavigationProperty navProp in entityType.NavigationProperties)
{
RelatedEnd relatedEnd = entry.GetRelatedEnd(navProp.Name);
if(!((AssociationType)relatedEnd.RelationshipSet.ElementType).IsForeignKey)
{
ApplyChangesToIndependentAssociation(context, (IObjectWithChangeTracker)changedEntity, entry, navProp, relatedEnd, allRelationships);
}
}
}
#endregion
// Change all the remaining relationships to the appropriate state
foreach (var relationship in allRelationships)
{
context.ObjectStateManager.ChangeRelationshipState(
relationship.End0,
relationship.End1,
relationship.AssociationSet.ElementType.FullName,
relationship.AssociationEndMembers[1].Name,
relationship.State);
}
}
finally
{
context.ContextOptions.LazyLoadingEnabled = lazyLoadingSetting;
}
}
private static void ApplyChangesToIndependentAssociation(ObjectContext context, IObjectWithChangeTracker changedEntity, ObjectStateEntry entry, NavigationProperty navProp,
IRelatedEnd relatedEnd, RelationshipSet allRelationships)
{
ObjectChangeTracker changeTracker = changedEntity.ChangeTracker;
if (changeTracker.State == ObjectState.Added)
{
// Relationships should remain added so remove them from the list of allRelationships
foreach (object relatedEntity in relatedEnd)
{
ObjectStateEntry addedRelationshipEntry =
context.ObjectStateManager.ChangeRelationshipState(
changedEntity,
relatedEntity,
navProp.Name,
EntityState.Added);
allRelationships.Remove(addedRelationshipEntry);
}
}
else
{
if (navProp.ToEndMember.RelationshipMultiplicity == RelationshipMultiplicity.Many)
{
//Handle removal to FixupCollections
ObjectList collectionPropertyChanges = null;
if (changeTracker.ObjectsRemovedFromCollectionProperties.TryGetValue(navProp.Name, out collectionPropertyChanges))
{
foreach (var removedEntityFromAssociation in collectionPropertyChanges)
{
ObjectStateEntry deletedRelationshipEntry =
context.ObjectStateManager.ChangeRelationshipState(
changedEntity,
removedEntityFromAssociation,
navProp.Name,
EntityState.Deleted);
allRelationships.Remove(deletedRelationshipEntry);
}
}
//Handle addition to FixupCollection
if (changeTracker.ObjectsAddedToCollectionProperties.TryGetValue(navProp.Name, out collectionPropertyChanges))
{
foreach (var addedEntityFromAssociation in collectionPropertyChanges)
{
allRelationships.Remove(AddRelationshipUnlessExists(context, relatedEnd, entry, addedEntityFromAssociation, navProp.Name));
}
}
}
else
{
// Handle original relationship values
object originalReferenceValue;
if (changeTracker.OriginalValues.TryGetValue(navProp.Name, out originalReferenceValue))
{
if (originalReferenceValue != null)
{
//Capture the deletion of association
ObjectStateEntry deletedRelationshipEntry =
context.ObjectStateManager.ChangeRelationshipState(
entry.Entity,
originalReferenceValue,
navProp.Name,
EntityState.Deleted);
allRelationships.Remove(deletedRelationshipEntry);
}
//Capture the Addition of association
object currentReferenceValue = null;
foreach (object o in relatedEnd)
{
currentReferenceValue = o;
break;
}
if (currentReferenceValue != null)
{
allRelationships.Remove(AddRelationshipUnlessExists(context, relatedEnd, entry, currentReferenceValue, navProp.Name));
}
// if the current value of the reference is null, then the user must set the entity reference to null
// which is already being handled by the deletion of the relationship
}
}
}
}
// Creates an Added relationship entry unless the relationship already exists in the Unchanged state
private static ObjectStateEntry AddRelationshipUnlessExists(ObjectContext context, IRelatedEnd relatedEnd, ObjectStateEntry fromEntry, object toEntity, string navPropName)
{
var toEntry = context.ObjectStateManager.GetObjectStateEntry(toEntity);
var associationSet = ((AssociationSet)relatedEnd.RelationshipSet);
var fromEnd = associationSet.AssociationSetEnds[relatedEnd.SourceRoleName].CorrespondingAssociationEndMember;
var toEnd = associationSet.AssociationSetEnds[relatedEnd.TargetRoleName].CorrespondingAssociationEndMember;
ObjectStateEntry existingRelationship;
if (!context.TryGetObjectStateEntry(fromEntry.EntityKey, toEntry.EntityKey, associationSet, fromEnd, toEnd, out existingRelationship) ||
existingRelationship.State != EntityState.Unchanged)
{
return context.ObjectStateManager.ChangeRelationshipState(
fromEntry.Entity,
toEntity,
navPropName,
EntityState.Added);
}
return existingRelationship;
}
// Extracts the relationship key information from the ExtendedProperties and OriginalValues records of each ObjectChangeTracker
// This is done by:
// 1. Creating any existing relationship specified in the ExtendedProperties
// 2. Determine if there was a previous relationship, and if there was create a deleted relationship between the entity and the previous entity or key value
private static void HandleRelationshipKeys(ObjectContext context, EntityIndex entityIndex, RelationshipSet allRelationships, IObjectWithChangeTracker entity)
{
ObjectChangeTracker changeTracker = entity.ChangeTracker;
if (changeTracker.State == ObjectState.Unchanged ||
changeTracker.State == ObjectState.Modified ||
changeTracker.State == ObjectState.Deleted)
{
ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(entity);
EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(entity);
RelationshipManager relationshipManager = context.ObjectStateManager.GetRelationshipManager(entity);
foreach (var entityReference in EnumerateSaveReferences(relationshipManager))
{
AssociationSet associationSet = ((AssociationSet)entityReference.RelationshipSet);
AssociationEndMember fromEnd = associationSet.AssociationSetEnds[entityReference.SourceRoleName].CorrespondingAssociationEndMember;
AssociationEndMember toEnd = associationSet.AssociationSetEnds[entityReference.TargetRoleName].CorrespondingAssociationEndMember;
// Find if there is a NavigationProperty for this candidate
NavigationProperty navigationProperty = entityType.NavigationProperties.
SingleOrDefault(x => x.RelationshipType == associationSet.ElementType &&
x.FromEndMember == fromEnd &&
x.ToEndMember == toEnd);
// Only handle relationship keys in one of these cases
// 1. There is no navigation property
// 2. The navigation property has a null current reference value and there are no removes or adds
// 3. The navigation property has a current reference value, but there is no remove
EntityKey currentKey = GetSavedReferenceKey(entityIndex, entityReference, entity, navigationProperty, changeTracker.ExtendedProperties);
// Get any original value from the change tracking information
object originalValue = null;
EntityKey originalKey = null;
bool hasOriginalValue = false;
if (changeTracker.OriginalValues != null)
{
// Try to get the original value from the NavigationProperty first
if (navigationProperty != null)
{
hasOriginalValue = changeTracker.OriginalValues.TryGetValue(navigationProperty.Name, out originalValue);
}
// Try to get the original value from the reference key second
if (!hasOriginalValue || originalValue == null)
{
originalKey = GetSavedReferenceKey(entityIndex, entityReference, entity, navigationProperty, changeTracker.OriginalValues);
}
}
// Create the current relationship
if (currentKey != null)
{
// If the key is for a deleted entity, move that key to an originalValue and fixup the entities key values
// Otherwise create a new relationship
ObjectStateEntry currentEntry;
if (context.ObjectStateManager.TryGetObjectStateEntry(currentKey, out currentEntry) &&
currentEntry.Entity != null &&
currentEntry.State == EntityState.Deleted)
{
entityReference.EntityKey = null;
MoveSavedReferenceKey(entityReference, entity, navigationProperty, changeTracker.ExtendedProperties, changeTracker.OriginalValues);
originalKey = currentKey;
}
else
{
CreateRelationship(context, entityReference, entry.EntityKey, currentKey, originalKey == null ? EntityState.Unchanged : EntityState.Added);
}
}
else
{
// Find the current key
// Cannot get the EntityKey directly because this is null when it points to an Added entity
currentKey = entityReference.GetCurrentEntityKey(context);
}
// Create the original relationship
if (originalKey != null)
{
// If the key is for a deleted entity, remember to create a deleted relationship,
// otherwise use the entityReference to setup the deleted relationship
ObjectStateEntry originalEntry = null;
ObjectStateEntry deletedRelationshipEntry = null;
if (context.ObjectStateManager.TryGetObjectStateEntry(originalKey, out originalEntry) &&
originalEntry.Entity != null &&
originalEntry.State == EntityState.Deleted)
{
allRelationships.Add(entityReference, entry.Entity, originalEntry.Entity, EntityState.Deleted);
}
else
{
// To create a deleted relationship to a key, first detach the existing relationship between entry and currentKey
EntityState currentRelationshipState = DetachRelationship(context, entityReference, entry, currentKey);
// If the relationship is 1 to 0..1, detach the relationship from currentKey to its target (targetKey)
EntityState targetRelationshipState = EntityState.Detached;
EntityReference targetReference = null;
EntityKey targetKey = null;
if (originalEntry != null &&
originalEntry.Entity != null &&
originalEntry.RelationshipManager != null &&
associationSet.AssociationSetEnds[fromEnd.Name].CorrespondingAssociationEndMember.RelationshipMultiplicity != RelationshipMultiplicity.Many)
{
targetReference = originalEntry.RelationshipManager.GetRelatedEnd(entityReference.RelationshipName, entityReference.SourceRoleName) as EntityReference;
targetKey = targetReference.GetCurrentEntityKey(context);
if (targetKey != null)
{
targetRelationshipState = DetachRelationship(context, targetReference, originalEntry, targetKey);
}
}
// Create the deleted relationship between entry and originalKey
deletedRelationshipEntry = CreateRelationship(context, entityReference, entry.EntityKey, originalKey, EntityState.Deleted);
// Set the previous relationship between entry and currentKey back
CreateRelationship(context, entityReference, entry.EntityKey, currentKey, currentRelationshipState);
// Set the previous relationship between originalEntry and targetKey back
if (targetKey != null)
{
CreateRelationship(context, targetReference, originalEntry.EntityKey, targetKey, targetRelationshipState);
}
}
if (deletedRelationshipEntry != null)
{
// Remove the deleted relationship from those that need to be processed later in ApplyChanges
allRelationships.Remove(deletedRelationshipEntry);
}
}
else if (currentKey == null && originalValue != null && entityReference.IsDependentEndOfReferentialConstraint())
{
// the graph won't have this hooked up because there is no current value, but there is an original value,
// so the relationship processing code will want to delete a relationship.
// we can add this one so it has a relationship to change to deleted.
context.ObjectStateManager.ChangeRelationshipState(
entry.Entity,
originalValue,
entityReference.RelationshipName,
entityReference.TargetRoleName,
EntityState.Added);
}
}
}
}
private static ObjectStateEntry CreateRelationship(ObjectContext context, EntityReference entityReference, EntityKey fromKey, EntityKey toKey, EntityState state)
{
if (state != EntityState.Detached)
{
AssociationSet associationSet = ((AssociationSet)entityReference.RelationshipSet);
AssociationEndMember fromEnd = associationSet.AssociationSetEnds[entityReference.SourceRoleName].CorrespondingAssociationEndMember;
AssociationEndMember toEnd = associationSet.AssociationSetEnds[entityReference.TargetRoleName].CorrespondingAssociationEndMember;
// set the relationship to the original relationship in the unchanged state
Debug.Assert(toKey != null, "why/how would we do a delete with a null originalKey?");
if (toKey.IsTemporary)
{
// Clear any existing relationship
entityReference.EntityKey = null;
// If the target entity is Added, use Add on RelatedEnd
ObjectStateEntry targetEntry;
context.ObjectStateManager.TryGetObjectStateEntry(toKey, out targetEntry);
Debug.Assert(targetEntry != null, "Should have found the state entry");
((IRelatedEnd)entityReference).Add(targetEntry.Entity);
}
else
{
entityReference.EntityKey = toKey;
}
ObjectStateEntry relationshipEntry;
bool found = context.TryGetObjectStateEntry(fromKey, toKey, associationSet, fromEnd, toEnd, out relationshipEntry);
Debug.Assert(found, "Did not find the created relationship.");
switch (state)
{
case EntityState.Added:
break;
case EntityState.Unchanged:
relationshipEntry.AcceptChanges();
break;
case EntityState.Deleted:
relationshipEntry.AcceptChanges();
entityReference.EntityKey = null;
break;
}
return relationshipEntry;
}
return null;
}
private static EntityState DetachRelationship(ObjectContext context, EntityReference entityReference, ObjectStateEntry fromEntry, EntityKey toKey)
{
EntityState currentRelationshipState = EntityState.Detached;
if (toKey != null)
{
AssociationSet associationSet = ((AssociationSet)entityReference.RelationshipSet);
AssociationEndMember fromEnd = associationSet.AssociationSetEnds[entityReference.SourceRoleName].CorrespondingAssociationEndMember;
AssociationEndMember toEnd = associationSet.AssociationSetEnds[entityReference.TargetRoleName].CorrespondingAssociationEndMember;
ObjectStateEntry currentRelationshipEntry = null;
if (context.TryGetObjectStateEntry(fromEntry.EntityKey, toKey, associationSet, fromEnd, toEnd, out currentRelationshipEntry))
{
currentRelationshipState = currentRelationshipEntry.State;
entityReference.EntityKey = null;
if (currentRelationshipEntry.State == EntityState.Deleted)
{
currentRelationshipEntry.AcceptChanges();
}
Debug.Assert(currentRelationshipEntry.State == EntityState.Detached, "relationship was not detached");
}
}
return currentRelationshipState;
}
private static string CreateReferenceKeyLookup(string keyMemberName, EntityReference reference, NavigationProperty navigationProperty)
{
// use the more usable navigation property name to qualify the member
// if available
if (navigationProperty != null)
{
return String.Format(CultureInfo.InvariantCulture, "{0}.{1}", navigationProperty.Name, keyMemberName);
}
else
{
return String.Format(CultureInfo.InvariantCulture, "Navigate({0}.{1}).{2}", reference.RelationshipSet.ElementType.FullName, reference.TargetRoleName, keyMemberName);
}
}
// retrieves the key corresponding to the passed in EntityReference
// these keys can be set during the ObjectMaterialized event or through relationship fixup
private static EntityKey GetSavedReferenceKey(EntityIndex entityIndex, EntityReference reference, object entity, NavigationProperty navigationProperty, IDictionary<string, object> values)
{
Debug.Assert(navigationProperty == null || reference.RelationshipSet.ElementType == navigationProperty.RelationshipType, "the reference and navigationProperty should correspond");
EntitySet entitySet = ((AssociationSet)reference.RelationshipSet).AssociationSetEnds[reference.TargetRoleName].EntitySet;
List<EntityKeyMember> foundKeyMembers = new List<EntityKeyMember>(1);
bool foundNone = true;
bool missingSome = false;
foreach (var keyMember in entitySet.ElementType.KeyMembers)
{
string lookupKey = CreateReferenceKeyLookup(keyMember.Name, reference, navigationProperty);
object value;
if (values.TryGetValue(lookupKey, out value))
{
foundKeyMembers.Add(new EntityKeyMember(keyMember.Name, value));
foundNone = false;
}
else
{
missingSome = true;
}
}
if (foundNone)
{
// we didn't find a key
return null;
}
else if (missingSome)
{
throw new InvalidOperationException(
String.Format(
CultureInfo.CurrentCulture,
"The OriginalValues or ExtendedProperties collections on the type '{0}' contained only a partial key to satisfy the relationship '{1}' targeting the role '{2}'",
ObjectContext.GetObjectType(entity.GetType()).FullName,
reference.RelationshipName,
reference.TargetRoleName));
}
EntityKey key = entityIndex.ConvertEntityKey(new EntityKey(reference.GetEntitySetName(), foundKeyMembers));
return key;
}
// Moves the key corresponding to the passed in EntityReference from a source collection to a target collection
private static void MoveSavedReferenceKey(EntityReference reference, object entity, NavigationProperty navigationProperty, IDictionary<string, object> sourceValues, IDictionary<string, object> targetValues)
{
Debug.Assert(navigationProperty == null || reference.RelationshipSet.ElementType == navigationProperty.RelationshipType, " the reference and navigationProperty should correspond");
EntitySet entitySet = ((AssociationSet)reference.RelationshipSet).AssociationSetEnds[reference.TargetRoleName].EntitySet;
bool missingSome = false;
foreach (var keyMember in entitySet.ElementType.KeyMembers)
{
string lookupKey = CreateReferenceKeyLookup(keyMember.Name, reference, navigationProperty);
object value;
if (sourceValues.TryGetValue(lookupKey, out value))
{
if (targetValues.ContainsKey(lookupKey))
{
targetValues[lookupKey] = value;
}
else
{
targetValues.Add(lookupKey, value);
}
sourceValues.Remove(lookupKey);
}
else
{
missingSome = true;
}
}
if (missingSome)
{
throw new InvalidOperationException(
String.Format(
CultureInfo.CurrentCulture,
" The OriginalValues or ExtendedProperties collections on the type '{0}' contained only a partial key to satisfy the relationship '{1}' targeting the role '{2}'",
ObjectContext.GetObjectType(entity.GetType()).FullName,
reference.RelationshipName,
reference.TargetRoleName));
}
}
private static IEnumerable<EntityReference> EnumerateSaveReferences(RelationshipManager manager)
{
return manager.GetAllRelatedEnds().OfType<EntityReference>()
.Where(er => er.RelationshipSet.ElementType.RelationshipEndMembers[er.SourceRoleName].RelationshipMultiplicity != RelationshipMultiplicity.One &&
!((AssociationSet)er.RelationshipSet).ElementType.IsForeignKey);
}
internal static void StoreReferenceKeyValues(this ObjectContext context, IObjectWithChangeTracker entity)
{
if(entity == null)
{
throw new ArgumentNullException("entity");
}
ObjectStateEntry entry;
if (!context.ObjectStateManager.TryGetObjectStateEntry(entity, out entry))
{
// must be a no tracking query, the reference key info won't be available
return;
}
var relationshipManager = entry.RelationshipManager;
EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(entity);
foreach (EntityReference entityReference in EnumerateSaveReferences(relationshipManager))
{
NavigationProperty navigationProperty = entityType.NavigationProperties.FirstOrDefault(n => n.RelationshipType == entityReference.RelationshipSet.ElementType &&
n.FromEndMember.Name == entityReference.SourceRoleName &&
n.ToEndMember.Name == entityReference.TargetRoleName);
object value = entityReference.GetValue();
if ((navigationProperty == null || value == null) && entityReference.EntityKey != null)
{
foreach (var item in entityReference.EntityKey.EntityKeyValues)
{
string key = CreateReferenceKeyLookup(item.Key, entityReference, navigationProperty);
entity.ChangeTracker.ExtendedProperties.Add(key, item.Value);
}
}
}
}
private static void HandleEntity(ObjectContext context, EntityIndex entityIndex, RelationshipSet allRelationships, IObjectWithChangeTracker entity)
{
ChangeEntityStateBasedOnObjectState(context, entity);
HandleRelationshipKeys(context, entityIndex, allRelationships, entity);
UpdateOriginalValues(context, entity);
}
private static void HandleDeletedEntity(ObjectContext context, EntityIndex entityIndex, RelationshipSet allRelationships, IObjectWithChangeTracker entity)
{
HandleRelationshipKeys(context, entityIndex, allRelationships, entity);
ChangeEntityStateBasedOnObjectState(context, entity);
UpdateOriginalValues(context, entity);
}
private static void UpdateOriginalValues(ObjectContext context, IObjectWithChangeTracker entity)
{
if (entity.ChangeTracker.State == ObjectState.Unchanged ||
entity.ChangeTracker.State == ObjectState.Added ||
entity.ChangeTracker.OriginalValues == null)
{
// nothing to do here
return;
}
// we only need/want to deal with scalar and complex properties
ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(entity);
OriginalValueRecord originalValueRecord = entry.GetUpdatableOriginalValues();
EntityType entityType = context.MetadataWorkspace.GetCSpaceEntityType(entity);
// walk through each property and see if we have an original value for it
// set it if we do. Walk down through ComplexType properties to set original values
// for each of them also
//
// it is expected that the original values will be sparse because we are trying
// to only capture originals for the ones we are required to have (concurrency, sproc, condition, more?)
foreach(EdmProperty property in entityType.Properties)
{
object value;
if(property.TypeUsage.EdmType is SimpleType && entity.ChangeTracker.OriginalValues.TryGetValue(property.Name, out value))
{
originalValueRecord.SetValue(property, value);
}
else if(property.TypeUsage.EdmType is ComplexType)
{
OriginalValueRecord complexOriginalValues = originalValueRecord.GetOriginalValueRecord(property.Name);
UpdateOriginalValues((ComplexType)property.TypeUsage.EdmType, ObjectContext.GetObjectType(entity.GetType()).FullName, property.Name, entity.ChangeTracker.OriginalValues, complexOriginalValues);
}
}
}
private static void UpdateOriginalValues(ComplexType complexType, string entityTypeName, string propertyPathToType, IDictionary<string, object> originalValueSource, OriginalValueRecord complexOriginalValueRecord)
{
// Note that complexOriginalValueRecord may be null
// a null complexOriginalValueRecord will only occur if a null reference is assigned
// to a ComplexType property and then given to ApplyChanges.
//
// walk through each property and see if we have an original value for it
// set it if we do. Walk down through ComplexType properties to set original values
// for each of them also
foreach (EdmProperty property in complexType.Properties)
{
object value;
string propertyPath = String.Format(CultureInfo.InvariantCulture, "{0}.{1}", propertyPathToType, property.Name);
if (property.TypeUsage.EdmType is SimpleType && originalValueSource.TryGetValue(propertyPath, out value))
{
if (complexOriginalValueRecord != null)
{
complexOriginalValueRecord.SetValue(property, value);
}
else if (value != null)
{
Debug.Assert(complexOriginalValueRecord == null, "we only throw when the value is not null and the record is null");
throw new InvalidOperationException(
String.Format(
CultureInfo.CurrentCulture,
"Can not set the original value on the object stored in the property '{0}' on the type '{1}' because the property is null.",
propertyPathToType,
entityTypeName));
}
}
else if (property.TypeUsage.EdmType is ComplexType)
{
OriginalValueRecord nestedOriginalValueRecord = null;
if (complexOriginalValueRecord != null)
{
nestedOriginalValueRecord = complexOriginalValueRecord.GetOriginalValueRecord(property.Name);
}
// recurse down the chain of complex types
UpdateOriginalValues((ComplexType)property.TypeUsage.EdmType, entityTypeName, propertyPath, originalValueSource, nestedOriginalValueRecord);
}
}
}
private static OriginalValueRecord GetOriginalValueRecord(this OriginalValueRecord record, string name)
{
int ordinal = record.GetOrdinal(name);
if (!record.IsDBNull(ordinal))
{
return record.GetDataRecord(ordinal) as OriginalValueRecord;
}
else
{
return null;
}
}
private static void SetValue(this OriginalValueRecord record, EdmProperty edmProperty, object value)
{
PrimitiveType primitiveType = edmProperty.TypeUsage.EdmType as PrimitiveType;
if (value == null && primitiveType != null)
{
Type entityClrType = primitiveType.ClrEquivalentType;
if (entityClrType.IsValueType && !edmProperty.Nullable)
{
// Skip setting null original values on non-nullable CLR types because the ObjectStateEntry won't allow this
return;
}
}
int ordinal = record.GetOrdinal(edmProperty.Name);
record.SetValue(ordinal, value);
}
private static void ChangeEntityStateBasedOnObjectState(ObjectContext context, IObjectWithChangeTracker entity)
{
switch (entity.ChangeTracker.State)
{
case (ObjectState.Added):
// No-op: the state entry is already marked as added
Debug.Assert(context.ObjectStateManager.GetObjectStateEntry(entity).State == EntityState.Added, "State should have been Added");
break;
case (ObjectState.Unchanged):
context.ObjectStateManager.ChangeObjectState(entity, EntityState.Unchanged);
break;
case (ObjectState.Modified):
context.ObjectStateManager.ChangeObjectState(entity, EntityState.Modified);
break;
case (ObjectState.Deleted):
context.ObjectStateManager.ChangeObjectState(entity, EntityState.Deleted);
break;
}
}
private static EntityType GetCSpaceEntityType(this MetadataWorkspace workspace, object entity)
{
Type type = ObjectContext.GetObjectType(entity.GetType());
EntityType ospaceEntityType = null;
StructuralType cspaceEntityType = null;
EntityType entityType = null;
if (workspace.TryGetItem<EntityType>(
type.FullName,
DataSpace.OSpace,
out ospaceEntityType))
{
if (workspace.TryGetEdmSpaceType(
ospaceEntityType,
out cspaceEntityType))
{
entityType = cspaceEntityType as EntityType;
}
}
if(entityType == null)
{
throw new ArgumentException(String.Format(CultureInfo.CurrentCulture, "Unable to find a CSpace type for type {0}", type.FullName));
}
return entityType;
}
private static object GetValue(this System.Data.Objects.DataClasses.EntityReference entityReference)
{
foreach (object value in entityReference)
{
return value;
}
return null;
}
private static EntityKey GetCurrentEntityKey(this System.Data.Objects.DataClasses.EntityReference entityReference, ObjectContext context)
{
EntityKey currentKey = null;
object currentValue = entityReference.GetValue();
if (currentValue != null)
{
ObjectStateEntry relatedEntry = context.ObjectStateManager.GetObjectStateEntry(currentValue);
currentKey = relatedEntry.EntityKey;
}
else
{
currentKey = entityReference.EntityKey;
}
return currentKey;
}
private static RelatedEnd GetRelatedEnd(this ObjectStateEntry entry, string navigationPropertyIdentity)
{
NavigationProperty navigationProperty =
GetNavigationProperty(entry.ObjectStateManager.MetadataWorkspace.GetCSpaceEntityType(entry.Entity), navigationPropertyIdentity);
return entry.RelationshipManager.GetRelatedEnd(
navigationProperty.RelationshipType.FullName, navigationProperty.ToEndMember.Name) as RelatedEnd;
}
private static NavigationProperty GetNavigationProperty(this EntityType entityType, string navigationPropertyIdentity)
{
NavigationProperty navigationProperty;
if (!entityType.NavigationProperties.TryGetValue(navigationPropertyIdentity, false, out navigationProperty))
{
throw new InvalidOperationException(
String.Format(
CultureInfo.CurrentCulture,
"Could not find navigation property '{0}' in EntityType '{1}'.",
navigationPropertyIdentity,
entityType.FullName));
}
return navigationProperty;
}
private static string GetEntitySetName(this RelatedEnd relatedEnd)
{
EntitySet entitySet = ((AssociationSet)relatedEnd.RelationshipSet).AssociationSetEnds[relatedEnd.TargetRoleName].EntitySet;
return entitySet.EntityContainer.Name + "." + entitySet.Name;
}
private static bool IsDependentEndOfReferentialConstraint(this RelatedEnd relatedEnd)
{
if (null != relatedEnd.RelationshipSet)
{
// NOTE Referential constraints collection will usually contains 0 or 1 element,
// so performance shouldn't be an issue here
foreach (ReferentialConstraint constraint in ((AssociationType)relatedEnd.RelationshipSet.ElementType).ReferentialConstraints)
{
if (constraint.ToRole.Name == relatedEnd.SourceRoleName)
{
// Example:
// Client<C_ID> --- Order<O_ID, Client_ID>
// RI Constraint: Principal/From <Client.C_ID>, Dependent/To <Order.Client_ID>
// When current RelatedEnd is a CollectionOrReference in Order's relationships,
// constarint.ToRole == this._fromEndProperty == Order
return true;
}
}
}
return false;
}
private static bool TryGetObjectStateEntry(this ObjectContext context, EntityKey from, EntityKey to, AssociationSet associationSet, AssociationEndMember fromEnd, AssociationEndMember toEnd, out ObjectStateEntry entry)
{
entry = null;
foreach (var relationshipEntry in (from e in context.ObjectStateManager.GetObjectStateEntries(EntityState.Added | EntityState.Unchanged)
where e.IsRelationship && e.EntitySet == associationSet
select e))
{
CurrentValueRecord currentValues = relationshipEntry.CurrentValues;
int fromOrdinal = currentValues.GetOrdinal(fromEnd.Name);
int toOrdinal = currentValues.GetOrdinal(toEnd.Name);
if (((EntityKey)currentValues.GetValue(fromOrdinal)) == from &&
((EntityKey)currentValues.GetValue(toOrdinal)) == to)
{
entry = relationshipEntry;
return true;
}
}
return false;
}
private sealed class AddHelper
{
private readonly ObjectContext _context;
private readonly EntityIndex _entityIndex;
// Used during add processing
private readonly Queue<Tuple<string, IObjectWithChangeTracker>> _entitiesToAdd;
private readonly Queue<Tuple<ObjectStateEntry, string, IEnumerable<object>>> _entitiesDuringAdd;
public static EntityIndex AddAllEntities(ObjectContext context, string entitySetName, IObjectWithChangeTracker entity)
{
AddHelper addHelper = new AddHelper(context);
try
{
// Include the root element to start the Apply
addHelper.QueueAdd(entitySetName, entity);
// Add everything
while (addHelper.HasMore)
{
Tuple<string, IObjectWithChangeTracker> entityInSet = addHelper.NextAdd();
// Only add the object if it's not already in the context
ObjectStateEntry entry = null;
if (!context.ObjectStateManager.TryGetObjectStateEntry(entityInSet.Item2, out entry))
{
context.AddObject(entityInSet.Item1, entityInSet.Item2);
}
}
}
finally
{
addHelper.Detach();
}
return addHelper.EntityIndex;
}
private AddHelper(ObjectContext context)
{
_context = context;
_context.ObjectStateManager.ObjectStateManagerChanged += this.HandleStateManagerChange;
_entityIndex = new EntityIndex(context);
_entitiesToAdd = new Queue<Tuple<string, IObjectWithChangeTracker>>();
_entitiesDuringAdd = new Queue<Tuple<ObjectStateEntry, string, IEnumerable<object>>>();
}
private void Detach()
{
_context.ObjectStateManager.ObjectStateManagerChanged -= this.HandleStateManagerChange;
}
private void HandleStateManagerChange(object sender, CollectionChangeEventArgs args)
{
if (args.Action == CollectionChangeAction.Add)
{
IObjectWithChangeTracker entity = args.Element as IObjectWithChangeTracker;
ObjectStateEntry entry = _context.ObjectStateManager.GetObjectStateEntry(entity);
ObjectChangeTracker changeTracker = entity.ChangeTracker;
changeTracker.ChangeTrackingEnabled = false;
_entityIndex.Add(entry, changeTracker);
// Queue removed reference values
var navPropNames = _context.MetadataWorkspace.GetCSpaceEntityType(entity).NavigationProperties.Select(n => n.Name);
var entityRefOriginalValues = changeTracker.OriginalValues.Where(kvp => navPropNames.Contains(kvp.Key));
foreach (KeyValuePair<string, object> originalValueWithName in entityRefOriginalValues)
{
if (originalValueWithName.Value != null)
{
_entitiesDuringAdd.Enqueue(new Tuple<ObjectStateEntry, string, IEnumerable<object>>(
entry,
originalValueWithName.Key,
new object[] { originalValueWithName.Value }));
}
}
// Queue removed collection values
foreach (KeyValuePair<string, ObjectList> collectionPropertyChangesWithName in changeTracker.ObjectsRemovedFromCollectionProperties)
{
_entitiesDuringAdd.Enqueue(new Tuple<ObjectStateEntry, string, IEnumerable<object>>(
entry,
collectionPropertyChangesWithName.Key,
collectionPropertyChangesWithName.Value));
}
}
}
private EntityIndex EntityIndex
{
get { return _entityIndex; }
}
private bool HasMore
{
get { ProcessNewAdds(); return _entitiesToAdd.Count > 0; }
}
private void QueueAdd(string entitySetName, IObjectWithChangeTracker entity)
{
if (!_entityIndex.Contains(entity))
{
// Queue the entity so that we can add the 'removed collection' items
_entitiesToAdd.Enqueue(new Tuple<string, IObjectWithChangeTracker>(entitySetName, entity));
}
}
private Tuple<string, IObjectWithChangeTracker> NextAdd()
{
ProcessNewAdds();
return _entitiesToAdd.Dequeue();
}
private void ProcessNewAdds()
{
while (_entitiesDuringAdd.Count > 0)
{
Tuple<ObjectStateEntry, string, IEnumerable<object>> relatedEntities = _entitiesDuringAdd.Dequeue();
RelatedEnd relatedEnd = relatedEntities.Item1.GetRelatedEnd(relatedEntities.Item2);
string entitySetName = relatedEnd.GetEntitySetName();
foreach (var targetEntity in relatedEntities.Item3)
{
QueueAdd(entitySetName, targetEntity as IObjectWithChangeTracker);
}
}
}
}
private sealed class EntityIndex
{
private readonly ObjectContext _context;
// Set of all entities
private readonly HashSet<IObjectWithChangeTracker> _allEntities;
// Index of the final key that will be used in the context (could be real for non-added, could be temporary for added)
// to the initial temporary key
private readonly Dictionary<EntityKey, EntityKey> _temporaryKeyMap;
public EntityIndex(ObjectContext context)
{
_context = context;
_allEntities = new HashSet<IObjectWithChangeTracker>();
_temporaryKeyMap = new Dictionary<EntityKey, EntityKey>();
}
public void Add(ObjectStateEntry entry, ObjectChangeTracker changeTracker)
{
EntityKey temporaryKey = entry.EntityKey;
EntityKey finalKey;
if (!_allEntities.Contains(entry.Entity))
{
// Track that this Apply will be handling this entity
_allEntities.Add(entry.Entity as IObjectWithChangeTracker);
}
if (changeTracker.State == ObjectState.Added)
{
finalKey = temporaryKey;
}
else
{
finalKey = _context.CreateEntityKey(temporaryKey.EntityContainerName + "." + temporaryKey.EntitySetName, entry.Entity);
}
if (!_temporaryKeyMap.ContainsKey(finalKey))
{
_temporaryKeyMap.Add(finalKey, temporaryKey);
}
}
public bool Contains(object entity)
{
return _allEntities.Contains(entity);
}
public IEnumerable<IObjectWithChangeTracker> AllEntities
{
get { return _allEntities; }
}
// Converts the passed in EntityKey to the EntityKey that is usable by the current state of ApplyChanges
public EntityKey ConvertEntityKey(EntityKey targetKey)
{
ObjectStateEntry targetEntry;
if (!_context.ObjectStateManager.TryGetObjectStateEntry(targetKey, out targetEntry))
{
// If no entry exists, then either:
// 1. This is an EntityKey that is not represented in the set of entities being dealt with during the Apply
// 2. This is an EntityKey that will represent one of the yet-to-be-processed Added entries, so look it up
EntityKey temporaryKey;
if (_temporaryKeyMap.TryGetValue(targetKey, out temporaryKey))
{
targetKey = temporaryKey;
}
}
return targetKey;
}
}
// The RelationshipSet builds a list of all relationships from an
// initial set of entities
private sealed class RelationshipSet : IEnumerable<RelationshipWrapper>
{
private readonly HashSet<RelationshipWrapper> _relationships;
private readonly ObjectContext _context;
public RelationshipSet(ObjectContext context, IEnumerable<object> allEntities)
{
_context = context;
_relationships = new HashSet<RelationshipWrapper>();
foreach (object entity in allEntities)
{
ObjectStateEntry entry = context.ObjectStateManager.GetObjectStateEntry(entity);
foreach (IRelatedEnd relatedEnd in entry.RelationshipManager.GetAllRelatedEnds())
{
if (!((AssociationType)relatedEnd.RelationshipSet.ElementType).IsForeignKey)
{
foreach (object targetEntity in relatedEnd)
{
Add(relatedEnd, entity, targetEntity, EntityState.Unchanged);
}
}
}
}
}
// Adds an entry to the index based on a IRelatedEnd
public void Add(IRelatedEnd relatedEnd, object sourceEntity, object targetEntity, EntityState state)
{
RelationshipWrapper wrapper = new RelationshipWrapper(
(AssociationSet)relatedEnd.RelationshipSet,
relatedEnd.SourceRoleName,
sourceEntity,
relatedEnd.TargetRoleName,
targetEntity,
state);
if (!_relationships.Contains(wrapper))
{
_relationships.Add(wrapper);
}
}
// Removes an entry from the index based on a relationship ObjectStateEntry
public void Remove(ObjectStateEntry relationshipEntry)
{
Debug.Assert(relationshipEntry.IsRelationship);
AssociationSet associationSet = (AssociationSet)relationshipEntry.EntitySet;
DbDataRecord values = relationshipEntry.State == EntityState.Deleted ? relationshipEntry.OriginalValues : relationshipEntry.CurrentValues;
int fromOridinal = values.GetOrdinal(associationSet.ElementType.AssociationEndMembers[0].Name);
object fromEntity = _context.ObjectStateManager.GetObjectStateEntry((EntityKey)values.GetValue(fromOridinal)).Entity;
int toOridinal = values.GetOrdinal(associationSet.ElementType.AssociationEndMembers[1].Name);
object toEntity = _context.ObjectStateManager.GetObjectStateEntry((EntityKey)values.GetValue(toOridinal)).Entity;
if (fromEntity != null && toEntity != null)
{
RelationshipWrapper wrapper = new RelationshipWrapper(
associationSet,
associationSet.ElementType.AssociationEndMembers[0].Name,
fromEntity,
associationSet.ElementType.AssociationEndMembers[1].Name,
toEntity,
EntityState.Unchanged);
_relationships.Remove(wrapper);
}
}
#region IEnumerable<RelationshipWrapper>
public IEnumerator<RelationshipWrapper> GetEnumerator()
{
return _relationships.GetEnumerator();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return _relationships.GetEnumerator();
}
#endregion
}
// A RelationshipWrapper is used to identify a relationship between two entities
// The relationship is identified by the AssociationSet, and the order of the entities based
// on the roles they play (via AssociationEndMember)
private sealed class RelationshipWrapper : IEquatable<RelationshipWrapper>
{
internal readonly AssociationSet AssociationSet;
internal readonly object End0;
internal readonly object End1;
internal readonly EntityState State;
internal RelationshipWrapper(AssociationSet extent,
string role0, object end0,
string role1, object end1,
EntityState state)
{
Debug.Assert(null != extent, "null AssociationSet");
Debug.Assert(null != (object)end0, "null end0");
Debug.Assert(null != (object)end1, "null end1");
AssociationSet = extent;
Debug.Assert(extent.ElementType.AssociationEndMembers.Count == 2, "only 2 ends are supported");
State = state;
if (extent.ElementType.AssociationEndMembers[0].Name == role0)
{
Debug.Assert(extent.ElementType.AssociationEndMembers[1].Name == role1, "a)roleAndKey1 Name differs");
End0 = end0;
End1 = end1;
}
else
{
Debug.Assert(extent.ElementType.AssociationEndMembers[0].Name == role1, "b)roleAndKey1 Name differs");
Debug.Assert(extent.ElementType.AssociationEndMembers[1].Name == role0, "b)roleAndKey0 Name differs");
End0 = end1;
End1 = end0;
}
}
internal ReadOnlyMetadataCollection<AssociationEndMember> AssociationEndMembers
{
get { return this.AssociationSet.ElementType.AssociationEndMembers; }
}
public override int GetHashCode()
{
return this.AssociationSet.Name.GetHashCode() ^ (this.End0.GetHashCode() + this.End1.GetHashCode());
}
public override bool Equals(object obj)
{
return Equals(obj as RelationshipWrapper);
}
public bool Equals(RelationshipWrapper wrapper)
{
return (Object.ReferenceEquals(this, wrapper) ||
((null != wrapper) &&
Object.ReferenceEquals(this.AssociationSet, wrapper.AssociationSet) &&
Object.ReferenceEquals(this.End0, wrapper.End0) &&
Object.ReferenceEquals(this.End1, wrapper.End1)));
}
}
}
<#+
}
#>
<#+
string DefaultSummaryComment{ get; set; }
string SummaryComment(MetadataItem item)
{
if (item.Documentation != null && item.Documentation.Summary != null)
{
return PrefixLinesOfMultilineComment(XMLCOMMENT_START + " ", XmlEntityize(item.Documentation.Summary));
}
if (DefaultSummaryComment != null)
{
return DefaultSummaryComment;
}
return string.Empty;
}
string LongDescriptionCommentElement(MetadataItem item, int indentLevel)
{
if (item.Documentation != null && !String.IsNullOrEmpty(item.Documentation.LongDescription))
{
string comment = Environment.NewLine;
string lineStart = CodeRegion.GetIndent(indentLevel) + XMLCOMMENT_START + " ";
comment += lineStart + "<LongDescription>" + Environment.NewLine;
comment += lineStart + PrefixLinesOfMultilineComment(lineStart, XmlEntityize(item.Documentation.LongDescription)) + Environment.NewLine;
comment += lineStart + "</LongDescription>";
return comment;
}
return string.Empty;
}
string NewModifier(NavigationProperty navigationProperty)
{
Type baseType = typeof(EntityObject);
return NewModifier(baseType, navigationProperty.Name);
}
string NewModifier(EdmFunction edmFunction)
{
Type baseType = typeof(ObjectContext);
return NewModifier(baseType, edmFunction.Name);
}
string NewModifier(EntitySet set)
{
Type baseType = typeof(ObjectContext);
return NewModifier(baseType, set.Name);
}
string NewModifier(EdmProperty property)
{
Type baseType;
if (property.DeclaringType.BuiltInTypeKind == BuiltInTypeKind.EntityType)
{
baseType = typeof(EntityObject);
}
else
{
baseType = typeof(ComplexObject);
}
return NewModifier(baseType, property.Name);
}
string NewModifier(Type type, string memberName)
{
if (HasBaseMemberWithMatchingName(type, memberName))
{
return "new";
}
return string.Empty;
}
string PrefixLinesOfMultilineComment(string prefix, string comment)
{
return comment.Replace(Environment.NewLine, Environment.NewLine + prefix);
}
string ParameterComments(IEnumerable<Tuple<string, string>> parameters, int indentLevel)
{
System.Text.StringBuilder builder = new System.Text.StringBuilder();
foreach (Tuple<string, string> parameter in parameters)
{
builder.AppendLine();
builder.Append(CodeRegion.GetIndent(indentLevel));
builder.Append(XMLCOMMENT_START);
builder.Append(String.Format(CultureInfo.InvariantCulture, " <param name=\"{0}\">{1}</param>", parameter.Item1, parameter.Item2));
}
return builder.ToString();
}
private void WriteFunctionParameters(IEnumerable<FunctionImportParameter> parameters)
{
foreach (FunctionImportParameter parameter in parameters)
{
if (!parameter.NeedsLocalVariable)
{
continue;
}
#>
ObjectParameter <#=parameter.LocalVariableName#>;
if (<#=parameter.IsNullableOfT ? parameter.FunctionParameterName + ".HasValue" : parameter.FunctionParameterName + " != null"#>)
{
<#=parameter.LocalVariableName#> = new ObjectParameter("<#=parameter.EsqlParameterName#>", <#=parameter.FunctionParameterName#>);
}
else
{
<#=parameter.LocalVariableName#> = new ObjectParameter("<#=parameter.EsqlParameterName#>", typeof(<#=parameter.RawClrTypeName#>));
}
<#+
}
}
string XmlEntityize(string text)
{
if (string.IsNullOrEmpty(text))
{
return string.Empty;
}
text = text.Replace("&","&");
text = text.Replace("<","<").Replace(">",">");
string id = Guid.NewGuid().ToString();
text = text.Replace(Environment.NewLine, id);
text = text.Replace("\r", "
").Replace("\n","
");
text = text.Replace(id, Environment.NewLine);
return text.Replace("\'","'").Replace("\"",""");
}
const string XMLCOMMENT_START = "///";
public string ModelNamespace{ get; set; }
string GetObjectNamespace(string csdlNamespaceName)
{
string objectNamespace;
if (EdmToObjectNamespaceMap.TryGetValue(csdlNamespaceName, out objectNamespace))
{
return objectNamespace;
}
return csdlNamespaceName;
}
static bool HasBaseMemberWithMatchingName(Type type, string memberName)
{
BindingFlags bindingFlags = BindingFlags.FlattenHierarchy | BindingFlags.NonPublic | BindingFlags.Public
| BindingFlags.Instance | BindingFlags.Static;
return type.GetMembers(bindingFlags).Where(m => IsVisibleMember(m)).Any(m => m.Name == memberName);
}
static bool IsVisibleMethod(MethodBase methodBase)
{
if (methodBase == null)
return false;
return !methodBase.IsPrivate && !methodBase.IsAssembly;
}
static bool IsVisibleMember(MemberInfo memberInfo)
{
if (memberInfo is EventInfo)
{
EventInfo ei = (EventInfo)memberInfo;
MethodInfo add = ei.GetAddMethod();
MethodInfo remove = ei.GetRemoveMethod();
return IsVisibleMethod(add) || IsVisibleMethod(remove);
}
else if (memberInfo is FieldInfo)
{
FieldInfo fi = (FieldInfo)memberInfo;
return !fi.IsPrivate && !fi.IsAssembly;
}
else if (memberInfo is MethodBase)
{
MethodBase mb = (MethodBase)memberInfo;
if (mb.IsSpecialName)
return false;
return IsVisibleMethod(mb);
}
else if (memberInfo is PropertyInfo)
{
PropertyInfo pi = (PropertyInfo)memberInfo;
MethodInfo get = pi.GetGetMethod();
MethodInfo set = pi.GetSetMethod();
return IsVisibleMethod(get) || IsVisibleMethod(set);
}
return false;
}
public Dictionary<string, string> EdmToObjectNamespaceMap
{
get { return _edmToObjectNamespaceMap; }
set { _edmToObjectNamespaceMap = value; }
}
public Dictionary<string, string> _edmToObjectNamespaceMap = new Dictionary<string, string>();
void UpdateObjectNamespaceMap(string objectNamespace)
{
if(objectNamespace != ModelNamespace && !EdmToObjectNamespaceMap.ContainsKey(ModelNamespace))
{
EdmToObjectNamespaceMap.Add(ModelNamespace, objectNamespace);
}
}
private void DefineMetadata()
{
TemplateMetadata[MetadataConstants.TT_TEMPLATE_NAME] = "CSharpSelfTracking.Context";
TemplateMetadata[MetadataConstants.TT_TEMPLATE_VERSION] = "5.0";
}
#>