Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import com.google.gson.JsonElement;
- import com.google.gson.JsonObject;
- import com.google.gson.JsonParser;
- import org.deeplearning4j.nn.conf.BackpropType;
- import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
- import org.deeplearning4j.nn.conf.inputs.InputType;
- import org.deeplearning4j.nn.conf.layers.DenseLayer;
- import org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer;
- import org.deeplearning4j.nn.conf.layers.OutputLayer;
- import org.deeplearning4j.nn.modelimport.keras.preprocessing.text.KerasTokenizer;
- import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
- import org.deeplearning4j.nn.weights.WeightInit;
- import org.nd4j.linalg.activations.Activation;
- import org.nd4j.linalg.api.ndarray.INDArray;
- import org.nd4j.linalg.dataset.DataSet;
- import org.nd4j.linalg.factory.Nd4j;
- import org.nd4j.linalg.learning.config.Adam;
- import org.nd4j.linalg.lossfunctions.LossFunctions;
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.IOException;
- import java.io.InputStreamReader;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.HashSet;
- import java.util.List;
- import java.util.Map;
- import java.util.Set;
- public class TrainModel {
- private static final int VOCAB_SIZE = 500;
- private static final int EMBEDDING_DIM = 16;
- private static final int MAX_WORD_LEN = 20;
- public static void main(String[] args) {
- TrainingData trainingData = loadTrainingData();
- OneHotEncoder lblEncoder = new OneHotEncoder(trainingData.uniqueLabels());
- var wordEmbeddingsRaw = createWordEmbeddings(trainingData.trainingSentences().toArray(new String[0]));
- var wordEmbeddingsNd = Nd4j.create(padArray(wordEmbeddingsRaw, MAX_WORD_LEN));
- var trainingLabelsNd = lblEncoder.transform(trainingData.trainingLabels());
- DataSet trainingSet = new DataSet(wordEmbeddingsNd, trainingLabelsNd);
- var model = generateModel(trainingData.numUniqueLabels());
- model.init();
- System.out.println(model.summary());
- model.fit(trainingSet);
- }
- private static Integer[][] createWordEmbeddings(String[] words) {
- var tokenizer = getTokenizer();
- tokenizer.fitOnTexts(words);
- return tokenizer.textsToSequences(words);
- }
- private static INDArray getTestExampleFeatures() {
- String[] example = new String[] { "You suck" };
- var wordEmbeddingsRaw = createWordEmbeddings(example);
- return Nd4j.create(padArray(wordEmbeddingsRaw, MAX_WORD_LEN));
- }
- private static KerasTokenizer getTokenizer() {
- return new KerasTokenizer(VOCAB_SIZE, "!\"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n",
- true, " ", false, "<OOV>");
- }
- private static MultiLayerNetwork generateModel(int numClasses) {
- var modelConf = new NeuralNetConfiguration.Builder()
- .weightInit(WeightInit.NORMAL)
- .activation(Activation.RELU)
- .updater(new Adam.Builder().build())
- .list()
- .setInputType(InputType.feedForward(VOCAB_SIZE))
- .layer(new EmbeddingSequenceLayer.Builder()
- .nOut(EMBEDDING_DIM)
- .inputLength(MAX_WORD_LEN)
- .build())
- .layer(new DenseLayer.Builder()
- .nOut(16)
- .activation(Activation.RELU)
- .build())
- .layer(new DenseLayer.Builder()
- .nOut(16)
- .activation(Activation.RELU)
- .build())
- .layer(new OutputLayer.Builder()
- .nOut(numClasses)
- .activation(Activation.SOFTMAX)
- .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
- .build())
- .backpropType(BackpropType.Standard)
- .build();
- return new MultiLayerNetwork(modelConf);
- }
- private static JsonObject readJsonFile() {
- var stream = TrainModel.class.getClassLoader().getResourceAsStream("intents.json");
- try (BufferedReader br = new BufferedReader(
- new InputStreamReader(stream)
- )) {
- JsonParser parser = new JsonParser();
- return parser.parse(br).getAsJsonObject();
- } catch (IOException e) {
- e.printStackTrace();
- }
- return null;
- }
- private static TrainingData loadTrainingData() {
- JsonObject jsonObj = readJsonFile();
- var intentsArray = jsonObj.get("intents").getAsJsonArray();
- List<String> trainingSentences = new ArrayList<>();
- List<String> trainingLabels = new ArrayList<>();
- Set<String> uniqueLabels = new HashSet<>();
- for (JsonElement intentElement : intentsArray) {
- var intentObj = intentElement.getAsJsonObject();
- final String intentTag = intentObj.get("tag").getAsString();
- for (JsonElement patternElem : intentObj.get("patterns").getAsJsonArray()) {
- var patternStr = patternElem.getAsString();
- trainingSentences.add(patternStr);
- trainingLabels.add(intentTag);
- }
- uniqueLabels.add(intentTag);
- }
- return new TrainingData(trainingSentences, trainingLabels, uniqueLabels);
- }
- // Post-pads an array
- private static int[][] padArray(Integer[][] data, int maxLen) {
- int[][] newData = new int[data.length][maxLen];
- for (int i = 0; i < data.length; i++) {
- // 2
- int len = data[i].length;
- int fillLen = maxLen - len;
- if (fillLen < 0)
- fillLen = 0;
- for (int j = 0; j < fillLen; ++j) {
- newData[i][j] = 0;
- }
- int copyIdx = 0;
- for (int newDataIdx = fillLen; newDataIdx < maxLen; ++newDataIdx) {
- newData[i][newDataIdx] = data[i][copyIdx];
- copyIdx++;
- }
- }
- return newData;
- }
- public record TrainingData(List<String> trainingSentences, List<String> trainingLabels,
- Set<String> uniqueLabels) {
- public int numUniqueLabels() {
- return uniqueLabels.size();
- }
- }
- public static class OneHotEncoder {
- // int in this case represents index
- private final Map<String, Integer> encoder = new HashMap<>();
- private final Map<Integer, String> inverseEncoder = new HashMap<>();
- private int vectorSize = -1;
- public OneHotEncoder() {
- }
- public OneHotEncoder(Set<String> labels) {
- this.fit(labels);
- }
- public void fit(Set<String> labels) {
- encoder.clear();
- vectorSize = labels.size();
- int idx = 0;
- for (String label : labels) {
- encoder.put(label, idx);
- inverseEncoder.put(idx, label);
- idx++;
- }
- }
- // Returns a matrix with num labels rows and vector size columns.
- public INDArray transform(List<String> labels) {
- var array = Nd4j.zeros(labels.size(), vectorSize);
- for (int i = 0; i < labels.size(); i++) {
- String label = labels.get(i);
- Integer encodedIdx = encoder.get(label);
- if (encodedIdx == null) {
- throw new RuntimeException("Error encoding label '" + label +
- "'! No encoding found!");
- }
- array.putScalar(i, encodedIdx, 1);
- }
- return array;
- }
- public List<String> inverseTransform(INDArray encodedLabels) {
- long[] shape = encodedLabels.shape();
- int rows = (int) shape[0];
- int numLabels = (int) shape[1];
- List<String> transformedList = new ArrayList<>(rows);
- for (int i = 0; i < rows; i++) {
- var rowVector = encodedLabels.getRow(i);
- int encodedLabelidx = -1;
- for (int lblIdx = 0; lblIdx < numLabels; ++lblIdx) {
- if (rowVector.getInt(lblIdx) == 1) {
- encodedLabelidx = lblIdx;
- break;
- }
- }
- String inverseEncoding = inverseEncoder.get(encodedLabelidx);
- if (inverseEncoding == null) {
- throw new RuntimeException("No inverse encoding found for value '" + encodedLabelidx + "'!");
- }
- transformedList.add(inverseEncoding);
- }
- return transformedList;
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement