Advertisement
Shimmy

Untitled

Apr 19th, 2015
35
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C# 5.46 KB | None | 0 0
  1.     public static void UpdateCollection<TEntity, TChild>(this DbContext context, TEntity parent, Expression<Func<TEntity, IEnumerable<TChild>>> expression, IEnumerable<TChild> current)
  2.       where TEntity : class
  3.       where TChild : class
  4.     {
  5.       var collection = expression.Compile()(parent);
  6.       var baseType = GetBaseType<TChild>();
  7.       var properties = GetEntityKeyNamesInternal(context, baseType).Select(pn => baseType.GetProperty(pn));
  8.  
  9.       foreach (var item in collection)
  10.       {
  11.         var ids = GetEntityKeyValuesInternal(item, properties);
  12.         var entry = context.Entry(item);
  13.         if (IsNotDefault(ids))
  14.         {
  15.           var isInCurrent = current.Any(cur => GetEntityKeyValues(context, cur).SequenceEqual(ids));
  16.           entry.State = isInCurrent ? EntityState.Modified : EntityState.Deleted;
  17.         }
  18.         else
  19.         {
  20.           entry.State = EntityState.Added;
  21.         }
  22.       }
  23.     }
  24.  
  25.  
  26.    
  27.     /// <summary>
  28.     ///
  29.     /// </summary>
  30.     /// <typeparam name="TEntity"></typeparam>
  31.     /// <typeparam name="TElement"></typeparam>
  32.     /// <param name="context"></param>
  33.     /// <param name="parent"></param>
  34.     /// <param name="navigationProperty"></param>
  35.     /// <param name="loadReference">indicates whether to load the related reference, or it has already been pre-loaded before.</param>
  36.     /// <returns></returns>
  37.     public static async Task UpdateCollectionAsync<TEntity, TElement>(this DbContext context, TEntity parent, Expression<Func<TEntity, ICollection<TElement>>> navigationProperty, bool loadReference = true)
  38.       where TEntity : class
  39.       where TElement : class
  40.     {
  41.       //check if key value equals to default value (i.e. Id == 0)
  42.       Func<TElement, bool> notDefaultKeyValuePredicate =
  43.         (e) =>  //context.GetEntityKeyValues(e).Any(ek => object.Equals(Activator.CreateInstance(ek.GetType()), ek));
  44.           IsNotDefault(context.GetEntityKeyValues(e));
  45.  
  46.       var collectionExtractor = navigationProperty.Compile();
  47.       var newCollection = collectionExtractor(parent).ToArray();
  48.       var newlyAdded = newCollection.Where(e => !notDefaultKeyValuePredicate(e));
  49.       foreach (var newItem in newlyAdded)
  50.         context.Entry(newItem).State = EntityState.Added;
  51.  
  52.       //context.Set<TEntity>().Attach(parent);        
  53.       var entityEntry = context.Entry(parent);
  54.       //entityEntry.State = EntityState.Modified;
  55.       if (loadReference)
  56.         await entityEntry.Collection(navigationProperty).LoadAsync();
  57.       var existingCollection = collectionExtractor(parent).ToArray();
  58.  
  59.       var existing = existingCollection.Intersect(newCollection);
  60.       //select only store items
  61.       foreach (var item in existingCollection.Where(notDefaultKeyValuePredicate))
  62.       {
  63.         var intersected = newCollection.SingleOrDefault(nI =>
  64.           {
  65.             var curKeys = context.GetEntityKeyValues(nI);
  66.             return curKeys.SequenceEqual(context.GetEntityKeyValues(item));
  67.           });
  68.         var isIntersected = intersected != null;
  69.         context.Entry(item).State =
  70.           isIntersected
  71.           ? EntityState.Modified
  72.           : EntityState.Deleted;
  73.       }
  74.     }
  75.  
  76.     public static void UpdateNavigation<TEntity, TNavigation>(this DbContext context, TEntity parent, Expression<Func<TEntity, TNavigation>> navigationProperty, params Expression<Func<TEntity, object>>[] storeEntityIdsExpressions)
  77.       where TEntity : class
  78.       where TNavigation : class
  79.     {
  80.       var entityBaseType = GetBaseType<TEntity>();
  81.       var set = context.Set(entityBaseType);
  82.       var currentEntity = navigationProperty.Compile()(parent);
  83.       if (currentEntity != null)
  84.       {
  85.         var currentIds = context.GetEntityKeyValues(currentEntity);
  86.         var isCurrentStored = IsNotDefault(currentIds);
  87.  
  88.         if (isCurrentStored)
  89.         {
  90.           set.Attach(currentEntity);
  91.           context.SetState(currentEntity, EntityState.Modified);
  92.           return; //no need to delete existing as its refers to the same as this one.
  93.         }
  94.         else
  95.         {
  96.           //make sure it's in added state
  97.         }
  98.       }
  99.       //else //currentEntity == null
  100.       //goto deleteStoreEntity;
  101.  
  102.       //deleteStoreEntity:
  103.       if (storeEntityIdsExpressions != null)
  104.       {
  105.         var attachment = Activator.CreateInstance<TEntity>();
  106.         var ids = new object[storeEntityIdsExpressions.Length];
  107.         for (int i = 0; i < storeEntityIdsExpressions.Length; i++)
  108.         {
  109.           var expression = storeEntityIdsExpressions[i];
  110.           ids[i] = expression.Compile()(parent);
  111.           var uExp = expression.Body as UnaryExpression;
  112.           var member = (MemberExpression)(uExp != null ? uExp.Operand : uExp);
  113.           var property = (PropertyInfo)member.Member;
  114.           var propType = property.PropertyType;
  115.           property.SetValue(parent, propType.IsValueType ? Activator.CreateInstance(propType) : null);
  116.         }
  117.  
  118.         if (IsNotDefault(ids))
  119.         {
  120.           context.SetState(parent, EntityState.Modified);
  121.           context.SetEntityKeyValues(attachment, ids);
  122.           set.Attach(attachment);
  123.           context.SetState(attachment, EntityState.Deleted);
  124.         }
  125.       }
  126.     }
  127.  
  128.     private static bool IsNotDefault(IEnumerable<object> entityIds)
  129.     {
  130.       return entityIds != null
  131.         && entityIds.All(id =>
  132.           (id is string && ((string)id).Any())
  133.        || (id is IConvertible && ((IConvertible)id).ToInt64(null) > 0));
  134.     }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement