internal static class Extensions
{
public static IEnumerable<string> GetPrimaryKeyForType(this Type elementType)
{
return from PropertyDescriptor property in TypeDescriptor.GetProperties(elementType)
where null != property.Attributes[typeof(KeyAttribute)]
select property.Name;
}
private static IOrderedQueryable SortCore(IQueryable source, string propertyName, ListSortDirection sortDirection, string methodName)
{
if (ListSortDirection.Descending == sortDirection)
{
methodName += "Descending";
}
var p = Expression.Parameter(source.ElementType, "p");
var prop = Expression.Property(p, propertyName);
var body = Expression.Lambda(prop, p);
var call = Expression.Call(typeof(Queryable), methodName,
new[] { p.Type, prop.Type },
source.Expression, Expression.Quote(body));
return (IOrderedQueryable)source.Provider.CreateQuery(call);
}
public static IOrderedQueryable Sort(this IQueryable source, IEnumerable<Tuple<string, ListSortDirection>> sortBy)
{
if (null == source) throw new ArgumentNullException("source");
if (null == sortBy) throw new ArgumentNullException("sortBy");
IOrderedQueryable result;
using (var enumerator = sortBy.Where(t => null != t).GetEnumerator())
{
if (!enumerator.MoveNext())
{
throw new ArgumentException(Resources.Error_EmptyEnumerable, "sortBy");
}
var tuple = enumerator.Current;
result = SortCore(source, tuple.Item1, tuple.Item2, "OrderBy");
while (enumerator.MoveNext())
{
tuple = enumerator.Current;
result = SortCore(result, tuple.Item1, tuple.Item2, "ThenBy");
}
}
return result;
}
public static long LongCount(this IQueryable source)
{
if (null == source) throw new ArgumentNullException("source");
Expression call = Expression.Call(typeof(Queryable), "LongCount",
new[] { source.ElementType },
source.Expression);
return source.Provider.Execute<long>(call);
}
public static IQueryable Take(this IQueryable source, int count)
{
if (null == source) throw new ArgumentNullException("source");
var call = Expression.Call(typeof(Queryable), "Take",
new[] { source.ElementType },
source.Expression, Expression.Constant(count));
return source.Provider.CreateQuery(call);
}
public static IQueryable SkipLong(this IQueryable source, long count)
{
if (null == source) throw new ArgumentNullException("source");
while (0L < count)
{
int toSkip;
if (int.MaxValue < count)
{
count -= int.MaxValue;
toSkip = int.MaxValue;
}
else
{
toSkip = (int)count;
count = 0L;
}
var call = Expression.Call(typeof(Queryable), "Skip",
new[] { source.ElementType },
source.Expression, Expression.Constant(toSkip));
source = source.Provider.CreateQuery(call);
}
return source;
}
public static IEnumerable Materialize(this IQueryable source)
{
if (null == source) throw new ArgumentNullException("source");
return Enumerable.Cast<object>(source).ToList();
}
}