FullyConnected deserialization fix
This commit is contained in:
parent
0b0202b90c
commit
95bc22f065
@ -9,7 +9,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
|
||||
const ngraph::Rank& output_rank,
|
||||
const ngraph::element::Type output_type)
|
||||
: Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
|
||||
@ -18,7 +18,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
|
||||
const ngraph::Rank& output_rank,
|
||||
const ngraph::element::Type output_type)
|
||||
: Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
||||
@ -40,13 +40,6 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||
input_size,
|
||||
", expected: 2 or 3.");
|
||||
|
||||
const auto output_size = get_output_size();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_size == 1,
|
||||
"Number of outputs is incorrect. Current value is: ",
|
||||
output_size,
|
||||
", expected: 1.");
|
||||
|
||||
// Weights shape: [O, I1, ..., Im];
|
||||
// O - output channels dimensions, Ik - input channels dimensions
|
||||
const auto weights_pshape = get_input_partial_shape(1);
|
||||
@ -101,10 +94,7 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||
}
|
||||
|
||||
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
||||
if (m_output_rank.is_static()) {
|
||||
std::int64_t value = m_output_rank.get_length();
|
||||
visitor.on_attribute("out-rank", value);
|
||||
}
|
||||
visitor.on_attribute("out-rank", m_output_rank);
|
||||
visitor.on_attribute("out-type", m_output_type);
|
||||
return true;
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/attribute_adapter.hpp"
|
||||
#include "ngraph/interval.hpp"
|
||||
#include "openvino/core/dimension.hpp"
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "openvino/core/attribute_adapter.hpp"
|
||||
#include "openvino/core/core_visibility.hpp"
|
||||
#include "openvino/core/interval.hpp"
|
||||
|
||||
@ -169,4 +170,24 @@ private:
|
||||
/// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`.
|
||||
OPENVINO_API
|
||||
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
|
||||
|
||||
template <>
|
||||
class OPENVINO_API AttributeAdapter<ov::Dimension> : public ValueAccessor<int64_t> {
|
||||
public:
|
||||
AttributeAdapter(ov::Dimension& value) : m_ref(value) {}
|
||||
|
||||
const int64_t& get() override;
|
||||
void set(const int64_t& value) override;
|
||||
operator ov::Dimension&() {
|
||||
return m_ref;
|
||||
}
|
||||
|
||||
OPENVINO_RTTI("AttributeAdapter<ov::Dimension>");
|
||||
BWDCMP_RTTI_DECLARATION;
|
||||
|
||||
protected:
|
||||
ov::Dimension& m_ref;
|
||||
int64_t m_buffer;
|
||||
bool m_buffer_valid{false};
|
||||
};
|
||||
} // namespace ov
|
||||
|
@ -105,3 +105,18 @@ Dimension::value_type Dimension::get_max_length() const {
|
||||
Dimension::value_type Dimension::get_min_length() const {
|
||||
return dimension_length(m_dimension.get_min_val());
|
||||
}
|
||||
|
||||
BWDCMP_RTTI_DEFINITION(ov::AttributeAdapter<ov::Dimension>);
|
||||
|
||||
const int64_t& ov::AttributeAdapter<ov::Dimension>::get() {
|
||||
if (!m_buffer_valid) {
|
||||
m_buffer = m_ref.is_dynamic() ? -1 : m_ref.get_length();
|
||||
m_buffer_valid = true;
|
||||
}
|
||||
return m_buffer;
|
||||
}
|
||||
|
||||
void ov::AttributeAdapter<ov::Dimension>::set(const int64_t& value) {
|
||||
m_ref = value == -1 ? ov::Dimension::dynamic() : Dimension(value);
|
||||
m_buffer_valid = false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user