Advertisement
Guest User

Untitled

a guest
Dec 2nd, 2015
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 12.52 KB | None | 0 0
  1. package net.yohanes.deepsums;
  2.  
  3. import java.io.File;
  4. import java.io.FileInputStream;
  5. import java.io.IOException;
  6. import java.util.*;
  7.  
  8. import com.fasterxml.jackson.databind.ObjectMapper;
  9. import org.apache.commons.io.IOUtils;
  10. import org.apache.commons.lang3.StringUtils;
  11. import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
  12. import org.deeplearning4j.datasets.iterator.BaseDatasetIterator;
  13. import org.deeplearning4j.datasets.iterator.DataSetIterator;
  14. import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
  15. import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
  16. import org.deeplearning4j.nn.conf.layers.RBM;
  17. import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
  18. import org.deeplearning4j.nn.weights.WeightInit;
  19. import org.deeplearning4j.optimize.api.IterationListener;
  20. import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
  21. import org.nd4j.linalg.api.ndarray.INDArray;
  22. import org.nd4j.linalg.dataset.DataSet;
  23. import org.nd4j.linalg.factory.Nd4j;
  24. import org.nd4j.linalg.lossfunctions.LossFunctions;
  25. import org.slf4j.Logger;
  26. import org.slf4j.LoggerFactory;
  27.  
  28.  
  29. public class DeepLearning {
  30.  
  31. private static Logger log = LoggerFactory.getLogger(DeepLearning.class);
  32.  
  33. private static MultiLayerNetwork model;
  34.  
  35. private final int numRows = 4;
  36. private int iterations = 10;
  37. private int seed = 123;
  38.  
  39. private MultiLayerNetwork getModel() {
  40. if (model != null) return model;
  41. log.info("Build model....");
  42. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
  43. .seed(seed) // Locks in weight initialization for tuning
  44. .iterations(iterations) // # training iterations predict/classify & backprop
  45. .learningRate(0.1) // Optimization step size
  46. .list(2) // # NN layers (doesn't count input layer)
  47. .layer(0, new RBM.Builder()
  48. .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
  49. .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
  50. .nIn(numRows) // # input nodes
  51. .nOut(numRows) // # output nodes
  52. .weightInit(WeightInit.UNIFORM) // Weight initialization
  53. .k(1) // # contrastive divergence iterations
  54. .activation("sigmoid") // Activation function type
  55. .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) // Loss function type
  56. .build())
  57. .layer(1, new RBM.Builder()
  58. .visibleUnit(RBM.VisibleUnit.GAUSSIAN)
  59. .hiddenUnit(RBM.HiddenUnit.RECTIFIED)
  60. .nIn(numRows) // # input nodes
  61. .nOut(numRows) // # output nodes
  62. .weightInit(WeightInit.UNIFORM) // Weight initialization
  63. .k(1) // # contrastive divergence iterations
  64. .activation("sigmoid") // Activation function type
  65. .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) // Loss function type
  66. .build())
  67. .build();
  68. model = new MultiLayerNetwork(conf);
  69. model.init();
  70. return model;
  71. }
  72.  
  73. public void train(String filepathTrain) {
  74. int numSamples = 1000;
  75. int batchSize = 10;
  76. int listenerFreq = 1;
  77. log.info("Load data....");
  78. DataSetIterator iterTrain = new DUCDataSetIterator(batchSize, numSamples, filepathTrain);
  79. DataSet train = iterTrain.next();
  80. MultiLayerNetwork trainModel = this.getModel();
  81. trainModel.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));
  82. trainModel.fit(train.getFeatureMatrix());
  83. }
  84.  
  85. public Summary summarize(String filepathTest, String filepathSentencesTest, float[] fThreshold) throws IOException {
  86. int numSamples = 1000;
  87. int batchSize = 100;
  88.  
  89. // Customizing params
  90. Nd4j.MAX_SLICES_TO_PRINT = -1;
  91. Nd4j.MAX_ELEMENTS_PER_SLICE = -1;
  92.  
  93. MultiLayerNetwork testModel = this.getModel();
  94. DataSetIterator iterTest = new DUCDataSetIterator(batchSize, numSamples, filepathTest);
  95. // do summarization based on threshold
  96. List<Integer> sentencesIds = new ArrayList<Integer>();
  97. int totalSentences = 0;
  98. int totalCorrect = 0;
  99. float[] maxThreshold = {0.0f, 0.0f, 0.0f, 0.0f};
  100. float[] averageThreshold = {0.0f, 0.0f, 0.0f, 0.0f};
  101. while (iterTest.hasNext()) {
  102. DataSet batch = iterTest.next();
  103. INDArray labels = batch.getLabels();
  104. // choose random threshold {f1, f2, f3, f4}
  105. INDArray o = testModel.output(batch.getFeatureMatrix());
  106. for (int i=0; i<o.rows();i++) {
  107. INDArray row = o.getRow(i);
  108. boolean take = true;
  109. for (int j=0; j<fThreshold.length; j++) {
  110. // get max threshold for analysis
  111. maxThreshold[j] = (maxThreshold[j] < row.getFloat(j)) ? row.getFloat(j) : maxThreshold[j];
  112. averageThreshold[j] += row.getFloat(j);
  113. if (row.getFloat(j) <= fThreshold[j]) {
  114. // don't use sentence as summary
  115. take = false;
  116. break;
  117. }
  118. }
  119. if (take) {
  120. sentencesIds.add(i);
  121. if (labels.getRow(i).getFloat(0) == 1.0f) {
  122. totalCorrect++;
  123. }
  124. }
  125. }
  126. totalSentences += o.rows();
  127. }
  128.  
  129. // generate summary
  130. List<String> sentences = DUCUtil.getSentencesList(filepathSentencesTest);
  131. float[][] rawData = DUCUtil.getRawData(filepathTest);
  132. List<String> summary = new ArrayList<String>();
  133. for (int id : sentencesIds) {
  134. summary.add(sentences.get(id));
  135. }
  136.  
  137. for (int i=0; i<averageThreshold.length; i++) {
  138. averageThreshold[i] = averageThreshold[i] / totalSentences;
  139. }
  140.  
  141. log.info("Max threshold: " + StringUtils.join(maxThreshold, ','));
  142. log.info("Avg threshold: " + StringUtils.join(averageThreshold, ','));
  143.  
  144. return new Summary(summary, rawData, sentencesIds.size(), totalCorrect);
  145. }
  146.  
  147. public static void main(String[] args) throws IOException {
  148. DeepLearning deepLearning = new DeepLearning();
  149. ObjectMapper mapper = new ObjectMapper();
  150. Map<String,Object> conf = mapper.readValue(new File("conf.json"), Map.class);
  151.  
  152. // Customizing params
  153. Nd4j.MAX_SLICES_TO_PRINT = -1;
  154. Nd4j.MAX_ELEMENTS_PER_SLICE = -1;
  155.  
  156. // training
  157. for (String filepath : (ArrayList<String>)conf.get("training")) {
  158. log.info("training: " + filepath);
  159. deepLearning.train(filepath);
  160. }
  161. //testing
  162. ArrayList<Map<String, Object>> results = new ArrayList<Map<String, Object>>();
  163. for (Map<String, Object> testing : (ArrayList<Map<String, Object>>) conf.get("testing")) {
  164. // log.info("testing: " + testing.get("data"));
  165. ArrayList<Double> recallList = new ArrayList<Double>();
  166. ArrayList<Double> precisionList = new ArrayList<Double>();
  167. ArrayList<Double> f1List = new ArrayList<Double>();
  168. ArrayList<Double> thresholdRaw = (ArrayList<Double>) testing.get("threshold");
  169. for (Double raw : thresholdRaw) {
  170. float[] threshold = { 0.0f, 0.0f, raw.floatValue(), 0.0f };
  171. Summary summary = deepLearning.summarize(
  172. (String) testing.get("data"),
  173. (String) testing.get("sentences"),
  174. threshold);
  175. recallList.add(new Double(summary.getRecall()));
  176. precisionList.add(new Double(summary.getPrecision()));
  177. f1List.add(new Double(summary.getFMeasure()));
  178. log.info("Percentage: " + summary.getCorrectPercentage());
  179. log.info("TotalCorrect: " + summary.getTotalCorrect() + " / " + summary.getTotalRetrieved());
  180. log.info("Recall: " + summary.getRecall());
  181. log.info("Precision: " + summary.getPrecision());
  182. log.info("Fmeasure: " + summary.getFMeasure());
  183. }
  184. Map<String, Object> result = new LinkedHashMap<String, Object>();
  185. result.put("name", testing.get("data"));
  186. result.put("labels", thresholdRaw);
  187. result.put("recall", recallList);
  188. result.put("precision", precisionList);
  189. result.put("f1", f1List);
  190. results.add(result);
  191. }
  192. mapper.writeValue(new File("report.json"), results);
  193. }
  194. }
  195.  
  196. class DUCDataSetIterator extends BaseDatasetIterator {
  197. private static final long serialVersionUID = -2022454995728680368L;
  198. public DUCDataSetIterator(int batch, int numExamples, String path) {
  199. super(batch,numExamples,new DUCDataFetcher(path));
  200. }
  201.  
  202. @Override
  203. public boolean hasNext() {
  204. return fetcher.hasMore();
  205. }
  206.  
  207. }
  208.  
  209.  
  210. class DUCDataFetcher extends BaseDataFetcher {
  211. private static final long serialVersionUID = 4566329799221375262L;
  212. public final static int NUM_EXAMPLES = 150;
  213. private String filepath;
  214. public DUCDataFetcher(String path) {
  215. numOutcomes = 4;
  216. inputColumns = 4;
  217. totalExamples = NUM_EXAMPLES;
  218. filepath = path;
  219. totalExamples = this.totalExamples();
  220. }
  221.  
  222. @Override
  223. public int totalExamples() {
  224. int total = 0;
  225. try {
  226. total = DUCUtil.getTotalSentences(filepath);
  227. } catch (Exception e) {
  228. log.error(e.getMessage(), e);
  229. }
  230. return total;
  231. }
  232.  
  233. @Override
  234. public boolean hasMore() {
  235. return cursor < totalExamples;
  236. }
  237.  
  238. public void fetch(int numExamples) {
  239. int from = cursor;
  240. int to = cursor + numExamples;
  241. if(to > totalExamples)
  242. to = totalExamples;
  243. try {
  244. initializeCurrFromList(DUCUtil.loadDUC(to, from, filepath));
  245. cursor += numExamples;
  246. } catch (IOException e) {
  247. throw new IllegalStateException("Unable to load duc");
  248. }
  249.  
  250. }
  251. }
  252.  
  253. class DUCUtil {
  254.  
  255. public static List<DataSet> loadDUC(int to, int from, String filepath) throws IOException {
  256. FileInputStream fis = new FileInputStream(filepath);
  257. @SuppressWarnings("unchecked")
  258. List<String> lines = IOUtils.readLines(fis);
  259. List<DataSet> list = new ArrayList<DataSet>();
  260. INDArray ret = Nd4j.ones(Math.abs(to - from), 4);
  261. double[][] outcomes = new double[lines.size()][4];
  262. int putCount = 0;
  263.  
  264. for(int i = from; i < to; i++) {
  265. String line = lines.get(i);
  266. String[] split = line.split(",");
  267.  
  268. addRow(ret,putCount++,split);
  269.  
  270. String outcome = split[4];
  271. double[] rowOutcome = new double[4];
  272. rowOutcome[new Float(outcome).intValue()] = 1;
  273. outcomes[i] = rowOutcome;
  274. }
  275.  
  276. for(int i = 0; i < ret.rows(); i++) {
  277. int idx = (outcomes.length > (from + i)) ? from + i : outcomes.length-1;
  278. DataSet add = new DataSet(ret.getRow(i), Nd4j.create(outcomes[idx]));
  279. list.add(add);
  280. if (idx == (outcomes.length-1)) break;
  281. }
  282. return list;
  283. }
  284.  
  285. public static int getTotalSentences(String filepath) throws IOException {
  286. FileInputStream fis = new FileInputStream(filepath);
  287. List<String> lines = IOUtils.readLines(fis);
  288. return lines.size();
  289. }
  290.  
  291. public static void addRow(INDArray ret,int row,String[] line) {
  292. double[] vector = new double[4];
  293. for(int i = 0; i < 4; i++)
  294. vector[i] = Double.parseDouble(line[i]);
  295.  
  296. ret.putRow(row,Nd4j.create(vector));
  297. }
  298.  
  299. public static List<String> getSentencesList(String filepath) throws IOException {
  300. FileInputStream fis = new FileInputStream(filepath);
  301. List<String> res = IOUtils.readLines(fis);
  302. fis.close();
  303. return res;
  304. }
  305.  
  306. public static float[][] getRawData(String filepath) throws IOException {
  307. FileInputStream fis = new FileInputStream(filepath);
  308. List<String> res = IOUtils.readLines(fis);
  309. float[][] matrix = new float[res.size()][7];
  310. for (int i = 0; i < matrix.length; i++) {
  311. String[] strArray = StringUtils.split(res.get(i), ',');
  312. for (int j = 0; j < strArray.length; j++) {
  313. matrix[i][j] = Float.parseFloat(strArray[j]);
  314. }
  315. }
  316. fis.close();
  317. return matrix;
  318. }
  319.  
  320. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement