Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- using System;
- using System.Collections.Generic;
- using System.Linq;
- using System.Threading.Tasks;
- using TensorFlow;
- using UnityEngine;
- namespace TFClassify
- {
- public class Segmentation
- {
- private TFGraph graph;
- private TFSession session;
- private int inputSize;
- private static string INPUT_NAME = "in_node"; //ImageTensor
- private static string OUTPUT_NAME = "module1/Sigmoid"; // SemanticPredictions
- public Segmentation(byte[] model, int inputSize)
- {
- #if (UNITY_ANDROID && !UNITY_EDITOR)
- TensorFlowSharp.Android.NativeBinding.Init();
- #endif
- this.inputSize = inputSize;
- this.graph = new TFGraph();
- this.graph.Import(new TFBuffer(model));
- this.session = new TFSession(this.graph);
- }
- public Task<float[,,,]> SegmentAsync(Color32[] data, int width, int height)
- {
- return Task.Run(() =>
- {
- //Debug.Log("Start : " + System.DateTime.Now.ToString());
- var pixel = new float[1,height,width,1];
- using (var tensor = TransformInput(data, width, height))
- {
- var runner = this.session.GetRunner();
- runner.AddInput(this.graph[INPUT_NAME][0], tensor)
- .Fetch(this.graph[OUTPUT_NAME][0]);
- var t = System.DateTime.Now.Ticks;
- var output = runner.Run();
- // Fetch the results from output:
- TFTensor result = output[0];
- var shape = result.Shape;
- //Debug.Log((float[][][][]) result.GetValue(jagged: true));
- try
- {
- pixel = (float[,,,])result.GetValue(jagged: false); // int[][]
- Debug.Log("Finish : " + (System.DateTime.Now.Ticks - t));
- }
- catch (Exception e)
- {
- Debug.Log(e.ToString());
- }
- }
- return pixel;
- });
- }
- public static TFTensor TransformInput(Color32[] data, int width, int height)
- {
- //byte[] mFlatIntValues = new byte[0];
- //try {
- // mFlatIntValues = new byte[width * height * 3];
- //}catch(Exception e)
- //{
- // Debug.Log(e.ToString());
- //}
- //for (int i = 0; i < data.Length; ++i) // data.Length = 513*513 = 263169
- //{
- // var color = data[i];
- // mFlatIntValues[i * 3 + 0] = (byte)color.r;
- // mFlatIntValues[i * 3 + 1] = (byte)color.g;
- // mFlatIntValues[i * 3 + 2] = (byte)color.b;
- //}
- float[] mFlatIntValues = new float[0];
- try
- {
- mFlatIntValues = new float[width * height * 3];
- }
- catch (Exception e)
- {
- Debug.Log(e.ToString());
- }
- for (int i = 0; i < data.Length; ++i) // data.Length = 513*513 = 263169
- {
- var color = data[i];
- mFlatIntValues[i * 3 + 0] = (float) color.r;
- mFlatIntValues[i * 3 + 1] = (float) color.g;
- mFlatIntValues[i * 3 + 2] = (float) color.b;
- }
- TFShape shape = new TFShape(1, height, width, 3);
- //Debug.Log("height " + height);
- //Debug.Log("width " + width);
- return TFTensor.FromBuffer(shape, mFlatIntValues, 0, mFlatIntValues.Length);
- //TFShape, Byte[] data, int start, int count
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement