SHARE
TWEET

Untitled

a guest Aug 25th, 2019 66 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. package com.dilisim.nn.dl4j.iterators
  2.  
  3. import java.util.Locale
  4.  
  5. import com.dilisim.nn.datasets.{Document, DocumentFields}
  6. import com.dilisim.nn.dl4j.models.ProcessTypes
  7. import com.dilisim.nn.dl4j.processors.{DL4JBOWProcessor, DL4JBOWSequenceProcessor, DL4JNERProcessor, DL4JNGramTokenTagProcessor, DL4JTokenCharProcessor, INDProcessor}
  8. import com.dilisim.nn.dl4j.splitters.{LineSplitSequencer, Sequencer}
  9. import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration
  10. import org.nd4j.linalg.api.memory.enums.{AllocationPolicy, LearningPolicy, MirroringPolicy, SpillPolicy}
  11. import org.nd4j.linalg.api.ndarray.INDArray
  12. import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
  13. import org.nd4j.linalg.factory.Nd4j
  14. import org.nd4j.linalg.indexing.NDArrayIndex
  15.  
  16. abstract class GenericIterator(val limitSampleSize: Int, val ngramLength: Int = 3) extends MultiDataSetIterator {
  17.   var inputProcessor = Array[(String, INDProcessor)]()
  18.   var outputProcessor = Array[(String, INDProcessor)]()
  19.  
  20.   /* val basicConfig = WorkspaceConfiguration.builder()
  21.      .policyAllocation(AllocationPolicy.STRICT)
  22.      .policyLearning(LearningPolicy.FIRST_LOOP)
  23.      .policyMirroring(MirroringPolicy.HOST_ONLY)
  24.      .policySpill(SpillPolicy.EXTERNAL)
  25.      .build();*/
  26.  
  27.   def build(): this.type = {
  28.     var count = 0
  29.     while (hasNext && count < limitSampleSize) {
  30.       next()
  31.       count += 1
  32.     }
  33.  
  34.     println(s"Sample count: ${count}")
  35.     reset()
  36.     this
  37.   }
  38.  
  39.  
  40.   def inputSize(): Array[Int] = {
  41.     inputProcessor.map(_._2.dictionarySize())
  42.   }
  43.  
  44.   def outputSize(): Array[Int] = {
  45.     outputProcessor.map(_._2.dictionarySize())
  46.   }
  47.  
  48.   def prepend(processors: Array[(String, INDProcessor)], field: String, typ: String,
  49.               locale: Locale, maxSize: Int, windowLength: Int): Array[(String, INDProcessor)] = {
  50.     processors :+ (field, processor(typ, locale, maxSize, windowLength))
  51.   }
  52.  
  53.   def prepend(processors: Array[(String, INDProcessor)], sequencer: Sequencer, field: String, typ: String,
  54.               locale: Locale, maxSize: Int, windowLength: Int): Array[(String, INDProcessor)] = {
  55.     processors :+ (field, processor(sequencer, typ, locale, maxSize, windowLength))
  56.   }
  57.  
  58.  
  59.   def processor(processorType: String, locale: Locale, size: Int, windowLength: Int): INDProcessor = {
  60.     if (processorType.equals(ProcessTypes.bow)) new DL4JBOWProcessor(locale, size)
  61.     else if (processorType.equals(ProcessTypes.tokenngrambow)) new DL4JTokenCharProcessor(size, size, windowLength)
  62.     else if (processorType.equals(ProcessTypes.bowsequence)) new DL4JBOWSequenceProcessor(locale, size, windowLength)
  63.     else if (processorType.equals(ProcessTypes.multiLabelAttrTag)) new DL4JNERProcessor(new LineSplitSequencer(indice = 1), size, size, windowLength, locale)
  64.     else throw new Exception("Processor is not defined...")
  65.   }
  66.  
  67.   def processor(sequencer: Sequencer, processorType: String, locale: Locale, size: Int, windowLength: Int): INDProcessor = {
  68.     if (processorType.equals(ProcessTypes.bow)) new DL4JBOWProcessor(locale, size)
  69.     else if (processorType.equals(ProcessTypes.tokenngrambow)) new DL4JNGramTokenTagProcessor(sequencer, ngramLength, size, windowLength)
  70.     else if (processorType.equals(ProcessTypes.bowsequence)) new DL4JBOWSequenceProcessor(locale, size, windowLength)
  71.     else if (processorType.equals(ProcessTypes.multiLabelAttrTag)) new DL4JNERProcessor(sequencer, size, size, windowLength, locale)
  72.     else throw new Exception("Processor is not defined...")
  73.   }
  74.  
  75.   def input(field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
  76.     inputProcessor = prepend(inputProcessor, field, typ, locale, maxSize, windowLength)
  77.     this
  78.   }
  79.  
  80.   def input(sequencer: Sequencer, field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
  81.     inputProcessor = prepend(inputProcessor, sequencer, field, typ, locale, maxSize, windowLength)
  82.     this
  83.   }
  84.  
  85.   def currentInputProcessor(): INDProcessor = {
  86.     inputProcessor.last._2
  87.   }
  88.  
  89.   def currentOutputProcessor(): INDProcessor = {
  90.     outputProcessor.last._2
  91.   }
  92.  
  93.   def output(field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
  94.     outputProcessor = prepend(outputProcessor, field, typ, locale, maxSize, windowLength)
  95.     this
  96.   }
  97.  
  98.   def output(sequencer: Sequencer, field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
  99.     outputProcessor = prepend(outputProcessor, sequencer, field, typ, locale, maxSize, windowLength)
  100.     this
  101.   }
  102.  
  103.   def fieldString(document: Document, field: String): Seq[String] = {
  104.  
  105.  
  106.     if (field.equals(DocumentFields.author)) Seq(document.author)
  107.     else if (field.equals(DocumentFields.text)) Seq(document.text)
  108.     else if (field.equals(DocumentFields.date)) Seq(document.date)
  109.     else if (field.equals(DocumentFields.doctype)) Seq(document.docType)
  110.     else if (field.equals(DocumentFields.keywords)) document.keywords
  111.     else if (field.equals(DocumentFields.genre)) Seq(document.genre)
  112.     else if (field.equals(DocumentFields.paragraphs)) document.paragraphs
  113.     else if (field.equals(DocumentFields.title)) Seq(document.title)
  114.     else if (field.equals(DocumentFields.identifier)) Seq(document.docID)
  115.     else null
  116.   }
  117.  
  118.   def fieldString(element: String): Seq[String] = {
  119.     Seq(element)
  120.   }
  121.  
  122.  
  123.   def batching(pairs: Seq[(Array[INDArray], Array[INDArray])], batchSize: Int): (Array[INDArray], Array[INDArray]) = {
  124.  
  125.  
  126.     val inputArray = pairs.head._1.map(indArray => {
  127.       val longs: Array[Long] = batchSize.toLong +: (if (indArray.shape().length == 1) indArray.shape() else indArray.shape().tail)
  128.       Nd4j.create(longs, 'f')
  129.     }).toSeq
  130.  
  131.     val outputArray = pairs.head._2.map(indArray => {
  132.       val longs: Array[Long] = batchSize.toLong +: (if (indArray.shape().length == 1) indArray.shape() else indArray.shape().tail)
  133.       Nd4j.create(longs, 'f')
  134.     }).toSeq
  135.  
  136.     pairs.zipWithIndex.foreach { case ((inputs, outputs), batchIndice) => {
  137.       //NDArrayIndex.point(kk)
  138.  
  139.       inputs.zipWithIndex.foreach { case (indArray, indice) => {
  140.         inputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
  141.       }
  142.       }
  143.       outputs.zipWithIndex.foreach { case (indArray, indice) => {
  144.         outputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
  145.       }
  146.       }
  147.     }
  148.     }
  149.  
  150.     (inputArray.toArray, outputArray.toArray)
  151.  
  152.   }
  153.  
  154.   def batchingMasking(pairs: Seq[(Array[(INDArray, INDArray)], Array[(INDArray, INDArray)])], batchSize: Int):
  155.   (Array[(INDArray, INDArray)], Array[(INDArray, INDArray)]) = {
  156.     var inputArray = pairs.head._1.map { case (indArray, maskArray) => {
  157.       val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
  158.       val longsMask: Array[Long] = batchSize.toLong +: maskArray.shape().tail
  159.       (Nd4j.create(longs, 'f'), Nd4j.create(longsMask, 'f'))
  160.     }
  161.     }
  162.  
  163.     var outputArray = pairs.head._2.map { case (indArray, maskArray) => {
  164.       val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
  165.       val longsMask: Array[Long] = batchSize.toLong +: maskArray.shape().tail
  166.       (Nd4j.create(longs, 'f'), Nd4j.create(longsMask, 'f'))
  167.     }
  168.     }
  169.  
  170.     pairs.zipWithIndex.foreach { case ((inputs, outputs), batchIndice) => {
  171.       //NDArrayIndex.point(kk)
  172.       inputs.zipWithIndex.foreach { case ((indArray, maskArray), indice) => {
  173.         inputArray(indice)._1.put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
  174.         inputArray(indice)._2.put(Array(NDArrayIndex.point(batchIndice)), maskArray.getRow(0))
  175.       }
  176.       }
  177.       outputs.zipWithIndex.foreach { case ((indArray, maskArray), indice) => {
  178.         outputArray(indice)._1.put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
  179.         outputArray(indice)._2.put(Array(NDArrayIndex.point(batchIndice)), maskArray.getRow(0))
  180.       }
  181.       }
  182.     }
  183.     }
  184.  
  185.     (inputArray, outputArray)
  186.   }
  187. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top