From 95bc22f0652043642b355dc0fbf41c08cda36bad Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Mon, 1 Nov 2021 09:35:10 +0300 Subject: [PATCH] FullyConnected deserialization fix --- .../op/fully_connected.cpp | 16 +++----------- ngraph/core/include/ngraph/dimension.hpp | 1 + .../core/include/openvino/core/dimension.hpp | 21 +++++++++++++++++++ ngraph/core/src/dimension.cpp | 15 +++++++++++++ 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp index 3affd3ab47e..7467f9103d6 100644 --- a/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp +++ b/inference-engine/src/mkldnn_plugin/ngraph_transformations/op/fully_connected.cpp @@ -9,7 +9,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output& const ngraph::Rank& output_rank, const ngraph::element::Type output_type) : Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) { - constructor_validate_and_infer_types(); + validate_and_infer_types(); } MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output& A, @@ -18,7 +18,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output& const ngraph::Rank& output_rank, const ngraph::element::Type output_type) : Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) { - constructor_validate_and_infer_types(); + validate_and_infer_types(); } std::shared_ptr MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const { @@ -40,13 +40,6 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() { input_size, ", expected: 2 or 3."); - const auto output_size = get_output_size(); - NODE_VALIDATION_CHECK(this, - output_size == 1, - "Number of outputs is incorrect. Current value is: ", - output_size, - ", expected: 1."); - // Weights shape: [O, I1, ..., Im]; // O - output channels dimensions, Ik - input channels dimensions const auto weights_pshape = get_input_partial_shape(1); @@ -101,10 +94,7 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() { } bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) { - if (m_output_rank.is_static()) { - std::int64_t value = m_output_rank.get_length(); - visitor.on_attribute("out-rank", value); - } + visitor.on_attribute("out-rank", m_output_rank); visitor.on_attribute("out-type", m_output_type); return true; } diff --git a/ngraph/core/include/ngraph/dimension.hpp b/ngraph/core/include/ngraph/dimension.hpp index 9d748b26b30..762d4e43043 100644 --- a/ngraph/core/include/ngraph/dimension.hpp +++ b/ngraph/core/include/ngraph/dimension.hpp @@ -4,6 +4,7 @@ #pragma once +#include "ngraph/attribute_adapter.hpp" #include "ngraph/interval.hpp" #include "openvino/core/dimension.hpp" diff --git a/ngraph/core/include/openvino/core/dimension.hpp b/ngraph/core/include/openvino/core/dimension.hpp index e5423267754..f1441dc577d 100644 --- a/ngraph/core/include/openvino/core/dimension.hpp +++ b/ngraph/core/include/openvino/core/dimension.hpp @@ -8,6 +8,7 @@ #include #include +#include "openvino/core/attribute_adapter.hpp" #include "openvino/core/core_visibility.hpp" #include "openvino/core/interval.hpp" @@ -169,4 +170,24 @@ private: /// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`. OPENVINO_API std::ostream& operator<<(std::ostream& str, const Dimension& dimension); + +template <> +class OPENVINO_API AttributeAdapter : public ValueAccessor { +public: + AttributeAdapter(ov::Dimension& value) : m_ref(value) {} + + const int64_t& get() override; + void set(const int64_t& value) override; + operator ov::Dimension&() { + return m_ref; + } + + OPENVINO_RTTI("AttributeAdapter"); + BWDCMP_RTTI_DECLARATION; + +protected: + ov::Dimension& m_ref; + int64_t m_buffer; + bool m_buffer_valid{false}; +}; } // namespace ov diff --git a/ngraph/core/src/dimension.cpp b/ngraph/core/src/dimension.cpp index 84b156ae7b4..7c60f89a95d 100644 --- a/ngraph/core/src/dimension.cpp +++ b/ngraph/core/src/dimension.cpp @@ -105,3 +105,18 @@ Dimension::value_type Dimension::get_max_length() const { Dimension::value_type Dimension::get_min_length() const { return dimension_length(m_dimension.get_min_val()); } + +BWDCMP_RTTI_DEFINITION(ov::AttributeAdapter); + +const int64_t& ov::AttributeAdapter::get() { + if (!m_buffer_valid) { + m_buffer = m_ref.is_dynamic() ? -1 : m_ref.get_length(); + m_buffer_valid = true; + } + return m_buffer; +} + +void ov::AttributeAdapter::set(const int64_t& value) { + m_ref = value == -1 ? ov::Dimension::dynamic() : Dimension(value); + m_buffer_valid = false; +}