Guest User

Untitled

a guest
Jan 23rd, 2018
67
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.62 KB | None | 0 0
  1. NDShape shape({ 3, 2 });
  2. Variable W = InputVariable(shape, DataType::Float, true);
  3. auto data = make_shared<vector<float>>(3, float(2));
  4. for (int j = 0; j < shape[0]; j++)
  5. {
  6. data->push_back(2 + 1);
  7. }
  8. auto val = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(shape, data->data(), data->size(), DeviceDescriptor::CPUDevice()));
  9.  
  10. // | 2 2 2 | when transposed after
  11. // | 3 3 3 |
  12.  
  13. Variable x = InputVariable(NDShape({ 3, 1 }), DataType::Float, true);
  14. auto data = make_shared<vector<float>>(3, float(1));
  15. auto valx = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(NDShape({ 3, 1 }), data->data(), data->size(), DeviceDescriptor::CPUDevice()));
  16.  
  17. // [1 1 1]
  18.  
  19. auto left = Transpose(W);
  20. auto op = Times(left, x);
  21.  
  22. std::unordered_map<Variable, ValuePtr> outputValues = { { op->Output(), nullptr } };
  23. auto back = op->Forward({ {W, val}, {x, valx} }, outputValues, DeviceDescriptor::UseDefaultDevice(), { op->Output() });
  24. auto outVal = outputValues[op->Output()];
  25. wcout << outVal->AsString() << "n";
  26. printMatrix(outVal->Data()->DeepClone(DeviceDescriptor::CPUDevice()));
  27. fflush(stdout);
  28.  
  29. //prints out correctly [6 9]
  30.  
  31. ValuePtr rootGrad = MakeSharedObject<Value>(MakeSharedObject<NDArrayView>(DataType::Float, op->Output().Shape(), DeviceDescriptor::UseDefaultDevice()));
  32. rootGrad->Data()->SetValue(1.0f);
  33. unordered_map<Variable, ValuePtr> gradientOut = { { x, nullptr } };
  34. op->Backward(back, { { op->Output(), rootGrad } }, gradientOut);
  35. auto outGrad = gradientOut[x];
  36. wcout << outGrad->AsString() << "n";
  37. printMatrix(outGrad->Data()->DeepClone(DeviceDescriptor::CPUDevice()));
  38. fflush(stdout);
  39.  
  40. //[5 5 5] !!!
  41. // and not
  42. // |2 2 2|
  43. // |3 3 3| !!!
Add Comment
Please, Sign In to add comment