diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index 2eb6747dba0..2da7fb978f3 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -349,13 +349,6 @@ public: void on_adapter(const std::string& name, ::ngraph::ValueAccessor& adapter) override; - void on_adapter(const std::string& name, ::ngraph::ValueAccessor& adapter) override { - if (std::string(node->get_type_name()) != "Constant") { - const auto data_beg = static_cast(adapter.get_ptr()); - params[name] = std::string(data_beg, adapter.size()); - } - } - private: std::shared_ptr<::ngraph::Node> node; std::map params; @@ -394,6 +387,11 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na (void)a; } else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter>(& adapter)) { (void)a; + } else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter>>(& adapter)) { + if (std::string(node->get_type_name()) != "Constant") { + const auto data_beg = static_cast(a->get()->get_ptr()); + params[name] = std::string(data_beg, a->get()->size()); + } } else { THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. " "Attribute adapter can not be found for " << name << " parameter"; diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp index 5191f022b99..e09918abad9 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.cpp @@ -224,6 +224,11 @@ ngraph::op::v5::Loop::SpecialBodyPorts V10Parser::XmlDeserializer::parsePurposeA } void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) { + static const std::unordered_set skip_names = { + "input_descriptions", + "output_descriptions", + "special_body_ports" + }; std::string val; // for TensorIterator look for 'port_map' as 'data' does not exist @@ -239,7 +244,7 @@ void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::Val } } - if (!getStrAttribute(node.child("data"), name, val)) return; + if (skip_names.count(name) && !getStrAttribute(node.child("data"), name, val)) return; if (auto a = ngraph::as_type>(&adapter)) { static_cast(*a) = details::convertPrecision(val); } else if (auto a = ngraph::as_type>(&adapter)) { @@ -292,6 +297,44 @@ void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::Val ngraph::element::dynamic, variable_id}); } a->set(variables[variable_id]); + } else if (auto a = ngraph::as_type>>(&adapter)) { + std::string value; + pugi::xml_node dn = node.child("data"); + auto type = XMLParseUtils::GetStrAttr(node, "type"); + + if (dn.empty()) + THROW_IE_EXCEPTION << "No attrtibutes defined for " << type << " op!"; + + if (getStrAttribute(dn, name, value)) { + auto buffer = std::make_shared(value.size()); + auto data = static_cast(buffer->get_ptr()); + value.copy(data, value.size()); + a->set(buffer); + } else if (name == "value" && type == "Const") { + std::vector shape; + std::string el_type_str; + + size_t offset = XMLParseUtils::GetUInt64Attr(dn, "offset"); + size_t size = XMLParseUtils::GetUInt64Attr(dn, "size"); + if (!getStrAttribute(dn, "element_type", el_type_str)) return; + if (!getParameters(dn, "shape", shape)) return; + + ngraph::element::Type el_type = details::convertPrecision(el_type_str); + + size_t length = weights->byteSize(); + if (!length) + THROW_IE_EXCEPTION << "Empty weights data in bin file or bin file cannot be found!"; + if (length < offset + size) + THROW_IE_EXCEPTION << "Incorrect weights in bin file!"; + if (size < std::ceil(ngraph::shape_size(shape) * el_type.bitwidth() / 8.f)) + THROW_IE_EXCEPTION << "Attribute and shape size are inconsistent for " << type << " op!"; + + char* data = weights->cbuffer().as() + offset; + + using SharedBuffer = ngraph::runtime::SharedBuffer; + auto buffer = std::make_shared(data, size, weights); + a->set(buffer); + } } else { THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name << " parameter"; @@ -742,6 +785,10 @@ std::shared_ptr V10Parser::XmlDeserializer::createNode( if (!ngraphNode) { THROW_IE_EXCEPTION << "Opset " << params.version << " doesn't contain the operation with type: " << type; } + // Share Weights form constant blob + if (auto constant = std::dynamic_pointer_cast(ngraphNode)) { + constant->alloc_buffer_on_visit_attributes(false); + } ngraphNode->set_arguments(inputs); XmlDeserializer visitor(node, weights, opsets, variables); if (ngraphNode->visit_attributes(visitor)) { diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp index 89a40549975..c8db0add302 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp @@ -216,43 +216,6 @@ private: stringToType(val, value); adapter.set(value); } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - std::string value; - pugi::xml_node dn = node.child("data"); - auto type = XMLParseUtils::GetStrAttr(node, "type"); - - if (dn.empty()) - THROW_IE_EXCEPTION << "No attrtibutes defined for " << type << " op!"; - - if (getStrAttribute(dn, name, value)) { - auto data = static_cast(adapter.get_ptr()); - size_t length = std::min(value.size(), adapter.size()); - value.copy(data, length); - } else if (name == "value" && type == "Const") { - std::vector shape; - std::string el_type_str; - - size_t offset = XMLParseUtils::GetUInt64Attr(dn, "offset"); - size_t size = XMLParseUtils::GetUInt64Attr(dn, "size"); - if (!getStrAttribute(dn, "element_type", el_type_str)) return; - if (!getParameters(dn, "shape", shape)) return; - - ngraph::element::Type el_type = details::convertPrecision(el_type_str); - - size_t length = weights->byteSize(); - if (!length) - THROW_IE_EXCEPTION << "Empty weights data in bin file or bin file cannot be found!"; - if (length < offset + size) - THROW_IE_EXCEPTION << "Incorrect weights in bin file!"; - if (size < std::ceil(ngraph::shape_size(shape) * el_type.bitwidth() / 8.f)) - THROW_IE_EXCEPTION << "Attribute and shape size are inconsistent for " << type << " op!"; - - auto data = static_cast(adapter.get_ptr()); - char* weights_data = weights->cbuffer().as() + offset; - - std::memcpy(data, weights_data, size); - } - } void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { std::string val; if (!getStrAttribute(node.child("data"), name, val)) diff --git a/inference-engine/src/transformations/src/transformations/serialize.cpp b/inference-engine/src/transformations/src/transformations/serialize.cpp index 52489dbc227..c14feb751a7 100644 --- a/inference-engine/src/transformations/src/transformations/serialize.cpp +++ b/inference-engine/src/transformations/src/transformations/serialize.cpp @@ -247,22 +247,15 @@ public: } } else if (const auto& a = ngraph::as_type>>(&adapter)) { m_xml_node.append_attribute(name.c_str()).set_value(a->get()->get_info().variable_id.c_str()); - } - } - - void on_adapter(const std::string& name, - ngraph::ValueAccessor& adapter) override { - if (name == "value" && translate_type_name(m_node_type_name) == "Const") { - using AlignedBufferAdapter = - ngraph::AttributeAdapter>; - if (auto a = ngraph::as_type(&adapter)) { - const int64_t size = a->size(); + } else if (const auto& a = ngraph::as_type>>(&adapter)) { + if (name == "value" && translate_type_name(m_node_type_name) == "Const") { + const int64_t size = a->get()->size(); const int64_t offset = m_bin_data.tellp(); m_xml_node.append_attribute("offset").set_value(offset); m_xml_node.append_attribute("size").set_value(size); - auto data = static_cast(a->get_ptr()); + auto data = static_cast(a->get()->get_ptr()); m_bin_data.write(data, size); } } diff --git a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index a3780a67e87..4045c8fd00e 100644 --- a/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/inference-engine/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -240,16 +240,16 @@ public: class ReadAndStoreAttributes : public ngraph::AttributeVisitor, protected storage::Storage { public: void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - if (auto inputs = - ngraph::as_type>(&adapter)) { + if (auto inputs = ngraph::as_type>(&adapter)) { insert(name, inputs->get()); - } else if ( - auto outputs = - ngraph::as_type>(&adapter)) { + } else if (auto outputs = ngraph::as_type>(&adapter)) { insert(name, outputs->get()); - } else if ( - auto ports = ngraph::as_type>(&adapter)) { + } else if (auto ports = ngraph::as_type>(&adapter)) { insert(name, ports->get()); + } else if (auto a = ngraph::as_type>>(&adapter)) { + const auto beg = static_cast(a->get()->get_ptr()); + const auto end = beg + a->get()->size(); + insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)}); } else { m_read_result += "store attr [ ERR ]: " + name + " [drop `void` comparison which is '" + adapter.get_type_info().name + @@ -257,12 +257,6 @@ public: } } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - const auto beg = static_cast(adapter.get_ptr()); - const auto end = beg + adapter.size(); - insert(name, storage::MemoryChunk{storage::MemoryChunk::Data(beg, end)}); - } - #define ON_ADAPTER(TYPE) \ void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { \ insert(name, adapter.get()); \ @@ -506,16 +500,25 @@ public: return; } m_visited_attributes.insert(name); - if (auto inputs = - ngraph::as_type>(&adapter)) { + if (auto inputs = ngraph::as_type>(&adapter)) { verify(name, inputs->get()); - } else if ( - auto outputs = - ngraph::as_type>(&adapter)) { + } else if (auto outputs = ngraph::as_type>(&adapter)) { verify(name, outputs->get()); - } else if ( - auto ports = ngraph::as_type>(&adapter)) { + } else if (auto ports = ngraph::as_type>(&adapter)) { verify(name, ports->get()); + } else if (auto a = ngraph::as_type>>(&adapter)) { + m_visited_attributes.insert(name); + const auto ref_value = m_attr_ref.get(name); + if (!ref_value) { + m_cmp_result += "missing attribute name: '" + name + "'"; + return; + } + + if (a->get()->size() != ref_value->size() || + std::memcmp(ref_value->data(), a->get()->get_ptr(), ref_value->size()) != 0) { + m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer"; + return; + } } else { m_cmp_result += "compare attr [ ERR ]: " + name + " [drop `void` comparison which is '" + adapter.get_type_info().name + @@ -523,24 +526,6 @@ public: } } - void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { - if (should_return()) { - return; - } - m_visited_attributes.insert(name); - const auto ref_value = m_attr_ref.get(name); - if (!ref_value) { - m_cmp_result += "missing attribute name: '" + name + "'"; - return; - } - - if (adapter.size() != ref_value->size() || - std::memcmp(ref_value->data(), adapter.get_ptr(), ref_value->size()) != 0) { - m_cmp_result += "mismatch in value: '" + name + "' : look in to the mem buffer"; - return; - } - } - #define ON_ADAPTER(TYPE) \ void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { \ verify(name, adapter.get()); \ diff --git a/ngraph/core/include/ngraph/op/constant.hpp b/ngraph/core/include/ngraph/op/constant.hpp index 93724d8cf16..749f94f2de1 100644 --- a/ngraph/core/include/ngraph/op/constant.hpp +++ b/ngraph/core/include/ngraph/op/constant.hpp @@ -470,6 +470,14 @@ namespace ngraph } std::string convert_value_to_string(size_t index) const; + /** + * \brief Allows to avoid buffer allocation on the visit_attributes call + */ + void alloc_buffer_on_visit_attributes(bool val) + { + m_alloc_buffer_on_visit_attributes = val; + } + protected: template void cast_vector(std::vector& output_vector) const @@ -591,6 +599,7 @@ namespace ngraph std::shared_ptr m_data; bool m_all_elements_bitwise_identical; bool are_all_data_elements_bitwise_identical() const; + bool m_alloc_buffer_on_visit_attributes = true; }; } using v0::Constant; diff --git a/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp b/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp index 772e739ddec..35eb94f2feb 100644 --- a/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp +++ b/ngraph/core/include/ngraph/runtime/aligned_buffer.hpp @@ -75,17 +75,13 @@ namespace ngraph } template <> class NGRAPH_API AttributeAdapter> - : public ValueAccessor + : public DirectValueAccessor> { public: AttributeAdapter(std::shared_ptr& value); - void* get_ptr() override; - size_t size() override; static constexpr DiscreteTypeInfo type_info{ "AttributeAdapter>", 0}; const DiscreteTypeInfo& get_type_info() const override { return type_info; } - protected: - std::shared_ptr& m_ref; }; } diff --git a/ngraph/core/src/op/constant.cpp b/ngraph/core/src/op/constant.cpp index d7fcfb93137..58a02471469 100644 --- a/ngraph/core/src/op/constant.cpp +++ b/ngraph/core/src/op/constant.cpp @@ -628,9 +628,13 @@ bool op::Constant::are_all_data_elements_bitwise_identical() const bool op::v0::Constant::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v0_Constant_visit_attributes); + Shape prev_shape = m_shape; + element::Type prev_type = m_element_type; visitor.on_attribute("element_type", m_element_type); visitor.on_attribute("shape", m_shape); - if (m_data == nullptr) + + bool need_to_reallocate = (m_shape != prev_shape || prev_type != m_element_type); + if (m_alloc_buffer_on_visit_attributes && need_to_reallocate) { // Filling in a fresh constant allocate_buffer(); diff --git a/ngraph/core/src/runtime/aligned_buffer.cpp b/ngraph/core/src/runtime/aligned_buffer.cpp index 3f38b024c85..1daca4d01c3 100644 --- a/ngraph/core/src/runtime/aligned_buffer.cpp +++ b/ngraph/core/src/runtime/aligned_buffer.cpp @@ -87,13 +87,7 @@ namespace ngraph AttributeAdapter>::AttributeAdapter( shared_ptr& value) - : m_ref(value) + : DirectValueAccessor>(value) { } - - void* AttributeAdapter>::get_ptr() - { - return m_ref->get_ptr(); - } - size_t AttributeAdapter>::size() { return m_ref->size(); } } diff --git a/ngraph/test/util/visitor.hpp b/ngraph/test/util/visitor.hpp index 35b66122960..7a40caa18a9 100644 --- a/ngraph/test/util/visitor.hpp +++ b/ngraph/test/util/visitor.hpp @@ -162,7 +162,16 @@ namespace ngraph } void on_adapter(const std::string& name, ValueAccessor& adapter) override { - NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled"); + if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter< + std::shared_ptr>>(&adapter)) + { + auto& data = m_values.get(name); + data->read(a->get()->get_ptr(), a->get()->size()); + } + else + { + NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled"); + } } // The remaining adapter methods fall back on the void adapter if not implemented void on_adapter(const std::string& name, ValueAccessor& adapter) override @@ -257,7 +266,18 @@ namespace ngraph void on_adapter(const std::string& name, ValueAccessor& adapter) override { - NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled"); + if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter< + std::shared_ptr>>(&adapter)) + { + HostTensorPtr data = + std::make_shared(element::u8, Shape{a->get()->size()}); + data->write(a->get()->get_ptr(), a->get()->size()); + m_values.insert(name, data); + } + else + { + NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled"); + } } // The remaining adapter methods fall back on the void adapter if not implemented void on_adapter(const std::string& name, ValueAccessor& adapter) override