Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def batching(pairs: Seq[(Array[INDArray], Array[INDArray])], batchSize: Int): (Array[INDArray], Array[INDArray]) = {
- val inputArray = pairs.head._1.map(indArray => {
- val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
- Nd4j.create(longs, 'f')
- }).toSeq
- val outputArray = pairs.head._2.map(indArray => {
- val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
- Nd4j.create(longs, 'f')
- }).toSeq
- pairs.zipWithIndex.foreach { case ((inputs, outputs), batchIndice) => {
- //NDArrayIndex.point(kk)
- inputs.zipWithIndex.foreach { case (indArray, indice) => {
- inputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
- }
- }
- outputs.zipWithIndex.foreach { case (indArray, indice) => {
- outputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
- }
- }
- }
- }
- (inputArray.toArray, outputArray.toArray)
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement