Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- namespace LinqToDB.Extensions.SqlServer
- {
- using LinqToDB;
- using LinqToDB.Data;
- using LinqToDB.Mapping;
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Linq.Expressions;
- using System.Reflection;
- using System.Runtime.CompilerServices;
- using System.Text;
- public static class Linq2DBExtensions
- {
- public static int TableValueConstructorMaxParameterCount = 100;
- public static int TableValueConstructorMaxNonParameterCount = 2000;
- private static readonly Type[] _parameterTypes =
- {
- typeof(string),
- typeof(char),
- typeof(char?),
- typeof(Guid),
- typeof(Guid?),
- typeof(DateTime),
- typeof(DateTime?),
- typeof(DateTimeOffset),
- typeof(DateTimeOffset?),
- typeof(TimeSpan),
- typeof(TimeSpan?)
- };
- public static Expression<Func<T, object>>[] ToPK<T>(this IEnumerable<T> source,
- params Expression<Func<T, object>>[] expressions) where T : class
- => expressions;
- public static IQueryable<T> ToQueryable<T>(this IEnumerable<T> numbers, IDataContext db,
- params Expression<Func<T, object>>[] pk)
- where T : class
- => ToQueryable(numbers, db, null, null, false, pk);
- public static IQueryable<T> ToQueryable<T>(this IEnumerable<T> numbers, IDataContext db,
- string tableName = null,
- string databaseName = null,
- bool dropFirst = false,
- Expression<Func<T, object>>[] pk = null)
- where T : class
- {
- var listData = (numbers as IList<T>) ?? numbers.ToList();
- var rowCount = listData.Count;
- var properties = typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance);
- if (rowCount > TableValueConstructorMaxNonParameterCount ||
- TableValueConstructorMaxParameterCount <
- (rowCount * properties.Count(x => _parameterTypes.Contains(x.PropertyType))))
- {
- return CreateTempTable(db, listData, pk, tableName, databaseName, dropFirst);
- }
- return BuildTableValuedConstructor(db, listData, rowCount, properties);
- }
- public static IQueryable<T> ToTableValuedConstructor<T>(this IEnumerable<T> numbers, IDataContext db)
- where T : class
- {
- var listData = (numbers as IList<T>) ?? numbers.ToList();
- var rowCount = listData.Count;
- var properties = typeof(T).GetProperties(BindingFlags.Public | BindingFlags.Instance);
- return BuildTableValuedConstructor(db, listData, rowCount, properties);
- }
- private static IQueryable<T> BuildTableValuedConstructor<T>(IDataContext db, IList<T> listData, int rowCount,
- PropertyInfo[] properties)
- where T : class
- {
- var builder = db.CreateSqlProvider();
- var fieldNames = string.Join(", ", properties.Select(x =>
- builder.Convert(x.Name, SqlProvider.ConvertType.NameToQueryField)));
- if (rowCount == default(int))
- {
- return db.FromSql<T>($"(SELECT {fieldNames} " +
- $"FROM (VALUES ({string.Join(", ", Enumerable.Range(0, properties.Length).Select(x => "(0)"))}))" +
- $" AS b({fieldNames}) WHERE 0 <> 0)", new object[0]);
- }
- var paramList = new List<DataParameter>();
- var data = BuildTableValuedData(db.MappingSchema, listData, properties, paramList);
- return db.FromSql<T>($"(SELECT {fieldNames} FROM (VALUES {data}) AS b({fieldNames}))", paramList.ToArray());
- }
- public static IQueryable<T> ToTempTable<T>(this IEnumerable<T> numbers, IDataContext db,
- params Expression<Func<T, object>>[] pk)
- where T : class
- => CreateTempTable(db, numbers, pk, null, null, false);
- public static IQueryable<T> ToTempTable<T>(this IEnumerable<T> numbers, IDataContext db,
- Expression<Func<T, object>>[] pk = null,
- string tableName = null,
- string databaseName = null,
- bool dropFirst = false)
- where T : class
- => CreateTempTable(db, numbers, pk, tableName, databaseName, dropFirst);
- public static IQueryable<T> CreateTempTable<T>(this IDataContext db, IEnumerable<T> numbers,
- Expression<Func<T, object>>[] pk = null,
- string tableName = null,
- string databaseName = null,
- bool dropFirst = false)
- where T : class
- {
- var tType = typeof(T);
- tableName = !string.IsNullOrEmpty(tableName) ? tableName : '#' + Guid.NewGuid().ToString();
- if (tableName[0] != '#')
- {
- tableName = '#' + tableName;
- }
- if (dropFirst)
- {
- DropTableIfExists<T>(db, tableName);
- }
- if (pk != null && (tType.IsAnonymousType() || !db.MappingSchema.GetEntites().Contains(tType)))
- {
- var entityMapper = db.MappingSchema.GetFluentMappingBuilder().Entity<T>();
- for (int i = 0; i < pk.Length; i++)
- {
- var pkField = pk[i];
- var fieldSetter = entityMapper.Property(pkField).IsPrimaryKey(i);
- if (DoesExpressionReturnString(pkField))
- {
- fieldSetter.HasLength(450).IsNullable(false);
- }
- }
- }
- var temp = db.CreateTable<T>(tableName, databaseName);
- temp.BulkCopy(new BulkCopyOptions { BulkCopyType = BulkCopyType.MultipleRows }, numbers);
- return temp;
- }
- private static readonly Type StringType = typeof(string);
- private static bool DoesExpressionReturnString<T>(Expression<Func<T, object>> expr)
- {
- var unary = expr.Body as UnaryExpression;
- MemberExpression memberExpr = null;
- if (unary != null)
- {
- memberExpr = unary.Operand as MemberExpression;
- }
- if (memberExpr == null)
- {
- memberExpr = expr.Body as MemberExpression;
- }
- if (memberExpr == null)
- {
- throw new ArgumentException("PK specified is invalid");
- }
- var propInfo = (PropertyInfo)memberExpr.Member;
- return propInfo.PropertyType == StringType;
- }
- private class SqlServerObject
- {
- public int? ObjectId { get; set; }
- }
- private static bool DropTableIfExists<T>(IDataContext db, string tableName)
- {
- if (string.IsNullOrEmpty(tableName) || !tableName.StartsWith("#", StringComparison.Ordinal))
- {
- return false;
- }
- var recordCount = db.FromSql<SqlServerObject>("SELECT OBJECT_ID({0}) " +
- $"AS {GetColumnName(nameof(SqlServerObject.ObjectId), db)}",
- new DataParameter("name", "tempdb.." + GetTableName(db, tableName), DataType.NVarChar))
- .FirstOrDefault();
- if (recordCount.ObjectId.HasValue)
- {
- db.DropTable<T>(tableName);
- return true;
- }
- return false;
- }
- private static string GetTableName(IDataContext db, string tableName, string schema = null, string database = null)
- => db.CreateSqlProvider()
- .ConvertTableName(new StringBuilder(), database, schema, tableName).ToString();
- private static string GetColumnName(string property, IDataContext db)
- => db.CreateSqlProvider()
- .Convert(property, SqlProvider.ConvertType.NameToQueryField).ToString();
- public static long BulkLoad<T>(this IDataContext db, IEnumerable<T> source,
- BulkCopyType bulkCopyType = BulkCopyType.MultipleRows)
- where T : class
- {
- var iTable = db.GetTable<T>();
- var result = iTable.BulkCopy(new BulkCopyOptions { BulkCopyType = bulkCopyType }, source);
- return result.RowsCopied;
- }
- private static string BuildTableValuedData<T>(MappingSchema mappingSchema, IEnumerable<T> data, PropertyInfo[] properties, List<DataParameter> parameters) where T : class
- {
- var param = Expression.Parameter(typeof(T), "p");
- var paramIndex = Expression.Parameter(typeof(int), "index");
- var paramSchema = Expression.Parameter(typeof(List<DataParameter>), "dataParams");
- var paramMapper = Expression.Parameter(typeof(MappingSchema), "dataSchema");
- var exps = new List<Expression>();
- foreach (var property in properties)
- {
- Expression exp = Expression.MakeMemberAccess(param, property);
- exp = Expression.Convert(exp, typeof(object));
- exps.Add(exp);
- }
- var arrayExp = Expression.NewArrayInit(typeof(object), exps.ToArray());
- Expression callExpr = Expression.Call(
- typeof(Linq2DBExtensions).GetMethod(nameof(Linq2DBExtensions.DataToString), BindingFlags.Static | BindingFlags.NonPublic),
- paramMapper,
- paramIndex,
- paramSchema,
- Expression.Constant(properties.Select(x => new PropertyData
- {
- IsParameterNeeded = _parameterTypes.Contains(x.PropertyType),
- Name = x.Name,
- Type = x.PropertyType,
- DataType = DataType.Undefined
- }).ToArray()),
- arrayExp);
- var lambda = Expression.Lambda<Func<MappingSchema, int, List<DataParameter>, T, string>>(callExpr, paramMapper, paramIndex, paramSchema, param);
- var func = lambda.Compile();
- return string.Concat("(", string.Join("), (", data.Select((x, i) => func(mappingSchema, i, parameters, x))), ")");
- }
- private class PropertyData
- {
- public Type Type { get; set; }
- public string Name { get; set; }
- public bool IsParameterNeeded { get; set; }
- public DataType DataType { get; set; }
- }
- private static string DataToString(MappingSchema mappingSchema, int rowIndex, List<DataParameter> dataParams, PropertyData[] needsParameter, params object[] values)
- {
- var data = new List<string>();
- var sb = new StringBuilder();
- for (int i = 0; i < values.Length; i++)
- {
- var value = values[i];
- if (needsParameter[i].IsParameterNeeded)
- {
- var parameter = new DataParameter(needsParameter[i].Name + "_" + rowIndex, value);
- data.Add(value == null ? "NULL" : "{" + dataParams.Count + "}");
- dataParams.Add(parameter);
- }
- else
- {
- sb.Clear();
- data.Add(mappingSchema.ValueToSqlConverter.Convert(sb, value).ToString());
- }
- }
- return string.Join(", ", data);
- }
- public static IQueryable<Dual> Dual(this IDataContext db)
- {
- var builder = db.CreateSqlProvider();
- var fieldName = builder.Convert("Discard", SqlProvider.ConvertType.NameToQueryField);
- var sql = $"(SELECT 1 AS {fieldName})";
- return db.FromSql<Dual>(sql);
- }
- private static bool IsAnonymousType(this Type type)
- {
- if (type == null)
- {
- throw new ArgumentNullException(nameof(type));
- }
- // HACK: The only way to detect anonymous types right now.
- return Attribute.IsDefined(type, typeof(CompilerGeneratedAttribute), false)
- && type.IsGenericType && type.Name.Contains("AnonymousType")
- && (type.Name.StartsWith("<>") || type.Name.StartsWith("VB$"))
- && (type.Attributes & TypeAttributes.NotPublic) == TypeAttributes.NotPublic;
- }
- }
- public class Dual
- {
- public bool Discard { get; set; }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement