FullyConnected deserialization fix

This commit is contained in:
Vladislav Golubev 2021-11-01 09:35:10 +03:00
parent 0b0202b90c
commit 95bc22f065
4 changed files with 40 additions and 13 deletions

View File

@ -9,7 +9,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type)
: Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) {
constructor_validate_and_infer_types();
validate_and_infer_types();
}
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
@ -18,7 +18,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type)
: Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) {
constructor_validate_and_infer_types();
validate_and_infer_types();
}
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
@ -40,13 +40,6 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
input_size,
", expected: 2 or 3.");
const auto output_size = get_output_size();
NODE_VALIDATION_CHECK(this,
output_size == 1,
"Number of outputs is incorrect. Current value is: ",
output_size,
", expected: 1.");
// Weights shape: [O, I1, ..., Im];
// O - output channels dimensions, Ik - input channels dimensions
const auto weights_pshape = get_input_partial_shape(1);
@ -101,10 +94,7 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
}
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
if (m_output_rank.is_static()) {
std::int64_t value = m_output_rank.get_length();
visitor.on_attribute("out-rank", value);
}
visitor.on_attribute("out-rank", m_output_rank);
visitor.on_attribute("out-type", m_output_type);
return true;
}

View File

@ -4,6 +4,7 @@
#pragma once
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/interval.hpp"
#include "openvino/core/dimension.hpp"

View File

@ -8,6 +8,7 @@
#include <limits>
#include <stdexcept>
#include "openvino/core/attribute_adapter.hpp"
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/interval.hpp"
@ -169,4 +170,24 @@ private:
/// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`.
OPENVINO_API
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
template <>
class OPENVINO_API AttributeAdapter<ov::Dimension> : public ValueAccessor<int64_t> {
public:
AttributeAdapter(ov::Dimension& value) : m_ref(value) {}
const int64_t& get() override;
void set(const int64_t& value) override;
operator ov::Dimension&() {
return m_ref;
}
OPENVINO_RTTI("AttributeAdapter<ov::Dimension>");
BWDCMP_RTTI_DECLARATION;
protected:
ov::Dimension& m_ref;
int64_t m_buffer;
bool m_buffer_valid{false};
};
} // namespace ov

View File

@ -105,3 +105,18 @@ Dimension::value_type Dimension::get_max_length() const {
Dimension::value_type Dimension::get_min_length() const {
return dimension_length(m_dimension.get_min_val());
}
BWDCMP_RTTI_DEFINITION(ov::AttributeAdapter<ov::Dimension>);
const int64_t& ov::AttributeAdapter<ov::Dimension>::get() {
if (!m_buffer_valid) {
m_buffer = m_ref.is_dynamic() ? -1 : m_ref.get_length();
m_buffer_valid = true;
}
return m_buffer;
}
void ov::AttributeAdapter<ov::Dimension>::set(const int64_t& value) {
m_ref = value == -1 ? ov::Dimension::dynamic() : Dimension(value);
m_buffer_valid = false;
}