Guest User

Untitled

a guest
May 29th, 2018
37
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.19 KB | None | 0 0
  1. template <class T>
  2. class StoppedMachineSerialization
  3. : public TrainedModelSerializationFixture<T>
  4. {
  5. };
  6.  
  7. TYPED_TEST_CASE(StoppedMachineSerialization, StoppableMachineTypes);
  8.  
  9. TYPED_TEST(StoppedMachineSerialization, Test)
  10. {
  11. int i = 0;
  12. std::function<bool()> callback = [&i]() {
  13. if (i >= 1)
  14. {
  15. get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
  16. return true;
  17. }
  18. SG_SPRINT("%d ",i);
  19. i++;
  20. return false;
  21. };
  22.  
  23. this->machine->set_callback(callback);
  24. this->machine->set_labels(this->train_labels);
  25. this->machine->train(this->train_feats);
  26.  
  27. /* to avoid serialization of the data */
  28. // machine->set_features(NULL);
  29. // machine->set_labels(NULL);
  30.  
  31. auto predictions = wrap<CLabels>(this->machine->apply(this->test_feats));
  32.  
  33. std::string filename;
  34. ASSERT_TRUE(this->serialize_machine(this->machine, filename));
  35.  
  36. ASSERT_TRUE(
  37. this->deserialize_machine(this->deserialized_machine, filename));
  38.  
  39. auto deserialized_predictions =
  40. wrap<CLabels>(this->deserialized_machine->apply(this->test_feats));
  41.  
  42. set_global_fequals_epsilon(1e-7);
  43. ASSERT_TRUE(i);
  44. ASSERT(predictions->equals(deserialized_predictions))
  45. set_global_fequals_epsilon(0);
  46. }
Add Comment
Please, Sign In to add comment