Advertisement
Guest User

Untitled

a guest
Sep 4th, 2015
69
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.80 KB | None | 0 0
  1. /**
  2. * Iterative reduce with
  3. * flat map using map partitions
  4. *
  5. * @author Adam Gibson
  6. modified by Christophe Cerisara
  7. */
  8.  
  9. public class IterativeReduceFlatMap implements FlatMapFunction<Iterator<DataSet>, INDArray> {
  10.  
  11. private String json;
  12. private Broadcast<INDArray> params;
  13. private static Logger log = LoggerFactory.getLogger(IterativeReduceFlatMap.class);
  14.  
  15. /**
  16. * Pass in json configuration and baseline parameters
  17. *
  18. * @param json json configuration for the network
  19. * @param params the parameters to use for the network
  20. */
  21. public IterativeReduceFlatMap(String json, Broadcast<INDArray> params) {
  22. this.json = json;
  23. this.params = params;
  24. }
  25.  
  26. public static INDArray flattenParms(MultiLayerNetwork network) {
  27. ArrayList<Double> ps = new ArrayList<Double>();
  28. final int nl=network.getLayers().length;
  29. ps.add((double)nl);
  30. for (int i=0;i<nl;i++) {
  31. INDArray pp = network.getLayer(i).params();
  32. final int nx=pp.length();
  33. ps.add((double)nx);
  34. for (int j=0;j<nx;j++) {
  35. ps.add(pp.getDouble(j));
  36. }
  37. }
  38. double[] vs = new double[ps.size()];
  39. for (int i=0;i<vs.length;i++) vs[i]=ps.get(i);
  40. INDArray v = Nd4j.create(vs);
  41. return v;
  42. }
  43. public static void deflattenParms(MultiLayerNetwork network, INDArray parms) {
  44. int pi=0;
  45. int nLayer = (int)parms.getDouble(pi++);
  46. assert nLayer == network.getLayers().length;
  47. for (int i=0;i<nLayer;i++) {
  48. int nx=(int)parms.getDouble(pi++);
  49. double[] w = new double[nx];
  50. for (int j=0;j<nx;j++) w[j] = parms.getDouble(pi++);
  51. INDArray pl = Nd4j.create(w);
  52. network.getLayer(i).setParams(pl);
  53. }
  54. }
  55.  
  56. @Override
  57. public Iterable<INDArray> call(Iterator<DataSet> dataSetIterator) throws Exception {
  58. if (!dataSetIterator.hasNext()) {
  59. return Collections.singletonList(Nd4j.zeros(params.value().shape()));
  60. }
  61.  
  62. List<DataSet> collect = new ArrayList<DataSet>();
  63. while (dataSetIterator.hasNext()) {
  64. collect.add(dataSetIterator.next());
  65. }
  66.  
  67. DataSet data = DataSet.merge(collect, false);
  68. log.debug("Training on " + data.labelCounts());
  69. MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
  70.  
  71. network.init();
  72. INDArray val = params.value();
  73. if (val.length() != network.numParams())
  74. throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
  75.  
  76. network.setParameters(val);
  77. network.fit(data);
  78. INDArray trainedParms = flattenParms(network);
  79. return Collections.singletonList(trainedParms);
  80. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement