Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /** This software is distributed under BSD 3-clause license (see LICENSE file).
- *
- * Authors: Sergey Lisitsyn, Viktor Gal
- */
- #include <memory>
- #include <shogun/base/class_list.h>
- #include <shogun/base/macros.h>
- #include <shogun/io/serialization/JsonDeserializer.h>
- #include <shogun/lib/SGVector.h>
- #include <shogun/lib/SGMatrix.h>
- #include <rapidjson/document.h>
- using namespace shogun;
- template<typename RapidJsonReader>
- class JSONReaderVisitor : public AnyVisitor
- {
- public:
- JSONReaderVisitor(RapidJsonReader& jr, CJsonDeserializer* dser):
- AnyVisitor(), m_json_reader(jr), m_deser(dser) {}
- virtual void on(bool* v)
- {
- SG_SDEBUG("reading bool")
- *v = m_json_reader.GetBool();
- SG_SDEBUG("%d\n", *v)
- }
- virtual void on(int32_t* v)
- {
- SG_SDEBUG("reading int32_t")
- *v = m_json_reader.GetInt();
- SG_SDEBUG("%d\n", *v)
- }
- virtual void on(int64_t* v)
- {
- SG_SDEBUG("reading int64_t")
- *v = m_json_reader.GetInt64();
- SG_SDEBUG("%d\n", *v)
- }
- virtual void on(float* v)
- {
- SG_SDEBUG("reading float: ")
- *v = (float32_t)m_json_reader.GetDouble();
- SG_SDEBUG("%f\n", *v)
- }
- virtual void on(double* v)
- {
- SG_SDEBUG("reading double: ")
- *v = m_json_reader.GetDouble();
- SG_SDEBUG("%f\n", *v)
- }
- virtual void on(CSGObject** v)
- {
- SG_SDEBUG("reading SGObject: ")
- *v = m_deser->read().get();
- /*
- std::string object_name;
- EPrimitiveType primitive_type;
- m_archive(object_name, primitive_type);
- SG_SDEBUG("%s %d\n", object_name.c_str(), primitive_type)
- if (*v == nullptr)
- SG_UNREF(*v);
- *v = create(object_name.c_str(), primitive_type);
- m_archive(**v);
- */
- }
- virtual void on(SGVector<int>* v)
- {
- SG_SDEBUG("reading SGVector<int>\n")
- }
- virtual void on(SGVector<float>* v)
- {
- SG_SDEBUG("reading SGVector<float>\n")
- }
- virtual void on(SGVector<double>* v)
- {
- SG_SDEBUG("reading SGVector<double>\n")
- }
- virtual void on(SGMatrix<int>* v)
- {
- SG_SDEBUG("reading SGMatrix<int>>\n")
- }
- virtual void on(SGMatrix<float>* v)
- {
- SG_SDEBUG("reading SGMatrix<float>>\n")
- }
- virtual void on(SGMatrix<double>* v)
- {
- SG_SDEBUG("reading SGMatrix<double>>\n")
- }
- private:
- RapidJsonReader& m_json_reader;
- CJsonDeserializer* m_deser; // we need this because of recursion :) lalal this super ugly
- };
- class CIStreamAdapter
- {
- public:
- typedef char Ch;
- CIStreamAdapter(CInputStream* is): m_stream(is) {}
- Ch Peek() const
- {
- //int c = m_stream.peek();
- // return c == std::char_traits<char>::eof() ? '\0' : (Ch)c;
- }
- Ch Take()
- {
- // int c = m_stream.get();
- // return c == std::char_traits<char>::eof() ? '\0' : (Ch)c;
- }
- size_t Tell() const
- {
- // return (size_t)m_stream.tellg();
- }
- Ch* PutBegin() { assert(false); return 0; }
- void Put(Ch) { assert(false); }
- void Flush() { assert(false); }
- size_t PutEnd(Ch*) { assert(false); return 0; }
- private:
- CInputStream* m_stream;
- SG_DELETE_COPY_AND_ASSIGN(CIStreamAdapter);
- };
- CJsonDeserializer::CJsonDeserializer() : CDeserializer()
- {
- }
- CJsonDeserializer::~CJsonDeserializer()
- {
- }
- Some<CSGObject> CJsonDeserializer::read()
- {
- CIStreamAdapter is(stream().get());
- rapidjson::Document obj_json;
- obj_json.ParseStream(is);
- auto reader_visitor = std::make_unique<JSONReaderVisitor<rapidjson::Document>>(obj_json, this); // watch out we are passing this....
- if (!obj_json.IsObject())
- throw ShogunException("JSON value is not an object!");
- std::string obj_name(obj_json["name"].GetString());
- EPrimitiveType primitive_type((EPrimitiveType) obj_json["generic"].GetInt());
- auto obj = create(obj_name.c_str(), primitive_type);
- for (auto it = obj_json.MemberBegin(); it != obj_json.MemberEnd(); ++it)
- {
- auto param_name = it->name.GetString();
- if (!has(param_name))
- throw ShogunException(
- "cannot deserialize the object from file!");
- BaseTag tag(param_name);
- auto parameter = obj->get_parameter(tag); // NOTE: currently private
- parameter.get_value().visit(reader_visitor.get());
- obj->update_parameter(tag, parameter.get_value()); // NOTE: currently private
- }
- return wrap<CSGObject>(obj);
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement