FullyConnectedNode: deserialization fix (#8345)
* FullyConnected deserialization fix * Added test-case with MatMul to caching tests * postreview fixes
This commit is contained in:
parent
9057c15e92
commit
62b084a524
@ -9,7 +9,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
|
||||
const ngraph::Rank& output_rank,
|
||||
const ngraph::element::Type output_type)
|
||||
: Op({A, B}), m_output_rank(output_rank), m_output_type(output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>& A,
|
||||
@ -18,7 +18,7 @@ MKLDNNPlugin::FullyConnectedNode::FullyConnectedNode(const ngraph::Output<Node>&
|
||||
const ngraph::Rank& output_rank,
|
||||
const ngraph::element::Type output_type)
|
||||
: Op({A, B, C}), m_output_rank(output_rank), m_output_type(output_type) {
|
||||
constructor_validate_and_infer_types();
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> MKLDNNPlugin::FullyConnectedNode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
||||
@ -40,13 +40,6 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||
input_size,
|
||||
", expected: 2 or 3.");
|
||||
|
||||
const auto output_size = get_output_size();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_size == 1,
|
||||
"Number of outputs is incorrect. Current value is: ",
|
||||
output_size,
|
||||
", expected: 1.");
|
||||
|
||||
// Weights shape: [O, I1, ..., Im];
|
||||
// O - output channels dimensions, Ik - input channels dimensions
|
||||
const auto weights_pshape = get_input_partial_shape(1);
|
||||
@ -101,10 +94,7 @@ void MKLDNNPlugin::FullyConnectedNode::validate_and_infer_types() {
|
||||
}
|
||||
|
||||
bool MKLDNNPlugin::FullyConnectedNode::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
||||
if (m_output_rank.is_static()) {
|
||||
std::int64_t value = m_output_rank.get_length();
|
||||
visitor.on_attribute("out-rank", value);
|
||||
}
|
||||
visitor.on_attribute("out-rank", m_output_rank);
|
||||
visitor.on_attribute("out-type", m_output_type);
|
||||
return true;
|
||||
}
|
||||
|
@ -109,6 +109,9 @@ std::vector<nGraphFunctionWithName> LoadNetworkCacheTestBase::getStandardFunctio
|
||||
res.push_back(nGraphFunctionWithName {
|
||||
inputShapeWrapper(ngraph::builder::subgraph::makeReadConcatSplitAssign, {1, 1, 2, 4}),
|
||||
"ReadConcatSplitAssign"});
|
||||
res.push_back(nGraphFunctionWithName{
|
||||
inputShapeWrapper(ngraph::builder::subgraph::makeMatMulBias, {1, 3, 24, 24}),
|
||||
"MatMulBias" });
|
||||
|
||||
return res;
|
||||
}
|
||||
|
@ -550,6 +550,23 @@ inline std::shared_ptr<ngraph::Function> makeReadConcatSplitAssign(std::vector<s
|
||||
fn_ptr->set_friendly_name("ReadConcatSplitAssign");
|
||||
return fn_ptr;
|
||||
}
|
||||
|
||||
inline std::shared_ptr<ngraph::Function> makeMatMulBias(std::vector<size_t> inputShape = { 1, 3, 24, 24 },
|
||||
ngraph::element::Type type = ngraph::element::Type_t::f32) {
|
||||
auto parameter = ngraph::builder::makeParams(type, { inputShape });
|
||||
parameter[0]->set_friendly_name("parameter");
|
||||
auto weights = ngraph::opset1::Constant::create(type, ngraph::Shape{ 24, 24 }, { 1 });
|
||||
auto biases = ngraph::opset1::Constant::create(type, ngraph::Shape{ 1, 24 }, { 1 });
|
||||
auto matmul = std::make_shared<opset1::MatMul>(parameter[0], weights);
|
||||
matmul->set_friendly_name("matmul");
|
||||
auto add = std::make_shared<opset1::Add>(matmul, biases);
|
||||
add->set_friendly_name("add");
|
||||
auto result = std::make_shared<ngraph::opset1::Result>(add);
|
||||
result->set_friendly_name("result");
|
||||
std::shared_ptr<ngraph::Function> fn_ptr = std::make_shared<ngraph::Function>(ngraph::ResultVector{ result }, ngraph::ParameterVector{ parameter });
|
||||
fn_ptr->set_friendly_name("MatMulBias");
|
||||
return fn_ptr;
|
||||
}
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/attribute_adapter.hpp"
|
||||
#include "ngraph/interval.hpp"
|
||||
#include "openvino/core/dimension.hpp"
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <limits>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "openvino/core/attribute_adapter.hpp"
|
||||
#include "openvino/core/core_visibility.hpp"
|
||||
#include "openvino/core/interval.hpp"
|
||||
|
||||
@ -169,4 +170,23 @@ private:
|
||||
/// Inserts the string `?` if `dimension` is dynamic; else inserts `dimension.get_length()`.
|
||||
OPENVINO_API
|
||||
std::ostream& operator<<(std::ostream& str, const Dimension& dimension);
|
||||
|
||||
template <>
|
||||
class OPENVINO_API AttributeAdapter<ov::Dimension> : public ValueAccessor<int64_t> {
|
||||
public:
|
||||
AttributeAdapter(ov::Dimension& value) : m_ref(value) {}
|
||||
|
||||
const int64_t& get() override;
|
||||
void set(const int64_t& value) override;
|
||||
operator ov::Dimension&() {
|
||||
return m_ref;
|
||||
}
|
||||
|
||||
OPENVINO_RTTI("AttributeAdapter<ov::Dimension>");
|
||||
|
||||
protected:
|
||||
ov::Dimension& m_ref;
|
||||
int64_t m_buffer{0};
|
||||
bool m_buffer_valid{false};
|
||||
};
|
||||
} // namespace ov
|
||||
|
@ -105,3 +105,16 @@ Dimension::value_type Dimension::get_max_length() const {
|
||||
Dimension::value_type Dimension::get_min_length() const {
|
||||
return dimension_length(m_dimension.get_min_val());
|
||||
}
|
||||
|
||||
const int64_t& ov::AttributeAdapter<ov::Dimension>::get() {
|
||||
if (!m_buffer_valid) {
|
||||
m_buffer = m_ref.is_dynamic() ? -1 : m_ref.get_length();
|
||||
m_buffer_valid = true;
|
||||
}
|
||||
return m_buffer;
|
||||
}
|
||||
|
||||
void ov::AttributeAdapter<ov::Dimension>::set(const int64_t& value) {
|
||||
m_ref = value == -1 ? ov::Dimension::dynamic() : Dimension(value);
|
||||
m_buffer_valid = false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user