Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /**
- * learn training data, printing SSE at 10 of the epochs, evenly spaced
- * if a validation set available, learning stops when SSE on validation set rises
- * this check is done by summing SSE over 10 epochs
- * @param numEpochs number of epochs
- * @param lRate learning rate
- * @param momentum momentum
- * @return String with data about learning eg SSEs at relevant epoch
- */
- public String doLearn (int numEpochs, double lRate, double momentum) {
- String s = "";
- String tmp = "";
- if (validationData==null) s = super.doLearn(numEpochs, lRate, momentum);
- // if no validation set, just use normal doLearn
- else {
- if(hasLearned)return ""; //if has learnt , do not proceed
- validationData.clearSSELog(); //clear the validations data SSE log
- for (int epoch = 0; epoch < numEpochs; epoch+=10) { //for every epoch , step (10) (because we check the SSE every 10 epochs)
- presentDataSet(validationData);
- tmp=super.doLearn(10, lRate, momentum); //perform 10 learning epochs in the train data
- //present the validation data into the network
- if(validationData.getTotalSSE()>currSSE) { //if this SSE rises from the previous one
- hasLearned=true; //break and give info
- s+="Stopped after "+trainData.sizeSSELog();
- break;
- }
- if(trainData.sizeSSELog()%100==0) {
- s+=addEpochString(trainData.sizeSSELog())+":"+trainData.dataAnalysis()+"\n"; //add string with analytics
- }
- currSSE=validationData.getTotalSSE(); //remember this SSE
- validationData.clearSSELog(); //clear the validation log
- }
- }
- return s; // return string showing learning
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement