Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- template <class T>
- class StoppedMachineSerialization
- : public TrainedModelSerializationFixture<T>
- {
- };
- TYPED_TEST_CASE(StoppedMachineSerialization, StoppableMachineTypes);
- TYPED_TEST(StoppedMachineSerialization, Test)
- {
- int i = 0;
- std::function<bool()> callback = [&i]() {
- if (i >= 1)
- {
- get_global_signal()->get_subscriber()->on_next(SG_BLOCK_COMP);
- return true;
- }
- SG_SPRINT("%d ",i);
- i++;
- return false;
- };
- this->machine->set_callback(callback);
- this->machine->set_labels(this->train_labels);
- this->machine->train(this->train_feats);
- /* to avoid serialization of the data */
- // machine->set_features(NULL);
- // machine->set_labels(NULL);
- auto predictions = wrap<CLabels>(this->machine->apply(this->test_feats));
- std::string filename;
- ASSERT_TRUE(this->serialize_machine(this->machine, filename));
- ASSERT_TRUE(
- this->deserialize_machine(this->deserialized_machine, filename));
- auto deserialized_predictions =
- wrap<CLabels>(this->deserialized_machine->apply(this->test_feats));
- set_global_fequals_epsilon(1e-7);
- ASSERT_TRUE(i);
- ASSERT(predictions->equals(deserialized_predictions))
- set_global_fequals_epsilon(0);
- }
Add Comment
Please, Sign In to add comment