Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package com.dilisim.nn.dl4j.iterators
- import java.util.Locale
- import com.dilisim.nn.datasets.{Document, DocumentFields}
- import com.dilisim.nn.dl4j.models.ProcessTypes
- import com.dilisim.nn.dl4j.processors.{DL4JBOWProcessor, DL4JBOWSequenceProcessor, DL4JNERProcessor, DL4JNGramTokenTagProcessor, DL4JTokenCharProcessor, INDProcessor}
- import com.dilisim.nn.dl4j.splitters.{LineSplitSequencer, Sequencer}
- import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration
- import org.nd4j.linalg.api.memory.enums.{AllocationPolicy, LearningPolicy, MirroringPolicy, SpillPolicy}
- import org.nd4j.linalg.api.ndarray.INDArray
- import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
- import org.nd4j.linalg.factory.Nd4j
- import org.nd4j.linalg.indexing.NDArrayIndex
- abstract class GenericIterator(val limitSampleSize: Int, val ngramLength: Int = 3) extends MultiDataSetIterator {
- var inputProcessor = Array[(String, INDProcessor)]()
- var outputProcessor = Array[(String, INDProcessor)]()
- /* val basicConfig = WorkspaceConfiguration.builder()
- .policyAllocation(AllocationPolicy.STRICT)
- .policyLearning(LearningPolicy.FIRST_LOOP)
- .policyMirroring(MirroringPolicy.HOST_ONLY)
- .policySpill(SpillPolicy.EXTERNAL)
- .build();*/
- def build(): this.type = {
- var count = 0
- while (hasNext && count < limitSampleSize) {
- next()
- count += 1
- }
- println(s"Sample count: ${count}")
- reset()
- this
- }
- def inputSize(): Array[Int] = {
- inputProcessor.map(_._2.dictionarySize())
- }
- def outputSize(): Array[Int] = {
- outputProcessor.map(_._2.dictionarySize())
- }
- def prepend(processors: Array[(String, INDProcessor)], field: String, typ: String,
- locale: Locale, maxSize: Int, windowLength: Int): Array[(String, INDProcessor)] = {
- processors :+ (field, processor(typ, locale, maxSize, windowLength))
- }
- def prepend(processors: Array[(String, INDProcessor)], sequencer: Sequencer, field: String, typ: String,
- locale: Locale, maxSize: Int, windowLength: Int): Array[(String, INDProcessor)] = {
- processors :+ (field, processor(sequencer, typ, locale, maxSize, windowLength))
- }
- def processor(processorType: String, locale: Locale, size: Int, windowLength: Int): INDProcessor = {
- if (processorType.equals(ProcessTypes.bow)) new DL4JBOWProcessor(locale, size)
- else if (processorType.equals(ProcessTypes.tokenngrambow)) new DL4JTokenCharProcessor(size, size, windowLength)
- else if (processorType.equals(ProcessTypes.bowsequence)) new DL4JBOWSequenceProcessor(locale, size, windowLength)
- else if (processorType.equals(ProcessTypes.multiLabelAttrTag)) new DL4JNERProcessor(new LineSplitSequencer(indice = 1), size, size, windowLength, locale)
- else throw new Exception("Processor is not defined...")
- }
- def processor(sequencer: Sequencer, processorType: String, locale: Locale, size: Int, windowLength: Int): INDProcessor = {
- if (processorType.equals(ProcessTypes.bow)) new DL4JBOWProcessor(locale, size)
- else if (processorType.equals(ProcessTypes.tokenngrambow)) new DL4JNGramTokenTagProcessor(sequencer, ngramLength, size, windowLength)
- else if (processorType.equals(ProcessTypes.bowsequence)) new DL4JBOWSequenceProcessor(locale, size, windowLength)
- else if (processorType.equals(ProcessTypes.multiLabelAttrTag)) new DL4JNERProcessor(sequencer, size, size, windowLength, locale)
- else throw new Exception("Processor is not defined...")
- }
- def input(field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
- inputProcessor = prepend(inputProcessor, field, typ, locale, maxSize, windowLength)
- this
- }
- def input(sequencer: Sequencer, field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
- inputProcessor = prepend(inputProcessor, sequencer, field, typ, locale, maxSize, windowLength)
- this
- }
- def currentInputProcessor(): INDProcessor = {
- inputProcessor.last._2
- }
- def currentOutputProcessor(): INDProcessor = {
- outputProcessor.last._2
- }
- def output(field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
- outputProcessor = prepend(outputProcessor, field, typ, locale, maxSize, windowLength)
- this
- }
- def output(sequencer: Sequencer, field: String, typ: String, locale: Locale, maxSize: Int, windowLength: Int): this.type = {
- outputProcessor = prepend(outputProcessor, sequencer, field, typ, locale, maxSize, windowLength)
- this
- }
- def fieldString(document: Document, field: String): Seq[String] = {
- if (field.equals(DocumentFields.author)) Seq(document.author)
- else if (field.equals(DocumentFields.text)) Seq(document.text)
- else if (field.equals(DocumentFields.date)) Seq(document.date)
- else if (field.equals(DocumentFields.doctype)) Seq(document.docType)
- else if (field.equals(DocumentFields.keywords)) document.keywords
- else if (field.equals(DocumentFields.genre)) Seq(document.genre)
- else if (field.equals(DocumentFields.paragraphs)) document.paragraphs
- else if (field.equals(DocumentFields.title)) Seq(document.title)
- else if (field.equals(DocumentFields.identifier)) Seq(document.docID)
- else null
- }
- def fieldString(element: String): Seq[String] = {
- Seq(element)
- }
- def batching(pairs: Seq[(Array[INDArray], Array[INDArray])], batchSize: Int): (Array[INDArray], Array[INDArray]) = {
- val inputArray = pairs.head._1.map(indArray => {
- val longs: Array[Long] = batchSize.toLong +: (if (indArray.shape().length == 1) indArray.shape() else indArray.shape().tail)
- Nd4j.create(longs, 'f')
- }).toSeq
- val outputArray = pairs.head._2.map(indArray => {
- val longs: Array[Long] = batchSize.toLong +: (if (indArray.shape().length == 1) indArray.shape() else indArray.shape().tail)
- Nd4j.create(longs, 'f')
- }).toSeq
- pairs.zipWithIndex.foreach { case ((inputs, outputs), batchIndice) => {
- //NDArrayIndex.point(kk)
- inputs.zipWithIndex.foreach { case (indArray, indice) => {
- inputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
- }
- }
- outputs.zipWithIndex.foreach { case (indArray, indice) => {
- outputArray(indice).put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
- }
- }
- }
- }
- (inputArray.toArray, outputArray.toArray)
- }
- def batchingMasking(pairs: Seq[(Array[(INDArray, INDArray)], Array[(INDArray, INDArray)])], batchSize: Int):
- (Array[(INDArray, INDArray)], Array[(INDArray, INDArray)]) = {
- var inputArray = pairs.head._1.map { case (indArray, maskArray) => {
- val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
- val longsMask: Array[Long] = batchSize.toLong +: maskArray.shape().tail
- (Nd4j.create(longs, 'f'), Nd4j.create(longsMask, 'f'))
- }
- }
- var outputArray = pairs.head._2.map { case (indArray, maskArray) => {
- val longs: Array[Long] = batchSize.toLong +: indArray.shape().tail
- val longsMask: Array[Long] = batchSize.toLong +: maskArray.shape().tail
- (Nd4j.create(longs, 'f'), Nd4j.create(longsMask, 'f'))
- }
- }
- pairs.zipWithIndex.foreach { case ((inputs, outputs), batchIndice) => {
- //NDArrayIndex.point(kk)
- inputs.zipWithIndex.foreach { case ((indArray, maskArray), indice) => {
- inputArray(indice)._1.put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
- inputArray(indice)._2.put(Array(NDArrayIndex.point(batchIndice)), maskArray.getRow(0))
- }
- }
- outputs.zipWithIndex.foreach { case ((indArray, maskArray), indice) => {
- outputArray(indice)._1.put(Array(NDArrayIndex.point(batchIndice)), indArray.getRow(0))
- outputArray(indice)._2.put(Array(NDArrayIndex.point(batchIndice)), maskArray.getRow(0))
- }
- }
- }
- }
- (inputArray, outputArray)
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement