Support FrameworkNode with subgraphs (#14994)
* Support FrameworkNode with subgraphs * Fix code style * Fix test fail * Apply suggestions from code review Co-authored-by: Ilya Churaev <ilyachur@gmail.com> * Apply review feedback * Fix code style * Apply suggestions from code review Co-authored-by: Ilya Churaev <ilyachur@gmail.com> Co-authored-by: Ilya Churaev <ilyachur@gmail.com>
This commit is contained in:
parent
11f1c138de
commit
fba025f1e7
@ -79,14 +79,6 @@ public:
|
||||
const std::shared_ptr<v0::Result>& else_result);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
private:
|
||||
using OutputMap = std::map<int64_t, std::shared_ptr<MultiSubGraphOp::OutputDescription>>;
|
||||
|
||||
void validate_and_infer_type_body(const std::shared_ptr<Model>& body,
|
||||
const MultiSubgraphInputDescriptionVector& input_descriptors);
|
||||
|
||||
OutputMap get_mapping_outputs_on_body_description(const MultiSubgraphOutputDescriptionVector& output_descriptors);
|
||||
};
|
||||
} // namespace v8
|
||||
} // namespace op
|
||||
|
@ -8,7 +8,7 @@
|
||||
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "openvino/core/strides.hpp"
|
||||
#include "openvino/op/op.hpp"
|
||||
#include "openvino/op/util/multi_subgraph_base.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
@ -69,20 +69,17 @@ private:
|
||||
std::unordered_map<std::string, std::string> m_attrs;
|
||||
};
|
||||
|
||||
class OPENVINO_API FrameworkNode : public Op {
|
||||
class OPENVINO_API FrameworkNode : public MultiSubGraphOp {
|
||||
public:
|
||||
OPENVINO_OP("FrameworkNode", "util");
|
||||
|
||||
FrameworkNode() = default;
|
||||
|
||||
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1);
|
||||
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1, size_t num_subgraphs = 0);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override {
|
||||
visitor.on_attribute("framework_node_attrs", m_attrs);
|
||||
return true;
|
||||
}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
const FrameworkNodeAttrs& get_attrs() const {
|
||||
return m_attrs;
|
||||
@ -96,11 +93,17 @@ public:
|
||||
|
||||
void cache_output_descriptor();
|
||||
|
||||
protected:
|
||||
FrameworkNode(const FrameworkNode&);
|
||||
|
||||
private:
|
||||
void clone_to(FrameworkNode& dst) const;
|
||||
|
||||
std::vector<std::tuple<ov::PartialShape, ov::element::Type>> m_inputs_desc;
|
||||
std::vector<std::tuple<ov::PartialShape, ov::element::Type>> m_output_desc;
|
||||
|
||||
FrameworkNodeAttrs m_attrs;
|
||||
size_t m_num_bodies;
|
||||
};
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
|
@ -308,6 +308,11 @@ protected:
|
||||
MultiSubGraphOp(const OutputVector& args, size_t number_of_bodies);
|
||||
explicit MultiSubGraphOp(const OutputVector& args);
|
||||
|
||||
using OutputMap = std::map<int64_t, std::shared_ptr<MultiSubGraphOp::OutputDescription>>;
|
||||
void validate_and_infer_type_body(const std::shared_ptr<ov::Model>& body,
|
||||
const MultiSubgraphInputDescriptionVector& input_descriptors);
|
||||
OutputMap get_mapping_outputs_on_body_description(const MultiSubgraphOutputDescriptionVector& output_descriptors);
|
||||
|
||||
std::vector<std::shared_ptr<Model>> m_bodies;
|
||||
std::vector<MultiSubgraphInputDescriptionVector> m_input_descriptions;
|
||||
std::vector<MultiSubgraphOutputDescriptionVector> m_output_descriptions;
|
||||
|
@ -75,19 +75,6 @@ bool ov::op::v8::If::visit_attributes(AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ov::op::v8::If::validate_and_infer_type_body(
|
||||
const std::shared_ptr<ov::Model>& body,
|
||||
const ngraph::op::util::MultiSubgraphInputDescriptionVector& input_descriptors) {
|
||||
for (const auto& input_description : input_descriptors) {
|
||||
auto index = input_description->m_input_index;
|
||||
|
||||
auto body_parameter = body->get_parameters().at(input_description->m_body_parameter_index);
|
||||
auto input_partial_shape = input_value(index).get_partial_shape();
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
}
|
||||
body->validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
void ov::op::v8::If::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v8_If_validate_and_infer_types);
|
||||
|
||||
@ -202,30 +189,6 @@ std::shared_ptr<ov::Node> ov::op::v8::If::clone_with_new_inputs(const OutputVect
|
||||
return op;
|
||||
}
|
||||
|
||||
ov::op::v8::If::OutputMap ov::op::v8::If::get_mapping_outputs_on_body_description(
|
||||
const ngraph::op::util::MultiSubgraphOutputDescriptionVector& output_descriptors) {
|
||||
OutputMap outputs_map = OutputMap();
|
||||
std::unordered_set<int64_t> checked_results_in_body;
|
||||
|
||||
for (const auto& output_description : output_descriptors) {
|
||||
auto out_index = output_description->m_output_index;
|
||||
auto internal_result_index = output_description->m_body_value_index;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
checked_results_in_body.count(internal_result_index) == 0,
|
||||
"Incorrect associating in then_body! Result ",
|
||||
internal_result_index,
|
||||
" is already associated with another output!");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
outputs_map.count(out_index) == 0,
|
||||
"Incorrect associating in then_body! Several results try to "
|
||||
"associate with the same output!");
|
||||
checked_results_in_body.insert(internal_result_index);
|
||||
outputs_map.insert({out_index, output_description});
|
||||
}
|
||||
|
||||
return outputs_map;
|
||||
}
|
||||
|
||||
void ov::op::v8::If::set_input(const Output<Node>& value,
|
||||
const std::shared_ptr<v0::Parameter>& then_parameter,
|
||||
const std::shared_ptr<v0::Parameter>& else_parameter) {
|
||||
|
@ -5,22 +5,54 @@
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
|
||||
ov::op::util::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size) : Op(inputs) {
|
||||
ov::op::util::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size, size_t num_subgraphs)
|
||||
: MultiSubGraphOp(num_subgraphs),
|
||||
m_num_bodies(num_subgraphs) {
|
||||
set_arguments(inputs);
|
||||
set_output_size(output_size);
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
ov::op::util::FrameworkNode::FrameworkNode(const ov::op::util::FrameworkNode& other) : MultiSubGraphOp() {
|
||||
set_arguments(other.input_values());
|
||||
other.clone_to(*this);
|
||||
}
|
||||
|
||||
void ov::op::util::FrameworkNode::clone_to(ov::op::util::FrameworkNode& dst) const {
|
||||
dst.set_output_size(get_output_size());
|
||||
|
||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||
dst.set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
|
||||
}
|
||||
dst.m_inputs_desc = m_inputs_desc;
|
||||
dst.m_output_desc = m_output_desc;
|
||||
dst.m_attrs = m_attrs;
|
||||
dst.m_num_bodies = m_num_bodies;
|
||||
|
||||
dst.m_bodies.resize(m_num_bodies);
|
||||
dst.m_input_descriptions.resize(m_num_bodies);
|
||||
dst.m_output_descriptions.resize(m_num_bodies);
|
||||
|
||||
for (size_t i = 0; i < m_num_bodies; i++) {
|
||||
dst.m_bodies[i] = get_function(i)->clone();
|
||||
for (const auto& input_description : m_input_descriptions[i]) {
|
||||
dst.m_input_descriptions[i].push_back(input_description->copy());
|
||||
}
|
||||
for (auto& output_description : m_output_descriptions[i]) {
|
||||
dst.m_output_descriptions[i].push_back(output_description->copy());
|
||||
}
|
||||
}
|
||||
|
||||
dst.validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> ov::op::util::FrameworkNode::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(FrameworkNode_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
auto node = std::make_shared<op::util::FrameworkNode>(new_args);
|
||||
for (size_t i = 0; i < get_output_size(); ++i) {
|
||||
node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
|
||||
}
|
||||
node->m_inputs_desc = m_inputs_desc;
|
||||
node->m_output_desc = m_output_desc;
|
||||
node->m_attrs = m_attrs;
|
||||
auto node = std::make_shared<op::util::FrameworkNode>(new_args, get_output_size(), m_num_bodies);
|
||||
clone_to(*node);
|
||||
return node;
|
||||
}
|
||||
|
||||
@ -32,6 +64,60 @@ void ov::op::util::FrameworkNode::cache_output_descriptor() {
|
||||
|
||||
void ov::op::util::FrameworkNode::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(FrameworkNode_validate_and_infer_types);
|
||||
|
||||
if (m_inputs_desc.size() < get_input_size()) {
|
||||
// case when we added inputs using set_invariant_inputs
|
||||
m_inputs_desc.clear();
|
||||
}
|
||||
if (m_output_desc.size() < get_output_size()) {
|
||||
// case when we added outputs using set_body_outputs
|
||||
m_output_desc.clear();
|
||||
}
|
||||
|
||||
// propagate shapes and types from bodies
|
||||
std::unordered_map<size_t, PartialShape> shape_map;
|
||||
std::unordered_map<size_t, element::Type> type_map;
|
||||
for (size_t i = 0; i < m_bodies.size(); ++i) {
|
||||
auto body = get_function(i);
|
||||
// If body doesn't exist skip the validation
|
||||
if (!body)
|
||||
continue;
|
||||
validate_and_infer_type_body(get_function(i), m_input_descriptions[i]);
|
||||
|
||||
auto outputs_map = get_mapping_outputs_on_body_description(m_output_descriptions[i]);
|
||||
|
||||
for (const auto& item : outputs_map) {
|
||||
auto output_index = item.first;
|
||||
auto desc = item.second;
|
||||
auto node_result = m_bodies[i]->get_results().at(desc->m_body_value_index)->input_value(0);
|
||||
auto pshape = PartialShape::dynamic();
|
||||
if (shape_map.count(output_index)) {
|
||||
pshape = shape_map.at(output_index);
|
||||
}
|
||||
if (PartialShape::merge_into(pshape, node_result.get_partial_shape())) {
|
||||
shape_map[output_index] = pshape;
|
||||
} else {
|
||||
shape_map[output_index] = PartialShape::dynamic();
|
||||
}
|
||||
auto type = element::dynamic;
|
||||
if (type_map.count(output_index)) {
|
||||
type = type_map.at(output_index);
|
||||
}
|
||||
if (element::Type::merge(type, type, node_result.get_element_type())) {
|
||||
type_map[output_index] = type;
|
||||
} else {
|
||||
type_map[output_index] = element::dynamic;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto& item : shape_map) {
|
||||
auto output_index = item.first;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
type_map.count(output_index) != 0,
|
||||
"Type map must contain same outputs as shape map");
|
||||
set_output_type(output_index, type_map.at(output_index), item.second);
|
||||
}
|
||||
|
||||
// Save initial inputs descriptors
|
||||
bool initialize_input_desc = m_inputs_desc.empty();
|
||||
bool reset_output_shape_to_dynamic = false;
|
||||
@ -93,5 +179,21 @@ void ov::op::util::FrameworkNode::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
bool ov::op::util::FrameworkNode::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("framework_node_attrs", m_attrs);
|
||||
visitor.on_attribute("num_bodies", m_num_bodies);
|
||||
|
||||
m_bodies.resize(m_num_bodies);
|
||||
m_input_descriptions.resize(m_num_bodies);
|
||||
m_output_descriptions.resize(m_num_bodies);
|
||||
|
||||
for (size_t i = 0; i < m_num_bodies; ++i) {
|
||||
visitor.on_attribute("body" + std::to_string(i), m_bodies[i]);
|
||||
visitor.on_attribute("input_descriptions" + std::to_string(i), m_input_descriptions[i]);
|
||||
visitor.on_attribute("output_descriptions" + std::to_string(i), m_output_descriptions[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ov::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>::AttributeAdapter(ov::op::util::FrameworkNodeAttrs& value)
|
||||
: DirectValueAccessor<ov::op::util::FrameworkNodeAttrs>(value) {}
|
||||
|
@ -146,3 +146,41 @@ ov::Output<ov::Node> ov::op::util::MultiSubGraphOp::set_body_outputs(const Resul
|
||||
validate_and_infer_types();
|
||||
return Output<Node>(shared_from_this(), output_index);
|
||||
}
|
||||
|
||||
void ov::op::util::MultiSubGraphOp::validate_and_infer_type_body(
|
||||
const std::shared_ptr<ov::Model>& body,
|
||||
const ov::op::util::MultiSubGraphOp::MultiSubgraphInputDescriptionVector& input_descriptors) {
|
||||
const auto& params = body->get_parameters();
|
||||
for (const auto& input_description : input_descriptors) {
|
||||
auto index = input_description->m_input_index;
|
||||
|
||||
auto body_parameter = params.at(input_description->m_body_parameter_index);
|
||||
auto input_partial_shape = input_value(index).get_partial_shape();
|
||||
body_parameter->set_partial_shape(input_partial_shape);
|
||||
}
|
||||
body->validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
ov::op::util::MultiSubGraphOp::OutputMap ov::op::util::MultiSubGraphOp::get_mapping_outputs_on_body_description(
|
||||
const ov::op::util::MultiSubGraphOp::MultiSubgraphOutputDescriptionVector& output_descriptors) {
|
||||
OutputMap outputs_map = OutputMap();
|
||||
std::unordered_set<int64_t> checked_results_in_body;
|
||||
|
||||
for (const auto& output_description : output_descriptors) {
|
||||
const auto& out_index = output_description->m_output_index;
|
||||
const auto& internal_result_index = output_description->m_body_value_index;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
checked_results_in_body.count(internal_result_index) == 0,
|
||||
"Incorrect associating in body! Result ",
|
||||
internal_result_index,
|
||||
" is already associated with another output!");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
outputs_map.count(out_index) == 0,
|
||||
"Incorrect associating in body! Several results try to "
|
||||
"associate with the same output!");
|
||||
checked_results_in_body.insert(internal_result_index);
|
||||
outputs_map.insert({out_index, output_description});
|
||||
}
|
||||
|
||||
return outputs_map;
|
||||
}
|
||||
|
@ -388,6 +388,33 @@ public:
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!is_body_target) {
|
||||
std::string id = "input_descriptions";
|
||||
std::string od = "output_descriptions";
|
||||
const auto& id_pos = name.find("input_descriptions");
|
||||
const auto& od_pos = name.find("output_descriptions");
|
||||
auto id_str = name;
|
||||
size_t body_id;
|
||||
if (id_pos != std::string::npos) {
|
||||
id_str.erase(id_pos, id.length());
|
||||
std::stoi(id_str, &body_id);
|
||||
is_body_target = true;
|
||||
} else if (od_pos != std::string::npos) {
|
||||
id_str.erase(od_pos, od.length());
|
||||
std::stoi(id_str, &body_id);
|
||||
is_body_target = true;
|
||||
}
|
||||
if (is_body_target) {
|
||||
auto body_name = "body" + id_str;
|
||||
if (m_xml_node.parent().child(body_name.c_str())) {
|
||||
bnames = BodyTargetNames{body_name,
|
||||
"port_map" + id_str,
|
||||
{"input_descriptions" + id_str, "output_descriptions" + id_str}};
|
||||
} else {
|
||||
is_body_target = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_body_target) {
|
||||
auto body_name = std::get<0>(bnames);
|
||||
auto portmap_name = std::get<1>(bnames);
|
||||
@ -503,7 +530,8 @@ public:
|
||||
m_xml_node.append_attribute(name.c_str()).set_value(create_atribute_list(adapter).c_str());
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<Function>>& adapter) override {
|
||||
if (name == "body" || name == "then_body" || name == "else_body") {
|
||||
if (name.find("body") != std::string::npos) {
|
||||
// name that contains subgraphs: body{n}, then_body, else_body
|
||||
// TI, Loop do not have attributtes as regular ops, it is necessary to append "body"
|
||||
// to layer above (m_xml_node.parent()) as in ngfunction_2_ir() layer (m_xml_node) with empty attributes
|
||||
// is removed.
|
||||
|
@ -60,7 +60,7 @@ TEST_F(CustomOpsSerializationTest, CustomOpNoExtensions) {
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="operation" id="1" type="Template" version="custom_opset">
|
||||
<data add="11"/>
|
||||
<data num_bodies="0" add="11"/>
|
||||
<input>
|
||||
<port id="1" precision="FP32">
|
||||
<dim>2</dim>
|
||||
|
66
src/core/tests/visitors/op/framework_node.cpp
Normal file
66
src/core/tests/visitors/op/framework_node.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
|
||||
#include "common_test_utils/graph_comparator.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/op/util/attr_types.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, framework_node_op) {
|
||||
NodeBuilder::get_ops().register_factory<op::util::FrameworkNode>();
|
||||
auto X = make_shared<Parameter>(element::f32, Shape{1, 2, 2});
|
||||
auto Y = make_shared<Parameter>(element::f32, Shape{1, 2, 2});
|
||||
auto cond = make_shared<Constant>(element::boolean, Shape{1}, true);
|
||||
auto cond2 = make_shared<Constant>(element::boolean, Shape{1}, false);
|
||||
auto Xt = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Yt = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Xe = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto Ye = make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto then_op = make_shared<Multiply>(Xt, Yt);
|
||||
auto res0 = make_shared<Result>(then_op);
|
||||
auto res1 = make_shared<Result>(Xe);
|
||||
auto body1 = make_shared<Model>(OutputVector{res0}, ParameterVector{Xt, Yt});
|
||||
auto body2 = make_shared<Model>(OutputVector{res1}, ParameterVector{Xe});
|
||||
auto fn_op = make_shared<op::util::FrameworkNode>(OutputVector{cond}, 0, 2);
|
||||
|
||||
// Add attributes
|
||||
auto attrs = op::util::FrameworkNodeAttrs();
|
||||
attrs.set_type_name("some_type");
|
||||
fn_op->set_attrs(attrs);
|
||||
|
||||
fn_op->set_function(0, body1);
|
||||
fn_op->set_function(1, body2);
|
||||
fn_op->set_invariant_inputs(X, {Xt, Xe});
|
||||
fn_op->set_invariant_inputs(Y, {Yt, nullptr});
|
||||
auto out = fn_op->set_body_outputs({res0, res1});
|
||||
fn_op->validate_and_infer_types();
|
||||
EXPECT_EQ(fn_op->inputs().size(), 3);
|
||||
EXPECT_EQ(fn_op->outputs().size(), 1);
|
||||
|
||||
NodeBuilder builder(fn_op);
|
||||
auto g_fn = ov::as_type_ptr<op::util::FrameworkNode>(builder.create());
|
||||
EXPECT_EQ(g_fn->get_attrs(), fn_op->get_attrs());
|
||||
EXPECT_EQ(g_fn->get_input_descriptions(0), fn_op->get_input_descriptions(0));
|
||||
EXPECT_EQ(g_fn->get_input_descriptions(1), fn_op->get_input_descriptions(1));
|
||||
EXPECT_EQ(g_fn->get_output_descriptions(0), fn_op->get_output_descriptions(0));
|
||||
EXPECT_EQ(g_fn->get_output_descriptions(1), fn_op->get_output_descriptions(1));
|
||||
|
||||
auto comparator = FunctionsComparator::with_default();
|
||||
ASSERT_TRUE(g_fn->get_function(0));
|
||||
auto res = comparator.compare(g_fn->get_function(0), fn_op->get_function(0));
|
||||
EXPECT_TRUE(res.valid) << res.message;
|
||||
|
||||
ASSERT_TRUE(g_fn->get_function(1));
|
||||
res = comparator.compare(g_fn->get_function(1), fn_op->get_function(1));
|
||||
EXPECT_TRUE(res.valid) << res.message;
|
||||
}
|
@ -11,6 +11,7 @@
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/factory.hpp"
|
||||
#include "ngraph/op/util/framework_node.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
|
||||
@ -120,6 +121,9 @@ public:
|
||||
virtual operator std::shared_ptr<Variable>&() {
|
||||
NGRAPH_CHECK(false, "Invalid type access");
|
||||
}
|
||||
virtual operator ov::op::util::FrameworkNodeAttrs&() {
|
||||
NGRAPH_CHECK(false, "Invalid type access");
|
||||
}
|
||||
uint64_t get_index() {
|
||||
return m_index;
|
||||
}
|
||||
@ -224,6 +228,8 @@ public:
|
||||
a->set(m_values.get<ov::Dimension>(name));
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<Variable>>>(&adapter)) {
|
||||
a->set(m_values.get<std::shared_ptr<Variable>>(name));
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||
a->set(m_values.get<ov::op::util::FrameworkNodeAttrs>(name));
|
||||
} else {
|
||||
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
|
||||
}
|
||||
@ -310,6 +316,8 @@ public:
|
||||
m_values.insert(name, a->get());
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<Variable>>>(&adapter)) {
|
||||
m_values.insert(name, a->get());
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
|
||||
m_values.insert(name, a->get());
|
||||
} else {
|
||||
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user