Advertisement
Guest User

Untitled

a guest
Apr 20th, 2018
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.09 KB | None | 0 0
  1. /** This software is distributed under BSD 3-clause license (see LICENSE file).
  2. *
  3. * Authors: Sergey Lisitsyn, Viktor Gal
  4. */
  5.  
  6. #include <memory>
  7.  
  8. #include <shogun/base/class_list.h>
  9. #include <shogun/base/macros.h>
  10. #include <shogun/io/serialization/JsonDeserializer.h>
  11. #include <shogun/lib/SGVector.h>
  12. #include <shogun/lib/SGMatrix.h>
  13.  
  14. #include <rapidjson/document.h>
  15.  
  16. using namespace shogun;
  17.  
  18. template<typename RapidJsonReader>
  19. class JSONReaderVisitor : public AnyVisitor
  20. {
  21. public:
  22. JSONReaderVisitor(RapidJsonReader& jr, CJsonDeserializer* dser):
  23. AnyVisitor(), m_json_reader(jr), m_deser(dser) {}
  24.  
  25. virtual void on(bool* v)
  26. {
  27. SG_SDEBUG("reading bool")
  28. *v = m_json_reader.GetBool();
  29. SG_SDEBUG("%d\n", *v)
  30. }
  31. virtual void on(int32_t* v)
  32. {
  33. SG_SDEBUG("reading int32_t")
  34. *v = m_json_reader.GetInt();
  35. SG_SDEBUG("%d\n", *v)
  36. }
  37. virtual void on(int64_t* v)
  38. {
  39. SG_SDEBUG("reading int64_t")
  40. *v = m_json_reader.GetInt64();
  41. SG_SDEBUG("%d\n", *v)
  42. }
  43. virtual void on(float* v)
  44. {
  45. SG_SDEBUG("reading float: ")
  46. *v = (float32_t)m_json_reader.GetDouble();
  47. SG_SDEBUG("%f\n", *v)
  48. }
  49. virtual void on(double* v)
  50. {
  51. SG_SDEBUG("reading double: ")
  52. *v = m_json_reader.GetDouble();
  53. SG_SDEBUG("%f\n", *v)
  54. }
  55. virtual void on(CSGObject** v)
  56. {
  57. SG_SDEBUG("reading SGObject: ")
  58. *v = m_deser->read().get();
  59. /*
  60. std::string object_name;
  61. EPrimitiveType primitive_type;
  62. m_archive(object_name, primitive_type);
  63. SG_SDEBUG("%s %d\n", object_name.c_str(), primitive_type)
  64. if (*v == nullptr)
  65. SG_UNREF(*v);
  66. *v = create(object_name.c_str(), primitive_type);
  67. m_archive(**v);
  68. */
  69. }
  70. virtual void on(SGVector<int>* v)
  71. {
  72. SG_SDEBUG("reading SGVector<int>\n")
  73. }
  74. virtual void on(SGVector<float>* v)
  75. {
  76. SG_SDEBUG("reading SGVector<float>\n")
  77. }
  78. virtual void on(SGVector<double>* v)
  79. {
  80. SG_SDEBUG("reading SGVector<double>\n")
  81. }
  82. virtual void on(SGMatrix<int>* v)
  83. {
  84. SG_SDEBUG("reading SGMatrix<int>>\n")
  85. }
  86. virtual void on(SGMatrix<float>* v)
  87. {
  88. SG_SDEBUG("reading SGMatrix<float>>\n")
  89. }
  90. virtual void on(SGMatrix<double>* v)
  91. {
  92. SG_SDEBUG("reading SGMatrix<double>>\n")
  93. }
  94.  
  95. private:
  96. RapidJsonReader& m_json_reader;
  97. CJsonDeserializer* m_deser; // we need this because of recursion :) lalal this super ugly
  98. };
  99.  
  100. class CIStreamAdapter
  101. {
  102. public:
  103. typedef char Ch;
  104.  
  105. CIStreamAdapter(CInputStream* is): m_stream(is) {}
  106.  
  107. Ch Peek() const
  108. {
  109. //int c = m_stream.peek();
  110. // return c == std::char_traits<char>::eof() ? '\0' : (Ch)c;
  111. }
  112.  
  113. Ch Take()
  114. {
  115. // int c = m_stream.get();
  116. // return c == std::char_traits<char>::eof() ? '\0' : (Ch)c;
  117. }
  118.  
  119. size_t Tell() const
  120. {
  121. // return (size_t)m_stream.tellg();
  122. }
  123.  
  124. Ch* PutBegin() { assert(false); return 0; }
  125. void Put(Ch) { assert(false); }
  126. void Flush() { assert(false); }
  127. size_t PutEnd(Ch*) { assert(false); return 0; }
  128.  
  129. private:
  130. CInputStream* m_stream;
  131. SG_DELETE_COPY_AND_ASSIGN(CIStreamAdapter);
  132. };
  133.  
  134.  
  135. CJsonDeserializer::CJsonDeserializer() : CDeserializer()
  136. {
  137. }
  138.  
  139. CJsonDeserializer::~CJsonDeserializer()
  140. {
  141. }
  142.  
  143. Some<CSGObject> CJsonDeserializer::read()
  144. {
  145. CIStreamAdapter is(stream().get());
  146. rapidjson::Document obj_json;
  147. obj_json.ParseStream(is);
  148. auto reader_visitor = std::make_unique<JSONReaderVisitor<rapidjson::Document>>(obj_json, this); // watch out we are passing this....
  149.  
  150. if (!obj_json.IsObject())
  151. throw ShogunException("JSON value is not an object!");
  152.  
  153. std::string obj_name(obj_json["name"].GetString());
  154. EPrimitiveType primitive_type((EPrimitiveType) obj_json["generic"].GetInt());
  155. auto obj = create(obj_name.c_str(), primitive_type);
  156. for (auto it = obj_json.MemberBegin(); it != obj_json.MemberEnd(); ++it)
  157. {
  158. auto param_name = it->name.GetString();
  159. if (!has(param_name))
  160. throw ShogunException(
  161. "cannot deserialize the object from file!");
  162.  
  163. BaseTag tag(param_name);
  164. auto parameter = obj->get_parameter(tag); // NOTE: currently private
  165. parameter.get_value().visit(reader_visitor.get());
  166. obj->update_parameter(tag, parameter.get_value()); // NOTE: currently private
  167. }
  168. return wrap<CSGObject>(obj);
  169. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement