Guest User

Untitled

a guest
Dec 18th, 2018
129
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.93 KB | None | 0 0
  1. const DEFAULT_OPTS: any = {
  2. activation: 'sigmoid',
  3. };
  4.  
  5. export function createModel(inputSize: number, outputSize: number, opts: any = {}): tf.Model {
  6. const model = tf.sequential();
  7. opts = Object.assign({}, DEFAULT_OPTS, opts);
  8. /* INPUT */
  9. model.add(
  10. tf.layers.dense({
  11. inputDim: inputSize,
  12. units: 64,
  13. activation: opts.activation,
  14. })
  15. );
  16.  
  17. /* HIDDEN */
  18. model.add(
  19. tf.layers.dense({
  20. inputDim: 64,
  21. units: 128,
  22. activation: opts.activation,
  23. })
  24. );
  25. model.add(
  26. tf.layers.dense({
  27. inputDim: 128,
  28. units: 64,
  29. activation: opts.activation,
  30. })
  31. );
  32.  
  33. /* OUTPUT */
  34. model.add(
  35. tf.layers.dense({
  36. inputDim: 64,
  37. units: outputSize,
  38. activation: 'relu',
  39. })
  40. );
  41.  
  42. return model;
  43. }
  44.  
  45. constructor(config: AgentConfig) {
  46. this.config = Object.assign({}, DEFAULT_CONFIG, config);
  47. this.memory = new Memory(config.memorySize);
  48. this.Q = createModel(config.inputSize, config.outputSize);
  49. this.QTarget = createModel(config.inputSize, config.outputSize);
  50. this.optimizer = tf.train.adam(this.config.learningRate);
  51. // Get weights refs
  52. this.weights = [];
  53. for (const w of this.Q.weights) {
  54. this.weights.push((w as any).val);
  55. }
  56. this.updateTarget();
  57. }
  58.  
  59. ...
  60.  
  61. public async learn() {
  62. if (this.stats.learnCount % this.config.refreshTargetEvery === 0) {
  63. this.updateTarget();
  64. }
  65. const batchSize = 32;
  66. if (this.memory.getLength() > batchSize) {
  67. const batch = this.memory.getBatch(batchSize);
  68. // Batch tensors
  69. const batchState = tf.tensor2d(batch.map((el: any) => el.state)).asType('float32');
  70. const batchAction = tf
  71. .oneHot(tf.tensor1d(batch.map((el: any) => actions.indexOf(el.action)), 'int32'), actions.length)
  72. .asType('float32');
  73. const batchReward = tf.tensor1d(batch.map((el: any) => el.reward)).asType('float32');
  74. const batchNextState = tf.tensor2d(batch.map((el: any) => el.nextState)).asType('float32');
  75. const batchDone = tf.tensor1d(batch.map((el: any) => el.done)).asType('float32');
  76.  
  77. // prodict nextState with targetNet
  78. const targets = this.calcTarget(batchReward, batchNextState, batchDone).asType('float32');
  79.  
  80. const loss = this.optimizer.minimize(
  81. () => {
  82. const x = tf.variable(batchState);
  83. const predictions = (this.Q.predict(x) as tf.Tensor).argMax(1).asType('float32');
  84. return tf.losses.meanSquaredError(targets, predictions) as any;
  85. },
  86. true,
  87. this.weights
  88. );
  89. console.log('loss');
  90. console.log(loss);
  91. }
  92. // qMaxNextState = reward + game + this.QTarget.predict()
  93. this.stats.learnCount++;
  94. return;
  95. }
  96.  
  97. private calcTarget(batchReward: any, batchNextState: any, batchDone: any) {
  98. return tf.tidy(() => {
  99. const maxQ = (this.QTarget.predict(batchNextState) as tf.Tensor).argMax(1).asType('float32');
  100. const targets = batchReward.add(maxQ.mul(tf.scalar(this.config.rewardDiscount)).mul(batchDone));
  101. return targets;
  102. });
  103. }
  104.  
  105. TypeError: Cannot read property 'values' of undefined
  106. at NodeJSKernelBackend.getInputTensorIds (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-node/dist/nodejs_kernel_backend.js:99:22)
  107. at NodeJSKernelBackend.executeSingleOutput (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-node/dist/nodejs_kernel_backend.js:123:73)
  108. at NodeJSKernelBackend.subtract (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-node/dist/nodejs_kernel_backend.js:248:21)
  109. at environment_1.ENV.engine.runKernel.$a (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/ops/binary_ops.ts:202:33)
  110. at /home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/engine.ts:206:22
  111. at Engine.scopedRun (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/engine.ts:167:19)
  112. at Engine.runKernel (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/engine.ts:202:10)
  113. at sub_ (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/ops/binary_ops.ts:201:21)
  114. at Object.sub (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/ops/operation.ts:46:24)
  115. at Tensor.sub (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/tensor.ts:842:22)
  116. at Object.$x (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/ops/unary_ops.ts:372:46)
  117. at _loop_1 (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/tape.ts:171:43)
  118. at Object.backpropagateGradients (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/dist/tape.js:112:9)
  119. at /home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/engine.ts:500:7
  120. at /home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/engine.ts:156:20
  121. at Engine.scopedRun (/home/clement/DEV/Crypto/influx-crypto-trader/node_modules/@tensorflow/tfjs-core/src/engine.ts:167:19)
  122.  
  123. NodeJSKernelBackend.prototype.getInputTensorIds = function (tensors) {
  124. var ids = [];
  125. for (var i = 0; i < tensors.length; i++) {
  126. var info = this.tensorMap.get(tensors[i].dataId);
  127. /*if (!info) {
  128. console.log('tensors[i]')
  129. console.log(this.tensorMap)
  130. console.log(tensors[i])
  131. console.log(info)
  132. }*/
  133. if (info.values != null) {
  134. info.id =
  135. this.binding.createTensor(info.shape, info.dtype, info.values);
  136. info.values = null;
  137. this.tensorMap.set(tensors[i].dataId, info);
  138. }
  139. ids.push(info.id);
  140. }
  141. return ids;
  142. };
  143.  
  144. Tensor {
  145. isDisposedInternal: true,
  146. shape: [ 32, 64 ],
  147. dtype: 'float32',
  148. size: 2048,
  149. strides: [ 64 ],
  150. dataId: {},
  151. id: 1494,
  152. rankType: '2' }
Add Comment
Please, Sign In to add comment