FullyConnectedNode: deserialization fix (#8345)

* FullyConnected deserialization fix

* Added test-case with MatMul to caching tests

* postreview fixes
This commit is contained in:
Vladislav Golubev 2021-11-09 12:14:36 +03:00 committed by GitHub
parent 9057c15e92
commit 62b084a524
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 13 deletions

View File

@ -9,7 +9,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type)
: Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) {
constructor_validate_and_infer_types();
validate_and_infer_types();
}
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
@ -18,7 +18,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
const ngraph::Rank& output_rank,
const ngraph::element::Type output_type)
: Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) {
constructor_validate_and_infer_types();
validate_and_infer_types();
}
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
@ -40,13 +40,6 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
input_size,
", expected: 2 or 3.");
const auto output_size = get_output_size();
NODE_VALIDATION_CHECK(this,
output_size == 1,
"Number of outputs is incorrect. Current value is: ",
output_size,
", expected: 1.");
// Weights shape: [O, I1, ..., Im];
// O - output channels dimensions, Ik - input channels dimensions
const auto weights_pshape = get_input_partial_shape(1);
@ -101,10 +94,7 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
}
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
if (m_output_rank.is_static()) {
std::int64_t value = m_output_rank.get_length();
visitor.on_attribute("out-rank", value);
}
visitor.on_attribute("out-rank", m_output_rank);
visitor.on_attribute("out-type", m_output_type);
return true;
}

View File

@ -109,6 +109,9 @@ std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getStandardFunctio
res.push_back(nGraphFunctionWithName {
inputShapeWrapper(ngraph::builder::subgraph::makeReadConcatSplitAssign, {1, 1, 2, 4}),
"ReadConcatSplitAssign"});
res.push_back(nGraphFunctionWithName{
inputShapeWrapper(ngraph::builder::subgraph::makeMatMulBias, {1, 3, 24, 24}),
"MatMulBias" });
return res;
}

View File

@ -550,6 +550,23 @@ inline std::shared_ptr<ngraph::Function> makeReadConcatSplitAssign(std::vector<s
fn_ptr->set_friendly_name("ReadConcatSplitAssign");
return fn_ptr;
}
inline std::shared_ptr<ngraph::Function> makeMatMulBias(std::vector<size_t> inputShape = { 1, 3, 24, 24 },
ngraph::element::Type type = ngraph::element::Type_t::f32) {
auto parameter = ngraph::builder::makeParams(type, { inputShape });
parameter[0]->set_friendly_name("parameter");
auto weights = ngraph::opset1::Constant::create(type, ngraph::Shape{ 24, 24 }, { 1 });
auto biases = ngraph::opset1::Constant::create(type, ngraph::Shape{ 1, 24 }, { 1 });
auto matmul = std::make_shared<opset1::MatMul>(parameter[0], weights);
matmul->set_friendly_name("matmul");
auto add = std::make_shared<opset1::Add>(matmul, biases);
add->set_friendly_name("add");
auto result = std::make_shared<ngraph::opset1::Result>(add);
result->set_friendly_name("result");
std::shared_ptr<ngraph::Function> fn_ptr = std::make_shared<ngraph::Function>(ngraph::ResultVector{ result }, ngraph::ParameterVector{ parameter });
fn_ptr->set_friendly_name("MatMulBias");
return fn_ptr;
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -4,6 +4,7 @@
#pragma once
#include "ngraph/attribute_adapter.hpp"
#include "ngraph/interval.hpp"
#include "openvino/core/dimension.hpp"

View File

@ -8,6 +8,7 @@
#include <limits>
#include <stdexcept>
#include "openvino/core/attribute_adapter.hpp"
#include "openvino/core/core_visibility.hpp"
#include "openvino/core/interval.hpp"
@ -169,4 +170,23 @@ private:
/// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`.
OPENVINO_API
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
template <>
class OPENVINO_API AttributeAdapter<ov::Dimension> : public ValueAccessor<int64_t> {
public:
AttributeAdapter(ov::Dimension& value) : m_ref(value) {}
const int64_t& get() override;
void set(const int64_t& value) override;
operator ov::Dimension&() {
return m_ref;
}
OPENVINO_RTTI("AttributeAdapter<ov::Dimension>");
protected:
ov::Dimension& m_ref;
int64_t m_buffer{0};
bool m_buffer_valid{false};
};
} // namespace ov

View File

@ -105,3 +105,16 @@ Dimension::value_type Dimension::get_max_length() const {
Dimension::value_type Dimension::get_min_length() const {
return dimension_length(m_dimension.get_min_val());
}
const int64_t& ov::AttributeAdapter<ov::Dimension>::get() {
if (!m_buffer_valid) {
m_buffer = m_ref.is_dynamic() ? -1 : m_ref.get_length();
m_buffer_valid = true;
}
return m_buffer;
}
void ov::AttributeAdapter<ov::Dimension>::set(const int64_t& value) {
m_ref = value == -1 ? ov::Dimension::dynamic() : Dimension(value);
m_buffer_valid = false;
}