Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /**
- * Iterative reduce with
- * flat map using map partitions
- *
- * @author Adam Gibson
- modified by Christophe Cerisara
- */
- public class IterativeReduceFlatMap implements FlatMapFunction<Iterator<DataSet>, INDArray> {
- private String json;
- private Broadcast<INDArray> params;
- private static Logger log = LoggerFactory.getLogger(IterativeReduceFlatMap.class);
- /**
- * Pass in json configuration and baseline parameters
- *
- * @param json json configuration for the network
- * @param params the parameters to use for the network
- */
- public IterativeReduceFlatMap(String json, Broadcast<INDArray> params) {
- this.json = json;
- this.params = params;
- }
- public static INDArray flattenParms(MultiLayerNetwork network) {
- ArrayList<Double> ps = new ArrayList<Double>();
- final int nl=network.getLayers().length;
- ps.add((double)nl);
- for (int i=0;i<nl;i++) {
- INDArray pp = network.getLayer(i).params();
- final int nx=pp.length();
- ps.add((double)nx);
- for (int j=0;j<nx;j++) {
- ps.add(pp.getDouble(j));
- }
- }
- double[] vs = new double[ps.size()];
- for (int i=0;i<vs.length;i++) vs[i]=ps.get(i);
- INDArray v = Nd4j.create(vs);
- return v;
- }
- public static void deflattenParms(MultiLayerNetwork network, INDArray parms) {
- int pi=0;
- int nLayer = (int)parms.getDouble(pi++);
- assert nLayer == network.getLayers().length;
- for (int i=0;i<nLayer;i++) {
- int nx=(int)parms.getDouble(pi++);
- double[] w = new double[nx];
- for (int j=0;j<nx;j++) w[j] = parms.getDouble(pi++);
- INDArray pl = Nd4j.create(w);
- network.getLayer(i).setParams(pl);
- }
- }
- @Override
- public Iterable<INDArray> call(Iterator<DataSet> dataSetIterator) throws Exception {
- if (!dataSetIterator.hasNext()) {
- return Collections.singletonList(Nd4j.zeros(params.value().shape()));
- }
- List<DataSet> collect = new ArrayList<DataSet>();
- while (dataSetIterator.hasNext()) {
- collect.add(dataSetIterator.next());
- }
- DataSet data = DataSet.merge(collect, false);
- log.debug("Training on " + data.labelCounts());
- MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json));
- network.init();
- INDArray val = params.value();
- if (val.length() != network.numParams())
- throw new IllegalStateException("Network did not have same number of parameters as the broadcasted set parameters");
- network.setParameters(val);
- network.fit(data);
- INDArray trainedParms = flattenParms(network);
- return Collections.singletonList(trainedParms);
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement