Advertisement
Guest User

Untitled

a guest
Oct 17th, 2019
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.62 KB | None | 0 0
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Diagnostics.CodeAnalysis;
  5. using System.Drawing;
  6. using System.Drawing.Imaging;
  7. using System.IO;
  8. using System.Linq;
  9. using System.Runtime.CompilerServices;
  10. using System.Text;
  11. using System.Threading.Tasks;
  12. using NumSharp;
  13. using NumSharp.Backends;
  14. using NumSharp.Backends.Unmanaged;
  15. using Tensorflow;
  16. using static Tensorflow.Binding;
  17. using Buffer = System.Buffer;
  18.  
  19. namespace ConsoleApp2
  20. {
  21. class Program
  22. {
  23. static void Main(string[] args)
  24. {
  25. var dncnn = new DnCNN();
  26. var l = new DirectoryInfo("./github_dataset").GetFiles().Select(f => new Bitmap(f.FullName)).ToList();
  27. dncnn.Train(l, l.Select(b=>(Bitmap)b.Clone()).ToList(), 100);
  28. }
  29. }
  30.  
  31. public class DnCNN
  32. {
  33. private const int batch_size = 128;
  34.  
  35. Tensor X, Y_, Y;
  36. Tensor loss;
  37. Operation optimizer;
  38.  
  39. Session sess;
  40.  
  41. public DnCNN()
  42. {
  43. X = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "input_image");
  44. Y_ = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "clean_image");
  45.  
  46. Y = BuildModel(X);
  47.  
  48. loss = (1.0 / batch_size) * tf.nn.relu(Y_ - Y);
  49. optimizer = tf.train.AdamOptimizer(0.001f, name: "AdamOptimizer").minimize(loss);
  50.  
  51. sess = new Session();
  52. var init = tf.global_variables_initializer();
  53. sess.run(init);
  54. }
  55.  
  56. private Tensor BuildModel(Tensor input, bool is_training = true)
  57. {
  58. var output = tf.layers.conv2d(input, 64, new int[] {3, 3}, name: "conv1", padding: "same");
  59. for (int i = 2; i < 20; i++)
  60. {
  61. output = tf.layers.conv2d(output, 64, new int[] {3, 3}, name: "conv" + i, padding: "same", use_bias: false);
  62. }
  63.  
  64. output = tf.layers.conv2d(output, 3, new int[] {3, 3}, name: "conv20", padding: "same", use_bias: false);
  65.  
  66. return input - output;
  67. }
  68.  
  69. public void Train(List<Bitmap> inputImages, List<Bitmap> outputImages, int epochs = 1)
  70. {
  71. var sw = new Stopwatch();
  72. sw.Start();
  73. NDArray x_train = GenerateDataset(inputImages);
  74. NDArray y_train = GenerateDataset(outputImages);
  75.  
  76. var saver = new Saver();
  77. print($"Dataset created in {sw.ElapsedMilliseconds}ms");
  78. sw.Restart();
  79. for (int i = 0; i < epochs; i++)
  80. {
  81. sess.run(optimizer, (X, (x_train)), (Y_, (y_train)));
  82. // Calculate and display the batch loss and accuracy
  83. var result = sess.run(new[] {loss}, new FeedItem(X, x_train), new FeedItem(Y_, y_train));
  84. print($"iter {i.ToString("000")}: {sw.ElapsedMilliseconds}ms");
  85. sw.Restart();
  86.  
  87. saver.save(sess, @"E:\Downloads\ciao.ckpt");
  88. }
  89. }
  90.  
  91. public Bitmap Evaluate(List<Bitmap> inputImages)
  92. {
  93. var sw = new Stopwatch();
  94. sw.Start();
  95. NDArray x_eval = GenerateDataset(inputImages);
  96. var output = sess.run(Y, new FeedItem(X, x_eval));
  97. sw.Stop();
  98. print($"Inference done in {sw.ElapsedMilliseconds}ms");
  99.  
  100. output = 255 * output;
  101. return image(output[0].astype(NPTypeCode.Byte));
  102. }
  103.  
  104. public NDArray GenerateDataset(List<Bitmap> imgs)
  105. {
  106. return np.vstack(imgs.Select(img => img.ToNDArray(false, false)).ToArray());
  107. }
  108.  
  109. public Bitmap image(NDArray nd)
  110. {
  111. return nd.ToBitmap(nd.shape[2], nd.shape[1]);
  112. }
  113. }
  114.  
  115. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement