// Jeffrey Becker Jan 2012 // Released to the public domain, use at your own risk! using System; using System.Collections.Generic; using System.Configuration; using System.Data; using System.Data.Common; using System.Linq; using System.Reflection; using System.Text.RegularExpressions; namespace CrapyDataAccess { public class Db : IDisposable { private readonly DbConnection _connection; private readonly HydraterFactory _hydraterFactory = new HydraterFactory(); private DbTransaction _transaction; private readonly DbProviderFactory _factory; public Db(string connectionStringName) { var css = ConfigurationManager.ConnectionStrings[connectionStringName]; if (css == null) throw new InvalidOperationException(String.Format("No ConnectionString {0} defined.", connectionStringName)); if (String.IsNullOrEmpty(css.ConnectionString)) throw new InvalidOperationException( String.Format("The ConnectionString {0} doesn't have a connectionString value defined.", connectionStringName)); if (String.IsNullOrEmpty(css.ProviderName)) throw new InvalidOperationException( String.Format("The ConnectionString {0} doesn't have a providerName defined.", connectionStringName)); _factory = DbProviderFactories.GetFactory(css.ProviderName); if(_factory == null) _connection = _factory.CreateConnection(); _connection.ConnectionString = css.ConnectionString; _connection.Open(); } public void Dispose() { if (_transaction != null) { _transaction.Commit(); _transaction.Dispose(); } _connection.Dispose(); } public IEnumerable Execute(string sql, object args = null) { using (var cmd = CreateCommand(sql, args)) using (var reader = cmd.ExecuteReader()) { var hydrater = _hydraterFactory.GetHydrater(); while (reader.Read()) { yield return hydrater.Hydrate(reader); } } } public T Scalar(string sql, object args = null) { return Execute(sql, args).First(); } public void Run(string sql, object args = null) { using (var cmd = CreateCommand(sql, args)) { cmd.ExecuteNonQuery(); } } public void Begin() { if (_transaction == null) { _transaction = _connection.BeginTransaction(); } } public void Commit() { if (_transaction != null) { _transaction.Commit(); _transaction.Dispose(); _transaction = null; } } public void Rollback() { if (_transaction != null) { _transaction.Rollback(); _transaction.Dispose(); _transaction = null; } } private DbCommand CreateCommand(string sql, object args) { if (String.IsNullOrEmpty(sql)) throw new ArgumentException("Sql must not be null or empty"); var cmd = _connection.CreateCommand(); cmd.CommandText = sql; if (_transaction != null) { cmd.Transaction = _transaction; } if (args != null) { var paramsToAdd = args.GetType() .GetProperties() .Select(p => { var param = cmd.CreateParameter(); param.DbType = TypeConverter.ToDbType(p.PropertyType); param.ParameterName = "@" + p.Name; param.Value = p.GetValue(args, null) ?? DBNull.Value; param.Direction = ParameterDirection.Input; return param; }); foreach (var p in paramsToAdd) { if (Regex.IsMatch(sql, "\\b" + p.ParameterName + "\\b", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant)) { cmd.Parameters.Add(p); } } } return cmd; } #region Nested type: ColumnInfo private class ColumnInfo { private readonly bool _isNullable; private readonly Type _parentType; private readonly PropertyInfo _propertyInfo; public ColumnInfo(Type parentType, PropertyInfo propertyInfo) { _parentType = parentType; _propertyInfo = propertyInfo; Name = _propertyInfo.Name; _isNullable = TypeConverter.IsNullable(_propertyInfo.PropertyType); } public string Name { get; protected set; } public void AssignValue(object o, object instance) { try { if (_isNullable && o == DBNull.Value) o = null; _propertyInfo.SetValue(instance, o, null); } catch (ArgumentException ex) { throw new InvalidOperationException( String.Format("Error hydrating {0}.{1}: \"{2}\"", _parentType.Name, Name, ex.Message), ex); } } } #endregion #region Nested type: EntityHydrater private class EntityHydrater : IHydrater { private readonly EntityInfo _entityInfo; private IDictionary _lookup; public EntityHydrater(EntityInfo entityInfo) { _entityInfo = entityInfo; } #region IHydrater Members public T Hydrate(IDataRecord dr) { _lookup = _lookup ?? _entityInfo.BuildLookup(dr); object instance = Activator.CreateInstance(typeof (T)); for (int i = 0; i < dr.FieldCount; i++) { if (_lookup.ContainsKey(i)) { _lookup[i].AssignValue(dr.GetValue(i), instance); } } return (T) instance; } #endregion } #endregion #region Nested type: EntityInfo private class EntityInfo { private readonly ICollection _columns; private readonly Type _type; public EntityInfo(Type type) { _type = type; bool hasParameterlessConstructor = _type.GetConstructors() .Any(c => !c.GetParameters().Any()); if (!hasParameterlessConstructor) throw new InvalidOperationException( String.Format("{0} does not have a public parameterless constructor.", _type.Name)); _columns = _type.GetProperties() .Where(p => p.CanWrite) .Select(p => new ColumnInfo(_type, p)) .ToList(); } public object CreateInstance() { return Activator.CreateInstance(_type); } public IDictionary BuildLookup(IDataRecord dr) { return Enumerable.Range(0, dr.FieldCount) .Select(i => new { Index = i, Column = _columns.FirstOrDefault(c => string.Compare(c.Name, dr.GetName(i), true) == 0) }) .Where(x => x.Column != null) .ToDictionary(x => x.Index, x => x.Column); } } #endregion #region Nested type: HydraterFactory private class HydraterFactory { private readonly Dictionary _entityInfos = new Dictionary(); public IHydrater GetHydrater() { Type type = typeof (T); if (TypeConverter.IsScalarType(type)) return new ScalarHydrater(); if (!_entityInfos.ContainsKey(type)) { _entityInfos.Add(type, new EntityInfo(type)); } return new EntityHydrater(_entityInfos[type]); } } #endregion #region Nested type: IHydrater private interface IHydrater { T Hydrate(IDataRecord dr); } #endregion #region Nested type: ScalarHydrater private class ScalarHydrater : IHydrater { private readonly bool _isNullable; private readonly Type _type; public ScalarHydrater() { Type type = typeof (T); _isNullable = TypeConverter.IsNullable(type); _type = _isNullable ? TypeConverter.GetNonNullable(type) : type; } #region IHydrater Members public T Hydrate(IDataRecord dr) { if (dr.FieldCount == 0) throw new InvalidOperationException("No fields were returned"); object obj = dr.GetValue(0); if (obj == DBNull.Value) { if (!_isNullable) { throw new InvalidOperationException(String.Format("The Column {0} is NULL", dr.GetName(0))); } else return (T) (object) null; } else return (T) Convert.ChangeType(obj, _type); } #endregion } #endregion #region Nested type: TypeConverter private static class TypeConverter { private static readonly Dictionary TypeToDbType = new Dictionary { {typeof (string), DbType.String}, {typeof (DateTime), DbType.DateTime}, {typeof (DateTime?), DbType.DateTime }, {typeof (int), DbType.Int32}, {typeof (int?), DbType.Int32}, {typeof (long), DbType.Int64}, {typeof (long?), DbType.Int64}, {typeof (bool), DbType.Boolean}, {typeof (bool?), DbType.Boolean}, {typeof (byte[]), DbType.Binary}, {typeof (decimal), DbType.Decimal}, {typeof (decimal?), DbType.Decimal}, {typeof (double), DbType.Double}, {typeof (double?), DbType.Double}, {typeof (float), DbType.Single}, {typeof (float?), DbType.Single}, {typeof (Guid), DbType.Guid}, {typeof (Guid?), DbType.Guid} }; public static bool IsScalarType(Type type) { return TypeToDbType.ContainsKey(type); } public static bool IsNullable(Type type) { return !type.IsValueType || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof (Nullable<>)); } public static Type GetNonNullable(Type type) { return type.GetGenericArguments()[0]; } public static DbType ToDbType(Type type) { if (!TypeToDbType.ContainsKey(type)) { throw new InvalidOperationException( string.Format("Type {0} doesn't have a matching DbType configured", type.FullName)); } return TypeToDbType[type]; } } #endregion } }