Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System;
- using System.Collections.Generic;
- using System.Diagnostics;
- using System.Diagnostics.CodeAnalysis;
- using System.Drawing;
- using System.Drawing.Imaging;
- using System.IO;
- using System.Linq;
- using System.Runtime.CompilerServices;
- using System.Text;
- using System.Threading.Tasks;
- using NumSharp;
- using NumSharp.Backends;
- using NumSharp.Backends.Unmanaged;
- using Tensorflow;
- using static Tensorflow.Binding;
- using Buffer = System.Buffer;
- namespace ConsoleApp2
- {
- class Program
- {
- static void Main(string[] args)
- {
- var dncnn = new DnCNN();
- var l = new DirectoryInfo("./github_dataset").GetFiles().Select(f => new Bitmap(f.FullName)).ToList();
- dncnn.Train(l, l.Select(b=>(Bitmap)b.Clone()).ToList(), 100);
- }
- }
- public class DnCNN
- {
- private const int batch_size = 128;
- Tensor X, Y_, Y;
- Tensor loss;
- Operation optimizer;
- Session sess;
- public DnCNN()
- {
- X = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "input_image");
- Y_ = tf.placeholder(tf.float32, shape: (-1, -1, -1, 3), name: "clean_image");
- Y = BuildModel(X);
- loss = (1.0 / batch_size) * tf.nn.relu(Y_ - Y);
- optimizer = tf.train.AdamOptimizer(0.001f, name: "AdamOptimizer").minimize(loss);
- sess = new Session();
- var init = tf.global_variables_initializer();
- sess.run(init);
- }
- private Tensor BuildModel(Tensor input, bool is_training = true)
- {
- var output = tf.layers.conv2d(input, 64, new int[] {3, 3}, name: "conv1", padding: "same");
- for (int i = 2; i < 20; i++)
- {
- output = tf.layers.conv2d(output, 64, new int[] {3, 3}, name: "conv" + i, padding: "same", use_bias: false);
- }
- output = tf.layers.conv2d(output, 3, new int[] {3, 3}, name: "conv20", padding: "same", use_bias: false);
- return input - output;
- }
- public void Train(List<Bitmap> inputImages, List<Bitmap> outputImages, int epochs = 1)
- {
- var sw = new Stopwatch();
- sw.Start();
- NDArray x_train = GenerateDataset(inputImages);
- NDArray y_train = GenerateDataset(outputImages);
- var saver = new Saver();
- print($"Dataset created in {sw.ElapsedMilliseconds}ms");
- sw.Restart();
- for (int i = 0; i < epochs; i++)
- {
- sess.run(optimizer, (X, (x_train)), (Y_, (y_train)));
- // Calculate and display the batch loss and accuracy
- var result = sess.run(new[] {loss}, new FeedItem(X, x_train), new FeedItem(Y_, y_train));
- print($"iter {i.ToString("000")}: {sw.ElapsedMilliseconds}ms");
- sw.Restart();
- saver.save(sess, @"E:\Downloads\ciao.ckpt");
- }
- }
- public Bitmap Evaluate(List<Bitmap> inputImages)
- {
- var sw = new Stopwatch();
- sw.Start();
- NDArray x_eval = GenerateDataset(inputImages);
- var output = sess.run(Y, new FeedItem(X, x_eval));
- sw.Stop();
- print($"Inference done in {sw.ElapsedMilliseconds}ms");
- output = 255 * output;
- return image(output[0].astype(NPTypeCode.Byte));
- }
- public NDArray GenerateDataset(List<Bitmap> imgs)
- {
- return np.vstack(imgs.Select(img => img.ToNDArray(false, false)).ToArray());
- }
- public Bitmap image(NDArray nd)
- {
- return nd.ToBitmap(nd.shape[2], nd.shape[1]);
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement