Advertisement
Guest User

Untitled

a guest
Mar 1st, 2022
31
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 8.83 KB | None | 0 0
  1. import com.google.gson.JsonElement;
  2. import com.google.gson.JsonObject;
  3. import com.google.gson.JsonParser;
  4. import org.deeplearning4j.nn.conf.BackpropType;
  5. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  6. import org.deeplearning4j.nn.conf.inputs.InputType;
  7. import org.deeplearning4j.nn.conf.layers.DenseLayer;
  8. import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
  9. import org.deeplearning4j.nn.conf.layers.OutputLayer;
  10. import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
  11. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  12. import org.deeplearning4j.nn.weights.WeightInit;
  13. import org.nd4j.linalg.activations.Activation;
  14. import org.nd4j.linalg.api.ndarray.INDArray;
  15. import org.nd4j.linalg.dataset.DataSet;
  16. import org.nd4j.linalg.factory.Nd4j;
  17. import org.nd4j.linalg.learning.config.Adam;
  18. import org.nd4j.linalg.lossfunctions.LossFunctions;
  19.  
  20. import java.io.BufferedReader;
  21. import java.io.File;
  22. import java.io.IOException;
  23. import java.io.InputStreamReader;
  24. import java.util.ArrayList;
  25. import java.util.HashMap;
  26. import java.util.HashSet;
  27. import java.util.List;
  28. import java.util.Map;
  29. import java.util.Set;
  30.  
  31. public class TrainModel {
  32.  
  33.     private static final int VOCAB_SIZE = 500;
  34.     private static final int EMBEDDING_DIM = 16;
  35.     private static final int MAX_WORD_LEN = 20;
  36.  
  37.  
  38.     public static void main(String[] args) {
  39.         TrainingData trainingData = loadTrainingData();
  40.         OneHotEncoder lblEncoder = new OneHotEncoder(trainingData.uniqueLabels());
  41.  
  42.         var wordEmbeddingsRaw = createWordEmbeddings(trainingData.trainingSentences().toArray(new String[0]));
  43.        
  44.         var wordEmbeddingsNd = Nd4j.create(padArray(wordEmbeddingsRaw, MAX_WORD_LEN));
  45.         var trainingLabelsNd = lblEncoder.transform(trainingData.trainingLabels());
  46.  
  47.         DataSet trainingSet = new DataSet(wordEmbeddingsNd, trainingLabelsNd);
  48.  
  49.         var model = generateModel(trainingData.numUniqueLabels());
  50.         model.init();
  51.         System.out.println(model.summary());
  52.  
  53.         model.fit(trainingSet);
  54.     }
  55.  
  56.     private static Integer[][] createWordEmbeddings(String[] words) {
  57.         var tokenizer = getTokenizer();
  58.         tokenizer.fitOnTexts(words);
  59.         return tokenizer.textsToSequences(words);
  60.     }
  61.  
  62.     private static INDArray getTestExampleFeatures() {
  63.         String[] example = new String[] { "You suck" };
  64.         var wordEmbeddingsRaw = createWordEmbeddings(example);
  65.         return Nd4j.create(padArray(wordEmbeddingsRaw, MAX_WORD_LEN));
  66.     }
  67.  
  68.     private static KerasTokenizer getTokenizer() {
  69.         return new KerasTokenizer(VOCAB_SIZE, "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n",
  70.                 true, " ", false, "<OOV>");
  71.     }
  72.  
  73.     private static MultiLayerNetwork generateModel(int numClasses) {
  74.         var modelConf = new NeuralNetConfiguration.Builder()
  75.                 .weightInit(WeightInit.NORMAL)
  76.                 .activation(Activation.RELU)
  77.                 .updater(new Adam.Builder().build())
  78.                 .list()
  79.                 .setInputType(InputType.feedForward(VOCAB_SIZE))
  80.                 .layer(new EmbeddingSequenceLayer.Builder()
  81.                         .nOut(EMBEDDING_DIM)
  82.                         .inputLength(MAX_WORD_LEN)
  83.                         .build())
  84.                 .layer(new DenseLayer.Builder()
  85.                         .nOut(16)
  86.                         .activation(Activation.RELU)
  87.                         .build())
  88.                 .layer(new DenseLayer.Builder()
  89.                         .nOut(16)
  90.                         .activation(Activation.RELU)
  91.                         .build())
  92.                 .layer(new OutputLayer.Builder()
  93.                         .nOut(numClasses)
  94.                         .activation(Activation.SOFTMAX)
  95.                         .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
  96.                         .build())
  97.                 .backpropType(BackpropType.Standard)
  98.                 .build();
  99.  
  100.         return new MultiLayerNetwork(modelConf);
  101.     }
  102.  
  103.  
  104.  
  105.     private static JsonObject readJsonFile() {
  106.         var stream = TrainModel.class.getClassLoader().getResourceAsStream("intents.json");
  107.         try (BufferedReader br = new BufferedReader(
  108.                 new InputStreamReader(stream)
  109.             )) {
  110.             JsonParser parser = new JsonParser();
  111.             return parser.parse(br).getAsJsonObject();
  112.         } catch (IOException e) {
  113.             e.printStackTrace();
  114.         }
  115.         return null;
  116.     }
  117.  
  118.     private static TrainingData loadTrainingData() {
  119.         JsonObject jsonObj = readJsonFile();
  120.         var intentsArray = jsonObj.get("intents").getAsJsonArray();
  121.  
  122.         List<String> trainingSentences = new ArrayList<>();
  123.         List<String> trainingLabels = new ArrayList<>();
  124.         Set<String> uniqueLabels = new HashSet<>();
  125.  
  126.         for (JsonElement intentElement : intentsArray) {
  127.             var intentObj = intentElement.getAsJsonObject();
  128.             final String intentTag = intentObj.get("tag").getAsString();
  129.  
  130.             for (JsonElement patternElem : intentObj.get("patterns").getAsJsonArray()) {
  131.                 var patternStr = patternElem.getAsString();
  132.                 trainingSentences.add(patternStr);
  133.                 trainingLabels.add(intentTag);
  134.             }
  135.  
  136.             uniqueLabels.add(intentTag);
  137.         }
  138.  
  139.         return new TrainingData(trainingSentences, trainingLabels, uniqueLabels);
  140.     }
  141.  
  142.     // Post-pads an array
  143.     private static int[][] padArray(Integer[][] data, int maxLen) {
  144.         int[][] newData = new int[data.length][maxLen];
  145.         for (int i = 0; i < data.length; i++) {
  146.             // 2
  147.             int len = data[i].length;
  148.             int fillLen = maxLen - len;
  149.             if (fillLen < 0)
  150.                 fillLen = 0;
  151.  
  152.             for (int j = 0; j < fillLen; ++j) {
  153.                 newData[i][j] = 0;
  154.             }
  155.  
  156.             int copyIdx = 0;
  157.             for (int newDataIdx = fillLen; newDataIdx < maxLen; ++newDataIdx) {
  158.                 newData[i][newDataIdx] = data[i][copyIdx];
  159.                 copyIdx++;
  160.             }
  161.         }
  162.  
  163.         return newData;
  164.     }
  165.  
  166.     public record TrainingData(List<String> trainingSentences, List<String> trainingLabels,
  167.                                Set<String> uniqueLabels) {
  168.  
  169.         public int numUniqueLabels() {
  170.             return uniqueLabels.size();
  171.         }
  172.     }
  173.  
  174.     public static class OneHotEncoder {
  175.  
  176.         // int in this case represents index
  177.         private final Map<String, Integer> encoder = new HashMap<>();
  178.         private final Map<Integer, String> inverseEncoder = new HashMap<>();
  179.         private int vectorSize = -1;
  180.  
  181.         public OneHotEncoder() {
  182.         }
  183.  
  184.         public OneHotEncoder(Set<String> labels) {
  185.             this.fit(labels);
  186.         }
  187.        
  188.         public void fit(Set<String> labels) {
  189.             encoder.clear();
  190.             vectorSize = labels.size();
  191.             int idx = 0;
  192.             for (String label : labels) {
  193.                 encoder.put(label, idx);
  194.                 inverseEncoder.put(idx, label);
  195.                 idx++;
  196.             }
  197.         }
  198.  
  199.         // Returns a matrix with num labels rows and vector size columns.
  200.         public INDArray transform(List<String> labels) {
  201.             var array = Nd4j.zeros(labels.size(), vectorSize);
  202.  
  203.             for (int i = 0; i < labels.size(); i++) {
  204.                 String label = labels.get(i);
  205.                 Integer encodedIdx = encoder.get(label);
  206.                 if (encodedIdx == null) {
  207.                     throw new RuntimeException("Error encoding label '" + label +
  208.                             "'! No encoding found!");
  209.                 }
  210.  
  211.                 array.putScalar(i, encodedIdx, 1);
  212.             }
  213.  
  214.             return array;
  215.         }
  216.        
  217.         public List<String> inverseTransform(INDArray encodedLabels) {
  218.             long[] shape = encodedLabels.shape();
  219.             int rows = (int) shape[0];
  220.             int numLabels = (int) shape[1];
  221.             List<String> transformedList = new ArrayList<>(rows);
  222.  
  223.             for (int i = 0; i < rows; i++) {
  224.                 var rowVector = encodedLabels.getRow(i);
  225.  
  226.                 int encodedLabelidx = -1;
  227.                 for (int lblIdx = 0; lblIdx < numLabels; ++lblIdx) {
  228.                     if (rowVector.getInt(lblIdx) == 1) {
  229.                         encodedLabelidx = lblIdx;
  230.                         break;
  231.                     }
  232.                 }
  233.  
  234.                 String inverseEncoding = inverseEncoder.get(encodedLabelidx);
  235.                 if (inverseEncoding == null) {
  236.                     throw new RuntimeException("No inverse encoding found for value '" + encodedLabelidx + "'!");
  237.                 }
  238.                 transformedList.add(inverseEncoding);
  239.             }
  240.  
  241.             return transformedList;
  242.         }
  243.     }
  244.  
  245. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement