Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- public ArrayList<Weights> train(){
- if(tensorFlowInferenceInterface!=null){
- byte[] def = tensorFlowInferenceInterface.graph().toGraphDef();
- try(Graph graph = new Graph();
- Session session = new Session(graph);){
- graph.importGraphDef(def);
- session.runner().addTarget("init").run();
- Random random = new Random();
- for(int i = 1; i<=7; i++){
- for(int n = 0; n < 500; n++){
- float in = random.nextFloat();
- try(Tensor<Float> input = Tensors.create(in);
- Tensor<Float> target = Tensors.create(m*in + c);
- ){
- session.runner().feed("input", input)
- .feed("target", target)
- .addTarget("train")
- .run();
- }
- }
- weights.add(getVariables(session, i));
- }
- }
- }
- Log.d("Weights:", weights+"");
- return weights;
- }
Add Comment
Please, Sign In to add comment