Support FrameworkNode with subgraphs (#14994)

* Support FrameworkNode with subgraphs

* Fix code style

* Fix test fail

* Apply suggestions from code review

Co-authored-by: Ilya Churaev <ilyachur@gmail.com>

* Apply review feedback

* Fix code style

* Apply suggestions from code review

Co-authored-by: Ilya Churaev <ilyachur@gmail.com>

Co-authored-by: Ilya Churaev <ilyachur@gmail.com>
This commit is contained in:
Maxim Vafin 2023-01-10 20:40:39 +01:00 committed by GitHub
parent 11f1c138de
commit fba025f1e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 267 additions and 62 deletions

View File

@ -79,14 +79,6 @@ public:
const std::shared_ptr<v0::Result>& else_result);
void validate_and_infer_types() override;
private:
using OutputMap = std::map<int64_t, std::shared_ptr<MultiSubGraphOp::OutputDescription>>;
void validate_and_infer_type_body(const std::shared_ptr<Model>& body,
const MultiSubgraphInputDescriptionVector& input_descriptors);
OutputMap get_mapping_outputs_on_body_description(const MultiSubgraphOutputDescriptionVector& output_descriptors);
};
} // namespace v8
} // namespace op

View File

@ -8,7 +8,7 @@
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/strides.hpp"
#include "openvino/op/op.hpp"
#include "openvino/op/util/multi_subgraph_base.hpp"
namespace ov {
namespace op {
@ -69,20 +69,17 @@ private:
std::unordered_map<std::string, std::string> m_attrs;
};
class OPENVINO_API FrameworkNode : public Op {
class OPENVINO_API FrameworkNode : public MultiSubGraphOp {
public:
OPENVINO_OP("FrameworkNode", "util");
FrameworkNode() = default;
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1);
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1, size_t num_subgraphs = 0);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override {
visitor.on_attribute("framework_node_attrs", m_attrs);
return true;
}
bool visit_attributes(AttributeVisitor& visitor) override;
const FrameworkNodeAttrs& get_attrs() const {
return m_attrs;
@ -96,11 +93,17 @@ public:
void cache_output_descriptor();
protected:
FrameworkNode(const FrameworkNode&);
private:
void clone_to(FrameworkNode& dst) const;
std::vector<std::tuple<ov::PartialShape, ov::element::Type>> m_inputs_desc;
std::vector<std::tuple<ov::PartialShape, ov::element::Type>> m_output_desc;
FrameworkNodeAttrs m_attrs;
size_t m_num_bodies;
};
} // namespace util
} // namespace op

View File

@ -308,6 +308,11 @@ protected:
MultiSubGraphOp(const OutputVector& args, size_t number_of_bodies);
explicit MultiSubGraphOp(const OutputVector& args);
using OutputMap = std::map<int64_t, std::shared_ptr<MultiSubGraphOp::OutputDescription>>;
void validate_and_infer_type_body(const std::shared_ptr<ov::Model>& body,
const MultiSubgraphInputDescriptionVector& input_descriptors);
OutputMap get_mapping_outputs_on_body_description(const MultiSubgraphOutputDescriptionVector& output_descriptors);
std::vector<std::shared_ptr<Model>> m_bodies;
std::vector<MultiSubgraphInputDescriptionVector> m_input_descriptions;
std::vector<MultiSubgraphOutputDescriptionVector> m_output_descriptions;

View File

@ -75,19 +75,6 @@ bool ov::op::v8::If::visit_attributes(AttributeVisitor& visitor) {
return true;
}
void ov::op::v8::If::validate_and_infer_type_body(
const std::shared_ptr<ov::Model>& body,
const ngraph::op::util::MultiSubgraphInputDescriptionVector& input_descriptors) {
for (const auto& input_description : input_descriptors) {
auto index = input_description->m_input_index;
auto body_parameter = body->get_parameters().at(input_description->m_body_parameter_index);
auto input_partial_shape = input_value(index).get_partial_shape();
body_parameter->set_partial_shape(input_partial_shape);
}
body->validate_nodes_and_infer_types();
}
void ov::op::v8::If::validate_and_infer_types() {
OV_OP_SCOPE(v8_If_validate_and_infer_types);
@ -202,30 +189,6 @@ std::shared_ptr<ov::Node> ov::op::v8::If::clone_with_new_inputs(const OutputVect
return op;
}
ov::op::v8::If::OutputMap ov::op::v8::If::get_mapping_outputs_on_body_description(
const ngraph::op::util::MultiSubgraphOutputDescriptionVector& output_descriptors) {
OutputMap outputs_map = OutputMap();
std::unordered_set<int64_t> checked_results_in_body;
for (const auto& output_description : output_descriptors) {
auto out_index = output_description->m_output_index;
auto internal_result_index = output_description->m_body_value_index;
NODE_VALIDATION_CHECK(this,
checked_results_in_body.count(internal_result_index) == 0,
"Incorrect associating in then_body! Result ",
internal_result_index,
" is already associated with another output!");
NODE_VALIDATION_CHECK(this,
outputs_map.count(out_index) == 0,
"Incorrect associating in then_body! Several results try to "
"associate with the same output!");
checked_results_in_body.insert(internal_result_index);
outputs_map.insert({out_index, output_description});
}
return outputs_map;
}
void ov::op::v8::If::set_input(const Output<Node>& value,
const std::shared_ptr<v0::Parameter>& then_parameter,
const std::shared_ptr<v0::Parameter>& else_parameter) {

View File

@ -5,22 +5,54 @@
#include "openvino/op/util/framework_node.hpp"
#include "itt.hpp"
#include "ngraph/graph_util.hpp"
ov::op::util::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size) : Op(inputs) {
ov::op::util::FrameworkNode::FrameworkNode(const OutputVector& inputs, size_t output_size, size_t num_subgraphs)
: MultiSubGraphOp(num_subgraphs),
m_num_bodies(num_subgraphs) {
set_arguments(inputs);
set_output_size(output_size);
constructor_validate_and_infer_types();
}
ov::op::util::FrameworkNode::FrameworkNode(const ov::op::util::FrameworkNode& other) : MultiSubGraphOp() {
set_arguments(other.input_values());
other.clone_to(*this);
}
void ov::op::util::FrameworkNode::clone_to(ov::op::util::FrameworkNode& dst) const {
dst.set_output_size(get_output_size());
for (size_t i = 0; i < get_output_size(); ++i) {
dst.set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
}
dst.m_inputs_desc = m_inputs_desc;
dst.m_output_desc = m_output_desc;
dst.m_attrs = m_attrs;
dst.m_num_bodies = m_num_bodies;
dst.m_bodies.resize(m_num_bodies);
dst.m_input_descriptions.resize(m_num_bodies);
dst.m_output_descriptions.resize(m_num_bodies);
for (size_t i = 0; i < m_num_bodies; i++) {
dst.m_bodies[i] = get_function(i)->clone();
for (const auto& input_description : m_input_descriptions[i]) {
dst.m_input_descriptions[i].push_back(input_description->copy());
}
for (auto& output_description : m_output_descriptions[i]) {
dst.m_output_descriptions[i].push_back(output_description->copy());
}
}
dst.validate_and_infer_types();
}
std::shared_ptr<ov::Node> ov::op::util::FrameworkNode::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(FrameworkNode_clone_with_new_inputs);
check_new_args_count(this, new_args);
auto node = std::make_shared<op::util::FrameworkNode>(new_args);
for (size_t i = 0; i < get_output_size(); ++i) {
node->set_output_type(i, get_output_element_type(i), get_output_partial_shape(i));
}
node->m_inputs_desc = m_inputs_desc;
node->m_output_desc = m_output_desc;
node->m_attrs = m_attrs;
auto node = std::make_shared<op::util::FrameworkNode>(new_args, get_output_size(), m_num_bodies);
clone_to(*node);
return node;
}
@ -32,6 +64,60 @@ void ov::op::util::FrameworkNode::cache_output_descriptor() {
void ov::op::util::FrameworkNode::validate_and_infer_types() {
OV_OP_SCOPE(FrameworkNode_validate_and_infer_types);
if (m_inputs_desc.size() < get_input_size()) {
// case when we added inputs using set_invariant_inputs
m_inputs_desc.clear();
}
if (m_output_desc.size() < get_output_size()) {
// case when we added outputs using set_body_outputs
m_output_desc.clear();
}
// propagate shapes and types from bodies
std::unordered_map<size_t, PartialShape> shape_map;
std::unordered_map<size_t, element::Type> type_map;
for (size_t i = 0; i < m_bodies.size(); ++i) {
auto body = get_function(i);
// If body doesn't exist skip the validation
if (!body)
continue;
validate_and_infer_type_body(get_function(i), m_input_descriptions[i]);
auto outputs_map = get_mapping_outputs_on_body_description(m_output_descriptions[i]);
for (const auto& item : outputs_map) {
auto output_index = item.first;
auto desc = item.second;
auto node_result = m_bodies[i]->get_results().at(desc->m_body_value_index)->input_value(0);
auto pshape = PartialShape::dynamic();
if (shape_map.count(output_index)) {
pshape = shape_map.at(output_index);
}
if (PartialShape::merge_into(pshape, node_result.get_partial_shape())) {
shape_map[output_index] = pshape;
} else {
shape_map[output_index] = PartialShape::dynamic();
}
auto type = element::dynamic;
if (type_map.count(output_index)) {
type = type_map.at(output_index);
}
if (element::Type::merge(type, type, node_result.get_element_type())) {
type_map[output_index] = type;
} else {
type_map[output_index] = element::dynamic;
}
}
}
for (const auto& item : shape_map) {
auto output_index = item.first;
NODE_VALIDATION_CHECK(this,
type_map.count(output_index) != 0,
"Type map must contain same outputs as shape map");
set_output_type(output_index, type_map.at(output_index), item.second);
}
// Save initial inputs descriptors
bool initialize_input_desc = m_inputs_desc.empty();
bool reset_output_shape_to_dynamic = false;
@ -93,5 +179,21 @@ void ov::op::util::FrameworkNode::validate_and_infer_types() {
}
}
bool ov::op::util::FrameworkNode::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("framework_node_attrs", m_attrs);
visitor.on_attribute("num_bodies", m_num_bodies);
m_bodies.resize(m_num_bodies);
m_input_descriptions.resize(m_num_bodies);
m_output_descriptions.resize(m_num_bodies);
for (size_t i = 0; i < m_num_bodies; ++i) {
visitor.on_attribute("body" + std::to_string(i), m_bodies[i]);
visitor.on_attribute("input_descriptions" + std::to_string(i), m_input_descriptions[i]);
visitor.on_attribute("output_descriptions" + std::to_string(i), m_output_descriptions[i]);
}
return true;
}
ov::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>::AttributeAdapter(ov::op::util::FrameworkNodeAttrs& value)
: DirectValueAccessor<ov::op::util::FrameworkNodeAttrs>(value) {}

View File

@ -146,3 +146,41 @@ ov::Output<ov::Node> ov::op::util::MultiSubGraphOp::set_body_outputs(const Resul
validate_and_infer_types();
return Output<Node>(shared_from_this(), output_index);
}
void ov::op::util::MultiSubGraphOp::validate_and_infer_type_body(
const std::shared_ptr<ov::Model>& body,
const ov::op::util::MultiSubGraphOp::MultiSubgraphInputDescriptionVector& input_descriptors) {
const auto& params = body->get_parameters();
for (const auto& input_description : input_descriptors) {
auto index = input_description->m_input_index;
auto body_parameter = params.at(input_description->m_body_parameter_index);
auto input_partial_shape = input_value(index).get_partial_shape();
body_parameter->set_partial_shape(input_partial_shape);
}
body->validate_nodes_and_infer_types();
}
ov::op::util::MultiSubGraphOp::OutputMap ov::op::util::MultiSubGraphOp::get_mapping_outputs_on_body_description(
const ov::op::util::MultiSubGraphOp::MultiSubgraphOutputDescriptionVector& output_descriptors) {
OutputMap outputs_map = OutputMap();
std::unordered_set<int64_t> checked_results_in_body;
for (const auto& output_description : output_descriptors) {
const auto& out_index = output_description->m_output_index;
const auto& internal_result_index = output_description->m_body_value_index;
NODE_VALIDATION_CHECK(this,
checked_results_in_body.count(internal_result_index) == 0,
"Incorrect associating in body! Result ",
internal_result_index,
" is already associated with another output!");
NODE_VALIDATION_CHECK(this,
outputs_map.count(out_index) == 0,
"Incorrect associating in body! Several results try to "
"associate with the same output!");
checked_results_in_body.insert(internal_result_index);
outputs_map.insert({out_index, output_description});
}
return outputs_map;
}

View File

@ -388,6 +388,33 @@ public:
}
}
}
if (!is_body_target) {
std::string id = "input_descriptions";
std::string od = "output_descriptions";
const auto& id_pos = name.find("input_descriptions");
const auto& od_pos = name.find("output_descriptions");
auto id_str = name;
size_t body_id;
if (id_pos != std::string::npos) {
id_str.erase(id_pos, id.length());
std::stoi(id_str, &body_id);
is_body_target = true;
} else if (od_pos != std::string::npos) {
id_str.erase(od_pos, od.length());
std::stoi(id_str, &body_id);
is_body_target = true;
}
if (is_body_target) {
auto body_name = "body" + id_str;
if (m_xml_node.parent().child(body_name.c_str())) {
bnames = BodyTargetNames{body_name,
"port_map" + id_str,
{"input_descriptions" + id_str, "output_descriptions" + id_str}};
} else {
is_body_target = false;
}
}
}
if (is_body_target) {
auto body_name = std::get<0>(bnames);
auto portmap_name = std::get<1>(bnames);
@ -503,7 +530,8 @@ public:
m_xml_node.append_attribute(name.c_str()).set_value(create_atribute_list(adapter).c_str());
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<Function>>& adapter) override {
if (name == "body" || name == "then_body" || name == "else_body") {
if (name.find("body") != std::string::npos) {
// name that contains subgraphs: body{n}, then_body, else_body
// TI, Loop do not have attributtes as regular ops, it is necessary to append "body"
// to layer above (m_xml_node.parent()) as in ngfunction_2_ir() layer (m_xml_node) with empty attributes
// is removed.

View File

@ -60,7 +60,7 @@ TEST_F(CustomOpsSerializationTest, CustomOpNoExtensions) {
</output>
</layer>
<layer name="operation" id="1" type="Template" version="custom_opset">
<data add="11"/>
<data num_bodies="0" add="11"/>
<input>
<port id="1" precision="FP32">
<dim>2</dim>

View File

@ -0,0 +1,66 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/util/framework_node.hpp"
#include "common_test_utils/graph_comparator.hpp"
#include "gtest/gtest.h"
#include "openvino/op/util/attr_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset10;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;
TEST(attributes, framework_node_op) {
NodeBuilder::get_ops().register_factory<op::util::FrameworkNode>();
auto X = make_shared<Parameter>(element::f32, Shape{1, 2, 2});
auto Y = make_shared<Parameter>(element::f32, Shape{1, 2, 2});
auto cond = make_shared<Constant>(element::boolean, Shape{1}, true);
auto cond2 = make_shared<Constant>(element::boolean, Shape{1}, false);
auto Xt = make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto Yt = make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto Xe = make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto Ye = make_shared<Parameter>(element::f32, PartialShape::dynamic());
auto then_op = make_shared<Multiply>(Xt, Yt);
auto res0 = make_shared<Result>(then_op);
auto res1 = make_shared<Result>(Xe);
auto body1 = make_shared<Model>(OutputVector{res0}, ParameterVector{Xt, Yt});
auto body2 = make_shared<Model>(OutputVector{res1}, ParameterVector{Xe});
auto fn_op = make_shared<op::util::FrameworkNode>(OutputVector{cond}, 0, 2);
// Add attributes
auto attrs = op::util::FrameworkNodeAttrs();
attrs.set_type_name("some_type");
fn_op->set_attrs(attrs);
fn_op->set_function(0, body1);
fn_op->set_function(1, body2);
fn_op->set_invariant_inputs(X, {Xt, Xe});
fn_op->set_invariant_inputs(Y, {Yt, nullptr});
auto out = fn_op->set_body_outputs({res0, res1});
fn_op->validate_and_infer_types();
EXPECT_EQ(fn_op->inputs().size(), 3);
EXPECT_EQ(fn_op->outputs().size(), 1);
NodeBuilder builder(fn_op);
auto g_fn = ov::as_type_ptr<op::util::FrameworkNode>(builder.create());
EXPECT_EQ(g_fn->get_attrs(), fn_op->get_attrs());
EXPECT_EQ(g_fn->get_input_descriptions(0), fn_op->get_input_descriptions(0));
EXPECT_EQ(g_fn->get_input_descriptions(1), fn_op->get_input_descriptions(1));
EXPECT_EQ(g_fn->get_output_descriptions(0), fn_op->get_output_descriptions(0));
EXPECT_EQ(g_fn->get_output_descriptions(1), fn_op->get_output_descriptions(1));
auto comparator = FunctionsComparator::with_default();
ASSERT_TRUE(g_fn->get_function(0));
auto res = comparator.compare(g_fn->get_function(0), fn_op->get_function(0));
EXPECT_TRUE(res.valid) << res.message;
ASSERT_TRUE(g_fn->get_function(1));
res = comparator.compare(g_fn->get_function(1), fn_op->get_function(1));
EXPECT_TRUE(res.valid) << res.message;
}

View File

@ -11,6 +11,7 @@
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/factory.hpp"
#include "ngraph/op/util/framework_node.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/runtime/host_tensor.hpp"
@ -120,6 +121,9 @@ public:
virtual operator std::shared_ptr<Variable>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator ov::op::util::FrameworkNodeAttrs&() {
NGRAPH_CHECK(false, "Invalid type access");
}
uint64_t get_index() {
return m_index;
}
@ -224,6 +228,8 @@ public:
a->set(m_values.get<ov::Dimension>(name));
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<Variable>>>(&adapter)) {
a->set(m_values.get<std::shared_ptr<Variable>>(name));
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
a->set(m_values.get<ov::op::util::FrameworkNodeAttrs>(name));
} else {
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
}
@ -310,6 +316,8 @@ public:
m_values.insert(name, a->get());
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<Variable>>>(&adapter)) {
m_values.insert(name, a->get());
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ov::op::util::FrameworkNodeAttrs>>(&adapter)) {
m_values.insert(name, a->get());
} else {
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
}