Advertisement
Guest User

Untitled

a guest
Aug 21st, 2019
80
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.95 KB | None | 0 0
  1. def batching(pairs: Seq[(Array[INDArray], Array[INDArray])], batchSize: Int): (Array[INDArray], Array[INDArray]) = {
  2.  
  3. val inputArray = pairs.head._1.map(indArray => {
  4. val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
  5. Nd4j.create(longs, 'f')
  6. }).toSeq
  7.  
  8. val outputArray = pairs.head._2.map(indArray => {
  9. val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
  10. Nd4j.create(longs, 'f')
  11. }).toSeq
  12.  
  13. pairs.zipWithIndex.foreach { case ((inputs, outputs), batchIndice) => {
  14. //NDArrayIndex.point(kk)
  15. inputs.zipWithIndex.foreach { case (indArray, indice) => {
  16. inputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
  17. }
  18. }
  19. outputs.zipWithIndex.foreach { case (indArray, indice) => {
  20. outputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
  21. }
  22. }
  23. }
  24. }
  25.  
  26. (inputArray.toArray, outputArray.toArray)
  27. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement