Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public interface IRepositoryFiltered<TObject> where TObject : class
- {
- int Count { get; }
- /// <summary>
- /// Gets all filtered objects from database
- /// </summary>
- TObject[] Get();
- TObject FirstOrDefault();
- IRepositoryOrdered<TObject> OrderByDescending<TKey>(Expression<Func<TObject, TKey>> orderBy) where TKey : IComparable;
- IRepositoryOrdered<TObject> OrderBy<TKey>(Expression<Func<TObject, TKey>> orderBy) where TKey : IComparable;
- int Sum(Expression<Func<TObject, int?>> selector);
- IQueryable<IGrouping<TKey, TObject>> GroupBy<TKey>(Expression<Func<TObject, TKey>> orderBy);
- }
- private static void SetupMany<T>(this Mock<IRepository<T>> mock, IEnumerable<T> collection) where T : class
- {
- mock
- .Setup(x => x.GetMany(It.IsAny<Expression<Func<T, bool>>>()))
- .Returns<Expression<Func<T, bool>>>(expr => collection.Where(expr.Compile()));
- }
- using System;
- using System.Collections.Generic;
- using System.Data;
- using System.Data.Common;
- using System.Data.Entity;
- using System.Data.Entity.Infrastructure;
- using System.Data.SqlClient;
- using System.Linq;
- using System.Linq.Expressions;
- using System.Reflection;
- using System.Text;
- using System.Text.RegularExpressions;
- using ExtensionMethods;
- namespace MyRepository
- {
- public class Repository<TObject> : AbstractRepository<TObject>,
- IRepository<TObject>
- , IRepositoryInclude<TObject>
- , IRepositoryFiltered<TObject>
- , IRepositoryOrdered<TObject>
- where TObject : class
- {
- private readonly IConnectionFactory _connectionFactory;
- private readonly IUnitOfWork _unitOfWork;
- public Repository(IConnectionFactory connection, IUnitOfWork unitOfWork)
- {
- _connectionFactory = connection;
- _unitOfWork = unitOfWork;
- }
- protected MWMS_NewWarehouseContext Context
- {
- get { return _connectionFactory.Get(); }
- }
- protected override IQueryable<TObject> QueryableDbSet
- {
- get { return DbSet.AsQueryable(); }
- }
- protected DbSet<TObject> DbSet
- {
- get
- {
- ConfigureConnection();
- return Context.Set<TObject>();
- }
- }
- private string TableName
- {
- get
- {
- string sql = Context.Set<TObject>().ToString();
- var regex = new Regex(@"FROMs+(?<table>.+)s+AS");
- Match match = regex.Match(sql);
- string table = match.Groups["table"].Value;
- return table;
- }
- }
- protected override IEnumerable<PropertyInfo> ColumnProperties
- {
- get
- {
- IEnumerable<PropertyInfo> columnProperties = base.ColumnProperties;
- return OrderPropertiesAsInDbTable(columnProperties);
- }
- }
- public IQueryable<TObject> Query
- {
- get { return DbSet.AsQueryable(); }
- }
- public virtual TObject[] All()
- {
- return Queryable.ToArray();
- }
- public virtual IEnumerable<TObject> GetMany(Expression<Func<TObject, bool>> where)
- {
- return Queryable.Where(where).ToList();
- }
- public IQueryable<IGrouping<TKey, TObject>> GroupBy<TKey>(Expression<Func<TObject, TKey>> groupBy)
- {
- try
- {
- return _queryFiltered.GroupBy(groupBy);
- }
- finally
- {
- _queryWithInclude = null;
- _queryOrdered = null;
- _queryFiltered = null;
- }
- }
- public virtual IRepositoryFiltered<TObject> Where(Expression<Func<TObject, bool>> predicate)
- {
- _queryFiltered = Queryable.Where(predicate);
- return this;
- }
- public TObject FirstOrDefault(Expression<Func<TObject, bool>> predicate)
- {
- return Queryable.FirstOrDefault(predicate);
- }
- public TObject First(Expression<Func<TObject, bool>> predicate)
- {
- return Queryable.First(predicate);
- }
- public TObject SingleOrDefault(Expression<Func<TObject, bool>> predicate)
- {
- return Queryable.SingleOrDefault(predicate);
- }
- public TObject Single(Expression<Func<TObject, bool>> predicate)
- {
- TObject result = Queryable.SingleOrDefault(predicate);
- if (result == null)
- throw new MwDbException("No records found for query: " + predicate.Body);
- return result;
- }
- public IRepositoryInclude<TObject> Include(Expression<Func<TObject, object>> entityToInclude)
- {
- string[] expressionString = entityToInclude.Body.ToString().Split(new[] { '.' });
- string path = string.Join(".", expressionString.Skip(1));
- return Include(path);
- }
- public bool Contains(Expression<Func<TObject, bool>> predicate)
- {
- return Queryable.Any(predicate);
- }
- public virtual TObject Create(TObject TObject)
- {
- TObject newEntry = DbSet.Add(TObject);
- return newEntry;
- }
- public virtual int Update(TObject TObject)
- {
- DbEntityEntry<TObject> entry = Context.Entry(TObject);
- DbSet.Attach(TObject);
- entry.State = EntityState.Modified;
- return 0;
- }
- public virtual int Delete(Expression<Func<TObject, bool>> predicate)
- {
- TObject[] objects = Where(predicate).Get();
- foreach (TObject obj in objects)
- DbSet.Remove(obj);
- return 0;
- }
- public void Delete(TObject TObject)
- {
- DbSet.Remove(TObject);
- }
- public void BulkInsertOrUpdate(IEnumerable<TObject> list, Expression<Func<TObject, object>> mergeKey)
- {
- BulkInsertOrUpdate(list, mergeKey, new Expression<Func<TObject, object>>[0]);
- }
- public void BulkInsertOrUpdate(IEnumerable<TObject> list, Expression<Func<TObject, object>> mergeKey, params Expression<Func<TObject, object>>[] skipUpdateFields)
- {
- var columns = GetColumnPropertyInfos();
- var importTable = BuildTemporaryImportTable(list, columns);
- var query = BuildMergeQuery(GetPropertyNameFromExpression(mergeKey), skipUpdateFields, importTable.TableName, columns);
- ExecuteSqlCommand(query.ToString());
- }
- public void BulkInsert(IEnumerable<TObject> list)
- {
- var columns = GetColumnPropertyInfos();
- var importTable = BuildTemporaryImportTable(list, columns);
- var query = BuildInsertQuery(importTable.TableName, columns);
- ExecuteSqlCommand(query.ToString());
- }
- private List<PropertyInfo> GetColumnPropertyInfos()
- {
- var columns = new List<PropertyInfo>();
- foreach (PropertyInfo p in ColumnProperties)
- {
- string stringType = p.PropertyType.ToString();
- if (!stringType.StartsWith("System.")) continue;
- Type type = p.PropertyType;
- if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
- {
- stringType = Nullable.GetUnderlyingType(type).ToString();
- }
- if (!stringType.Substring("System.".Length).IsIn(
- "SByte",
- "Byte",
- "Int16",
- "UInt16",
- "Int32",
- "UInt32",
- "Int64",
- "UInt64",
- "Char",
- "Single",
- "Double",
- "DateTime",
- "String",
- "Boolean",
- "Decimal"))
- continue;
- columns.Add(p);
- }
- //if (ColumnProperties.Count() != columns.Count)
- //{
- // var firstMissingColumn = ColumnProperties.Select(i => i.Name).Except(columns.Select(i => i.Name)).First();
- // throw new MwApplicationException("bulk import error: missing column " + firstMissingColumn);
- //}
- return columns;
- }
- private DataTable BuildTemporaryImportTable(IEnumerable<TObject> list, List<PropertyInfo> columns)
- {
- var dt = new DataTable
- {
- TableName = "#tmpImport"
- };
- foreach (PropertyInfo p in columns)
- {
- Type type = p.PropertyType;
- if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>))
- dt.Columns.Add(new DataColumn(p.Name, Nullable.GetUnderlyingType(type)));
- else
- dt.Columns.Add(new DataColumn(p.Name, type));
- }
- foreach (TObject item in list)
- {
- DataRow row = dt.NewRow();
- foreach (PropertyInfo columnProperty in columns)
- {
- object value = columnProperty.GetValue(item, null);
- if (value != null)
- row[columnProperty.Name] = value;
- else
- row[columnProperty.Name] = DBNull.Value;
- }
- dt.Rows.Add(row);
- }
- PopulateTemporaryImportTable(columns, dt);
- return dt;
- }
- private void PopulateTemporaryImportTable(IEnumerable<PropertyInfo> columns, DataTable importTable)
- {
- var fields = string.Join(",", columns.Select(i => "[" + i.Name + "]"));
- var sql = string.Format(
- "IF OBJECT_ID('tempdb..{0}', 'U') IS NOT NULL{2} DROP TABLE {0}; {2}{2} select top 0 {3} into {0} from {1}",
- importTable.TableName, TableName, Environment.NewLine, fields);
- ExecuteSqlCommand(sql);
- using (var bulkCopy = new SqlBulkCopy(GetSqlConnection(), SqlBulkCopyOptions.Default, (SqlTransaction)_unitOfWork.Transaction.UnderlyingTransaction))
- {
- //Setting timeout to 0 means no time out for this command will not timeout until upload complete.
- //Change as per you
- bulkCopy.BulkCopyTimeout = 0;
- bulkCopy.DestinationTableName = importTable.TableName;
- //write the data in the "dataTable"
- bulkCopy.WriteToServer(importTable);
- }
- }
- private StringBuilder BuildMergeQuery(string mergeKeyName, IEnumerable<Expression<Func<TObject, object>>> skipUpdateFields, string temporaryTableName, List<PropertyInfo> columns)
- {
- var newLine = Environment.NewLine;
- var query = new StringBuilder();
- query.AppendFormat("MERGE {0} AS Target {1}", TableName, newLine);
- query.AppendFormat("USING {0} as Source {1}", temporaryTableName, newLine);
- query.AppendFormat("ON Target.[{0}] = Source.[{0}] {1}", mergeKeyName, newLine);
- var updateFields = GetUpdateFields(skipUpdateFields, mergeKeyName, columns);
- var insertFields = GetInsertFields(columns);
- if (updateFields.Any())
- {
- query.AppendFormat("WHEN MATCHED THEN UPDATE SET {0}", newLine);
- query.Append(string.Join("," + newLine, updateFields));
- }
- query.AppendLine(" WHEN NOT MATCHED BY TARGET THEN");
- query.AppendFormat("INSERT ({0}) {1}", string.Join(",", insertFields.Select(i => "[" + i + "]")), newLine);
- query.AppendFormat("VALUES ({0});", string.Join(",", insertFields.Select(i => "Source." + "[" + i + "]")));
- return query;
- }
- private StringBuilder BuildInsertQuery(string temporaryTableName, List<PropertyInfo> columns)
- {
- var query = new StringBuilder();
- var insertFields = GetInsertFields(columns);
- query.AppendFormat("INSERT INTO {0} ({1}) {2}", TableName, string.Join(",", insertFields.Select(i => "[" + i + "]")), Environment.NewLine);
- query.AppendFormat("SELECT {0}", string.Join(",", insertFields.Select(i => "[" + i + "]")));
- query.AppendFormat("FROM {0};", temporaryTableName);
- return query;
- }
- public virtual int Count
- {
- get { return DbSet.Count(); }
- }
- public virtual TObject[] Get()
- {
- try
- {
- return Queryable.ToArray();
- }
- finally
- {
- _queryWithInclude = null;
- _queryOrdered = null;
- _queryFiltered = null;
- }
- }
- public virtual int Sum(Expression<Func<TObject, int?>> selector)
- {
- try
- {
- if (_queryOrdered != null)
- {
- return _queryOrdered.Sum(selector) ?? 0;
- }
- return _queryFiltered.Sum(selector) ?? 0;
- }
- finally
- {
- _queryWithInclude = null;
- _queryOrdered = null;
- _queryFiltered = null;
- }
- }
- public TObject FirstOrDefault()
- {
- try
- {
- return Queryable.FirstOrDefault();
- }
- finally
- {
- _queryWithInclude = null;
- _queryOrdered = null;
- _queryFiltered = null;
- }
- }
- public virtual IRepositoryOrdered<TObject> OrderByDescending<TKey>(Expression<Func<TObject, TKey>> orderBy) where TKey : IComparable
- {
- _queryOrdered = Queryable.OrderByDescending(orderBy);
- return this;
- }
- public virtual IRepositoryOrdered<TObject> OrderBy<TKey>(Expression<Func<TObject, TKey>> orderBy) where TKey : IComparable
- {
- _queryOrdered = Queryable.OrderBy(orderBy);
- return this;
- }
- public virtual TObject Find(Expression<Func<TObject, bool>> predicate)
- {
- return Queryable.FirstOrDefault(predicate);
- }
- public IRepositoryOrdered<TObject> ThenByDescending<TKey>(Expression<Func<TObject, TKey>> orderBy) where TKey : IComparable
- {
- _queryOrdered = Queryable.OrderBy(orderBy);
- return this;
- }
- public IRepositoryOrdered<TObject> ThenBy<TKey>(Expression<Func<TObject, TKey>> orderBy) where TKey : IComparable
- {
- _queryOrdered = Queryable.OrderBy(orderBy);
- return this;
- }
- /// <summary>
- /// </summary>
- /// <param name="index">Specified the page index.</param>
- /// <param name="size">Specified the page size</param>
- /// <returns></returns>
- public TObject[] Page(int index = 0, int size = 50)
- {
- int totalCount;
- return Page(out totalCount, false, index, size);
- }
- /// <summary>
- /// </summary>
- /// <param name="totalCount">Total count of before apply paging</param>
- /// <param name="index">Specified the page index.</param>
- /// <param name="size">Specified the page size</param>
- /// <returns></returns>
- public TObject[] Page(out int totalCount, int index = 0, int size = 50)
- {
- return Page(out totalCount, true, index, size);
- }
- protected void ConfigureConnection()
- {
- if (Context.Database.Connection.State == ConnectionState.Closed)
- Context.Database.Connection.Open();
- DbCommand command = Context.Database.Connection.CreateCommand();
- command.CommandText = "SET TRANSACTION ISOLATION LEVEL READ UNCOMMITTED;";
- try
- {
- command.ExecuteNonQuery();
- }
- catch (SqlException)
- {
- Context.Database.Connection.Open();
- command.ExecuteNonQuery();
- }
- }
- public virtual IQueryable<TObject> GetAllLazyLoad(Expression<Func<TObject, bool>> filter, params Expression<Func<TObject, object>>[] children)
- {
- children.ToList().ForEach(x => DbSet.Include(x).Load());
- return DbSet.AsQueryable();
- }
- private TObject[] Page(out int totalCount, bool countTotal, int index = 0, int size = 50)
- {
- int skipCount = index * size;
- try
- {
- TObject[] result = skipCount == 0
- ? Queryable.Take(size).ToArray()
- : Queryable.Skip(skipCount).Take(size).ToArray();
- if (countTotal)
- totalCount = Queryable.Count();
- else
- totalCount = -1;
- return result;
- }
- finally
- {
- _queryWithInclude = null;
- _queryOrdered = null;
- _queryFiltered = null;
- }
- }
- protected IRepositoryInclude<TObject> Include(string path)
- {
- _queryWithInclude = _queryWithInclude != null ? _queryWithInclude.Include(path) : DbSet.Include(path);
- return this;
- }
- private List<string> GetUpdateFields(IEnumerable<Expression<Func<TObject, object>>> skipUpdateFields, string matchKey, IEnumerable<PropertyInfo> columns)
- {
- var skipFields = new List<string>
- {
- matchKey.ToLower()
- };
- foreach (var skipUpdateField in skipUpdateFields)
- {
- skipFields.Add(GetPropertyNameFromExpression(skipUpdateField).ToLower());
- }
- var updateFields = new List<string>();
- foreach (PropertyInfo columnProperty in columns)
- {
- if (skipFields.Contains(columnProperty.Name.ToLower()))
- continue;
- updateFields.Add(string.Format("Target.[{0}] = Source.[{0}]", columnProperty.Name));
- }
- return updateFields;
- }
- protected void ExecuteSqlCommand(string query)
- {
- SqlCommand command = CreateSqlCommand(query);
- command.CommandTimeout = (60 * 5);
- command.ExecuteNonQuery();
- }
- protected SqlCommand CreateSqlCommand(string query)
- {
- SqlConnection dbConnection = GetSqlConnection();
- return new SqlCommand(query, dbConnection, (SqlTransaction)_unitOfWork.Transaction.UnderlyingTransaction);
- }
- private SqlConnection GetSqlConnection()
- {
- var dbConnection = Context.Database.Connection as SqlConnection;
- if (dbConnection.State == ConnectionState.Closed)
- dbConnection.Open();
- return dbConnection;
- }
- private IEnumerable<PropertyInfo> OrderPropertiesAsInDbTable(IEnumerable<PropertyInfo> propertyInfos)
- {
- const string query = "select COLUMN_NAME from INFORMATION_SCHEMA.COLUMNS " +
- "where TABLE_SCHEMA = 'dbo' " +
- "and '[dbo].['+rtrim(TABLE_NAME)+']' = {0} " +
- "order by ORDINAL_POSITION";
- DbRawSqlQuery<string> columns = Context.Database.SqlQuery<string>(query, TableName);
- var result = new List<PropertyInfo>();
- foreach (string column in columns)
- {
- PropertyInfo property = propertyInfos.FirstOrDefault(pi => pi.Name.ToLower().Trim() == column.ToLower().Trim());
- if (property == null)
- throw new Exception("Cannot find field:" + column);
- result.Add(property);
- }
- return result;
- }
- protected List<string> GetInsertFields(List<PropertyInfo> columns)
- {
- string query =
- "select COLUMN_NAME from INFORMATION_SCHEMA.COLUMNS where TABLE_SCHEMA = 'dbo' and COLUMNPROPERTY(object_id(TABLE_NAME), COLUMN_NAME, 'IsIdentity') = 1 and '[dbo].['+rtrim(TABLE_NAME)+']' = {0} order by ORDINAL_POSITION";
- string identity = Context.Database.SqlQuery<string>(query, TableName).FirstOrDefault();
- return columns.Where(i => i.Name != identity).Select(i => i.Name).ToList();
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement