Add add_output API for ov Function (#7980)

* Add add_output API for ov Function

* Fixed comments
This commit is contained in:
Ilya Churaev 2021-10-19 14:07:48 +03:00 committed by GitHub
parent 7ee322be90
commit 4964ec890c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 342 additions and 134 deletions

View File

@ -107,6 +107,10 @@ public:
ov::Output<const ov::Node> input(size_t i) const; ov::Output<const ov::Node> input(size_t i) const;
ov::Output<const ov::Node> input(const std::string& tensor_name) const; ov::Output<const ov::Node> input(const std::string& tensor_name) const;
void add_output(const std::string& tensor_name);
void add_output(const std::string& op_name, size_t output_idx);
void add_output(const ov::Output<ov::Node>& port);
void reshape(const std::map<std::string, ov::PartialShape>& partial_shapes); void reshape(const std::map<std::string, ov::PartialShape>& partial_shapes);
/// Return the element type of output i /// Return the element type of output i

View File

@ -818,3 +818,50 @@ void ov::Function::reshape(const std::map<std::string, ov::PartialShape>& partia
throw ex; throw ex;
} }
} }
void ov::Function::add_output(const std::string& tensor_name) {
for (const auto& op : get_ops()) {
if (ov::op::util::is_output(op))
continue;
for (const auto& output : op->outputs()) {
const auto& names = output.get_tensor().get_names();
if (names.find(tensor_name) != names.end()) {
add_output(output);
return;
}
}
}
throw ov::Exception("Tensor name " + tensor_name + " was not found.");
}
void ov::Function::add_output(const std::string& op_name, size_t output_idx) {
for (const auto& op : get_ops()) {
if (op->get_friendly_name() == op_name) {
OPENVINO_ASSERT(output_idx < op->get_output_size(),
"Cannot add output to port ",
std::to_string(output_idx),
" operation ",
op->get_friendly_name(),
" has only ",
std::to_string(op->get_output_size()),
" outputs.");
add_output(op->output(output_idx));
return;
}
}
throw ov::Exception("Port " + std::to_string(output_idx) + " for operation with name " + op_name +
" was not found.");
}
void ov::Function::add_output(const ov::Output<ov::Node>& port) {
if (ov::op::util::is_output(port.get_node()))
return;
for (const auto& input : port.get_target_inputs()) {
// Do not add result if port is already connected with result
if (ov::op::util::is_output(input.get_node())) {
return;
}
}
auto result = std::make_shared<ov::op::v0::Result>(port);
add_results({result});
}

View File

@ -21,9 +21,9 @@ TEST(function, get_input_by_tensor_name) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto input = f->input("input"); auto input = f->input("input");
ASSERT_EQ(input.get_node(), arg0.get()); EXPECT_EQ(input.get_node(), arg0.get());
ASSERT_EQ(input.get_element_type(), ov::element::f32); EXPECT_EQ(input.get_element_type(), ov::element::f32);
ASSERT_EQ(input.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(input.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_output_by_tensor_name) { TEST(function, get_output_by_tensor_name) {
@ -40,12 +40,12 @@ TEST(function, get_output_by_tensor_name) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto output = f->output("relu_t"); auto output = f->output("relu_t");
ASSERT_EQ(output.get_tensor().get_names().size(), 2); EXPECT_EQ(output.get_tensor().get_names().size(), 2);
ASSERT_EQ(output.get_tensor().get_names(), out_names); EXPECT_EQ(output.get_tensor().get_names(), out_names);
ASSERT_EQ(output.get_node(), result.get()); EXPECT_EQ(output.get_node(), result.get());
ASSERT_EQ(f->output("identity"), output); EXPECT_EQ(f->output("identity"), output);
ASSERT_EQ(output.get_element_type(), ov::element::f32); EXPECT_EQ(output.get_element_type(), ov::element::f32);
ASSERT_EQ(output.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(output.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_incorrect_output_by_tensor_name) { TEST(function, get_incorrect_output_by_tensor_name) {
@ -59,7 +59,7 @@ TEST(function, get_incorrect_output_by_tensor_name) {
auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->output("input"), ov::Exception); EXPECT_THROW(f->output("input"), ov::Exception);
} }
TEST(function, get_incorrect_input_by_tensor_name) { TEST(function, get_incorrect_input_by_tensor_name) {
@ -73,7 +73,7 @@ TEST(function, get_incorrect_input_by_tensor_name) {
auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->input("relu_t"), ov::Exception); EXPECT_THROW(f->input("relu_t"), ov::Exception);
} }
TEST(function, get_input_by_index) { TEST(function, get_input_by_index) {
@ -88,9 +88,9 @@ TEST(function, get_input_by_index) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto input = f->input(0); auto input = f->input(0);
ASSERT_EQ(input.get_node(), arg0.get()); EXPECT_EQ(input.get_node(), arg0.get());
ASSERT_EQ(input.get_element_type(), ov::element::f32); EXPECT_EQ(input.get_element_type(), ov::element::f32);
ASSERT_EQ(input.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(input.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_output_by_index) { TEST(function, get_output_by_index) {
@ -106,9 +106,9 @@ TEST(function, get_output_by_index) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto output = f->output(0); auto output = f->output(0);
ASSERT_EQ(output.get_node(), result.get()); EXPECT_EQ(output.get_node(), result.get());
ASSERT_EQ(output.get_element_type(), ov::element::f32); EXPECT_EQ(output.get_element_type(), ov::element::f32);
ASSERT_EQ(output.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(output.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_input_without_index) { TEST(function, get_input_without_index) {
@ -123,9 +123,9 @@ TEST(function, get_input_without_index) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto input = f->input(); auto input = f->input();
ASSERT_EQ(input.get_node(), arg0.get()); EXPECT_EQ(input.get_node(), arg0.get());
ASSERT_EQ(input.get_element_type(), ov::element::f32); EXPECT_EQ(input.get_element_type(), ov::element::f32);
ASSERT_EQ(input.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(input.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_output_without_index) { TEST(function, get_output_without_index) {
@ -141,9 +141,9 @@ TEST(function, get_output_without_index) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto output = f->output(); auto output = f->output();
ASSERT_EQ(output.get_node(), result.get()); EXPECT_EQ(output.get_node(), result.get());
ASSERT_EQ(output.get_element_type(), ov::element::f32); EXPECT_EQ(output.get_element_type(), ov::element::f32);
ASSERT_EQ(output.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(output.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_incorrect_output_by_index) { TEST(function, get_incorrect_output_by_index) {
@ -157,7 +157,7 @@ TEST(function, get_incorrect_output_by_index) {
auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->output(2), std::exception); EXPECT_THROW(f->output(2), std::exception);
} }
TEST(function, get_incorrect_input_by_index) { TEST(function, get_incorrect_input_by_index) {
@ -171,7 +171,7 @@ TEST(function, get_incorrect_input_by_index) {
auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->input(2), std::exception); EXPECT_THROW(f->input(2), std::exception);
} }
TEST(function, incorrect_multiple_inputs_outputs_function) { TEST(function, incorrect_multiple_inputs_outputs_function) {
@ -196,8 +196,8 @@ TEST(function, incorrect_multiple_inputs_outputs_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->input(), ov::Exception); EXPECT_THROW(f->input(), ov::Exception);
ASSERT_THROW(f->output(), ov::Exception); EXPECT_THROW(f->output(), ov::Exception);
} }
TEST(function, multiple_inputs_outputs_function) { TEST(function, multiple_inputs_outputs_function) {
@ -225,28 +225,28 @@ TEST(function, multiple_inputs_outputs_function) {
auto input1 = f->input(0); auto input1 = f->input(0);
auto input2 = f->input("data1"); auto input2 = f->input("data1");
ASSERT_NE(input1, input2); EXPECT_NE(input1, input2);
ASSERT_EQ(input1, f->input("input1")); EXPECT_EQ(input1, f->input("input1"));
ASSERT_EQ(input2, f->input("input2")); EXPECT_EQ(input2, f->input("input2"));
ASSERT_EQ(input2, f->input(1)); EXPECT_EQ(input2, f->input(1));
ASSERT_EQ(input1.get_node(), arg0.get()); EXPECT_EQ(input1.get_node(), arg0.get());
ASSERT_EQ(input2.get_node_shared_ptr(), arg1); EXPECT_EQ(input2.get_node_shared_ptr(), arg1);
auto output1 = f->output(0); auto output1 = f->output(0);
auto output2 = f->output("shape_of_t"); auto output2 = f->output("shape_of_t");
ASSERT_NE(output1, output2); EXPECT_NE(output1, output2);
ASSERT_EQ(output1, f->output("concat_t")); EXPECT_EQ(output1, f->output("concat_t"));
ASSERT_EQ(output2, f->output("identity")); EXPECT_EQ(output2, f->output("identity"));
ASSERT_EQ(output2, f->output(1)); EXPECT_EQ(output2, f->output(1));
ASSERT_EQ(arg0.get(), f->input(0).get_node()); EXPECT_EQ(arg0.get(), f->input(0).get_node());
ASSERT_EQ(arg1.get(), f->input(1).get_node()); EXPECT_EQ(arg1.get(), f->input(1).get_node());
ASSERT_EQ(result1.get(), f->output(0).get_node()); EXPECT_EQ(result1.get(), f->output(0).get_node());
ASSERT_EQ(result2.get(), f->output(1).get_node()); EXPECT_EQ(result2.get(), f->output(1).get_node());
ASSERT_EQ(output1, result1); EXPECT_EQ(output1, result1);
ASSERT_EQ(output2, result2); EXPECT_EQ(output2, result2);
ASSERT_EQ(f->inputs().size(), 2); EXPECT_EQ(f->inputs().size(), 2);
ASSERT_EQ(f->outputs().size(), 2); EXPECT_EQ(f->outputs().size(), 2);
} }
TEST(function, DISABLED_create_function_with_incorrect_tensor_names) { TEST(function, DISABLED_create_function_with_incorrect_tensor_names) {
@ -258,7 +258,7 @@ TEST(function, DISABLED_create_function_with_incorrect_tensor_names) {
relu->set_friendly_name("relu"); relu->set_friendly_name("relu");
relu->get_output_tensor(0).set_names({"input"}); relu->get_output_tensor(0).set_names({"input"});
auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<ov::Function>(relu, ov::ParameterVector{arg0});
ASSERT_THROW(f->validate_nodes_and_infer_types(), ov::Exception); EXPECT_THROW(f->validate_nodes_and_infer_types(), ov::Exception);
} }
TEST(function, get_input_by_tensor_name_from_const) { TEST(function, get_input_by_tensor_name_from_const) {
@ -273,9 +273,9 @@ TEST(function, get_input_by_tensor_name_from_const) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto input = f->input("input"); auto input = f->input("input");
ASSERT_EQ(input.get_node(), arg0.get()); EXPECT_EQ(input.get_node(), arg0.get());
ASSERT_EQ(input.get_element_type(), ov::element::f32); EXPECT_EQ(input.get_element_type(), ov::element::f32);
ASSERT_EQ(input.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(input.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_output_by_tensor_name_from_const_function) { TEST(function, get_output_by_tensor_name_from_const_function) {
@ -292,12 +292,12 @@ TEST(function, get_output_by_tensor_name_from_const_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto output = f->output("relu_t"); auto output = f->output("relu_t");
ASSERT_EQ(output.get_tensor().get_names().size(), 2); EXPECT_EQ(output.get_tensor().get_names().size(), 2);
ASSERT_EQ(output.get_tensor().get_names(), out_names); EXPECT_EQ(output.get_tensor().get_names(), out_names);
ASSERT_EQ(output.get_node(), result.get()); EXPECT_EQ(output.get_node(), result.get());
ASSERT_EQ(f->output("identity"), output); EXPECT_EQ(f->output("identity"), output);
ASSERT_EQ(output.get_element_type(), ov::element::f32); EXPECT_EQ(output.get_element_type(), ov::element::f32);
ASSERT_EQ(output.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(output.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_incorrect_output_by_tensor_name_from_const_function) { TEST(function, get_incorrect_output_by_tensor_name_from_const_function) {
@ -311,7 +311,7 @@ TEST(function, get_incorrect_output_by_tensor_name_from_const_function) {
auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->output("input"), ov::Exception); EXPECT_THROW(f->output("input"), ov::Exception);
} }
TEST(function, get_incorrect_input_by_tensor_name_from_const_function) { TEST(function, get_incorrect_input_by_tensor_name_from_const_function) {
@ -325,7 +325,7 @@ TEST(function, get_incorrect_input_by_tensor_name_from_const_function) {
auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->input("relu_t"), ov::Exception); EXPECT_THROW(f->input("relu_t"), ov::Exception);
} }
TEST(function, get_input_by_index_from_const_function) { TEST(function, get_input_by_index_from_const_function) {
@ -340,9 +340,9 @@ TEST(function, get_input_by_index_from_const_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto input = f->input(0); auto input = f->input(0);
ASSERT_EQ(input.get_node(), arg0.get()); EXPECT_EQ(input.get_node(), arg0.get());
ASSERT_EQ(input.get_element_type(), ov::element::f32); EXPECT_EQ(input.get_element_type(), ov::element::f32);
ASSERT_EQ(input.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(input.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_output_by_index_from_const_function) { TEST(function, get_output_by_index_from_const_function) {
@ -358,9 +358,9 @@ TEST(function, get_output_by_index_from_const_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto output = f->output(0); auto output = f->output(0);
ASSERT_EQ(output.get_node(), result.get()); EXPECT_EQ(output.get_node(), result.get());
ASSERT_EQ(output.get_element_type(), ov::element::f32); EXPECT_EQ(output.get_element_type(), ov::element::f32);
ASSERT_EQ(output.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(output.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_input_without_index_from_const_function) { TEST(function, get_input_without_index_from_const_function) {
@ -375,9 +375,9 @@ TEST(function, get_input_without_index_from_const_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto input = f->input(); auto input = f->input();
ASSERT_EQ(input.get_node(), arg0.get()); EXPECT_EQ(input.get_node(), arg0.get());
ASSERT_EQ(input.get_element_type(), ov::element::f32); EXPECT_EQ(input.get_element_type(), ov::element::f32);
ASSERT_EQ(input.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(input.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_output_without_index_from_const_function) { TEST(function, get_output_without_index_from_const_function) {
@ -393,9 +393,9 @@ TEST(function, get_output_without_index_from_const_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
auto output = f->output(); auto output = f->output();
ASSERT_EQ(output.get_node(), result.get()); EXPECT_EQ(output.get_node(), result.get());
ASSERT_EQ(output.get_element_type(), ov::element::f32); EXPECT_EQ(output.get_element_type(), ov::element::f32);
ASSERT_EQ(output.get_partial_shape(), ov::PartialShape{1}); EXPECT_EQ(output.get_partial_shape(), ov::PartialShape{1});
} }
TEST(function, get_incorrect_output_by_index_from_const_function) { TEST(function, get_incorrect_output_by_index_from_const_function) {
@ -409,7 +409,7 @@ TEST(function, get_incorrect_output_by_index_from_const_function) {
auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->output(2), std::exception); EXPECT_THROW(f->output(2), std::exception);
} }
TEST(function, get_incorrect_input_by_index_from_const_function) { TEST(function, get_incorrect_input_by_index_from_const_function) {
@ -423,7 +423,7 @@ TEST(function, get_incorrect_input_by_index_from_const_function) {
auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->input(2), std::exception); EXPECT_THROW(f->input(2), std::exception);
} }
TEST(function, incorrect_multiple_inputs_outputs_function_from_const_function) { TEST(function, incorrect_multiple_inputs_outputs_function_from_const_function) {
@ -448,8 +448,8 @@ TEST(function, incorrect_multiple_inputs_outputs_function_from_const_function) {
f->validate_nodes_and_infer_types(); f->validate_nodes_and_infer_types();
ASSERT_THROW(f->input(), ov::Exception); EXPECT_THROW(f->input(), ov::Exception);
ASSERT_THROW(f->output(), ov::Exception); EXPECT_THROW(f->output(), ov::Exception);
} }
TEST(function, multiple_inputs_outputs_function_from_const_function) { TEST(function, multiple_inputs_outputs_function_from_const_function) {
@ -477,28 +477,28 @@ TEST(function, multiple_inputs_outputs_function_from_const_function) {
auto input1 = f->input(0); auto input1 = f->input(0);
auto input2 = f->input("data1"); auto input2 = f->input("data1");
ASSERT_NE(input1, input2); EXPECT_NE(input1, input2);
ASSERT_EQ(input1, f->input("input1")); EXPECT_EQ(input1, f->input("input1"));
ASSERT_EQ(input2, f->input("input2")); EXPECT_EQ(input2, f->input("input2"));
ASSERT_EQ(input2, f->input(1)); EXPECT_EQ(input2, f->input(1));
ASSERT_EQ(input1.get_node(), arg0.get()); EXPECT_EQ(input1.get_node(), arg0.get());
ASSERT_EQ(input2.get_node_shared_ptr(), arg1); EXPECT_EQ(input2.get_node_shared_ptr(), arg1);
auto output1 = f->output(0); auto output1 = f->output(0);
auto output2 = f->output("shape_of_t"); auto output2 = f->output("shape_of_t");
ASSERT_NE(output1, output2); EXPECT_NE(output1, output2);
ASSERT_EQ(output1, f->output("concat_t")); EXPECT_EQ(output1, f->output("concat_t"));
ASSERT_EQ(output2, f->output("identity")); EXPECT_EQ(output2, f->output("identity"));
ASSERT_EQ(arg0.get(), f->input(0).get_node()); EXPECT_EQ(arg0.get(), f->input(0).get_node());
ASSERT_EQ(arg1.get(), f->input(1).get_node()); EXPECT_EQ(arg1.get(), f->input(1).get_node());
ASSERT_EQ(result1.get(), f->output(0).get_node()); EXPECT_EQ(result1.get(), f->output(0).get_node());
ASSERT_EQ(result2.get(), f->output(1).get_node()); EXPECT_EQ(result2.get(), f->output(1).get_node());
ASSERT_EQ(output2, f->output(1)); EXPECT_EQ(output2, f->output(1));
ASSERT_EQ(output1.get_node(), result1.get()); EXPECT_EQ(output1.get_node(), result1.get());
ASSERT_EQ(output2.get_node(), result2.get()); EXPECT_EQ(output2.get_node(), result2.get());
ASSERT_EQ(f->inputs().size(), 2); EXPECT_EQ(f->inputs().size(), 2);
ASSERT_EQ(f->outputs().size(), 2); EXPECT_EQ(f->outputs().size(), 2);
} }
TEST(function, DISABLED_create_function_with_incorrect_tensor_names_from_const_function) { TEST(function, DISABLED_create_function_with_incorrect_tensor_names_from_const_function) {
@ -510,7 +510,7 @@ TEST(function, DISABLED_create_function_with_incorrect_tensor_names_from_const_f
relu->set_friendly_name("relu"); relu->set_friendly_name("relu");
relu->get_output_tensor(0).set_names({"input"}); relu->get_output_tensor(0).set_names({"input"});
auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0});
ASSERT_THROW(f->validate_nodes_and_infer_types(), ov::Exception); EXPECT_THROW(f->validate_nodes_and_infer_types(), ov::Exception);
} }
TEST(function_reshape, ReshapedDynamicShapeLayout) { TEST(function_reshape, ReshapedDynamicShapeLayout) {
@ -530,7 +530,7 @@ TEST(function_reshape, ReshapedDynamicShapeLayout) {
std::map<std::string, ov::PartialShape> new_shape; std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = ov::Shape{1, 3, 22, 22}; new_shape["tensor"] = ov::Shape{1, 3, 22, 22};
ASSERT_NO_THROW(ngraph->reshape(new_shape)); EXPECT_NO_THROW(ngraph->reshape(new_shape));
EXPECT_FALSE(ngraph->input().get_partial_shape().is_dynamic()); EXPECT_FALSE(ngraph->input().get_partial_shape().is_dynamic());
EXPECT_FALSE(ngraph->get_parameters().front()->get_partial_shape().is_dynamic()); EXPECT_FALSE(ngraph->get_parameters().front()->get_partial_shape().is_dynamic());
@ -552,17 +552,17 @@ TEST(function_reshape, ReshapeBatchReLU) {
ngraph = std::make_shared<ov::Function>(results, params); ngraph = std::make_shared<ov::Function>(results, params);
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{ {
std::map<std::string, ov::PartialShape> new_shape; std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor2"] = ov::PartialShape{2, 3, 22, 22}; new_shape["tensor2"] = ov::PartialShape{2, 3, 22, 22};
ASSERT_NO_THROW(ngraph->reshape(new_shape)); EXPECT_NO_THROW(ngraph->reshape(new_shape));
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({2, 3, 22, 22})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({2, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({2, 3, 22, 22})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({2, 3, 22, 22}));
} }
TEST(function_reshape, ReshapeSpatialReLU) { TEST(function_reshape, ReshapeSpatialReLU) {
@ -581,17 +581,17 @@ TEST(function_reshape, ReshapeSpatialReLU) {
ngraph = std::make_shared<ov::Function>(results, params); ngraph = std::make_shared<ov::Function>(results, params);
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{ {
std::map<std::string, ov::PartialShape> new_shape; std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = ov::PartialShape{1, 3, 25, 25}; new_shape["tensor"] = ov::PartialShape{1, 3, 25, 25};
ASSERT_NO_THROW(ngraph->reshape(new_shape)); EXPECT_NO_THROW(ngraph->reshape(new_shape));
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 25, 25})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 25, 25})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 25, 25}));
} }
TEST(function_reshape, ReshapeSpatialReLUWithoutReplaceParameter) { TEST(function_reshape, ReshapeSpatialReLUWithoutReplaceParameter) {
@ -609,16 +609,16 @@ TEST(function_reshape, ReshapeSpatialReLUWithoutReplaceParameter) {
ngraph = std::make_shared<ov::Function>(results, params); ngraph = std::make_shared<ov::Function>(results, params);
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{ {
ngraph->get_parameters()[0]->set_partial_shape({1, 3, 25, 25}); ngraph->get_parameters()[0]->set_partial_shape({1, 3, 25, 25});
ngraph->validate_nodes_and_infer_types(); ngraph->validate_nodes_and_infer_types();
} }
ASSERT_EQ(ngraph->input().get_partial_shape(), ov::Shape({1, 3, 25, 25})); EXPECT_EQ(ngraph->input().get_partial_shape(), ov::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->output().get_partial_shape(), ov::Shape({1, 3, 25, 25})); EXPECT_EQ(ngraph->output().get_partial_shape(), ov::Shape({1, 3, 25, 25}));
} }
TEST(function_reshape, ReshapeSpatialReLUStaticToDynamic) { TEST(function_reshape, ReshapeSpatialReLUStaticToDynamic) {
@ -638,19 +638,19 @@ TEST(function_reshape, ReshapeSpatialReLUStaticToDynamic) {
ngraph = std::make_shared<ov::Function>(results, params); ngraph = std::make_shared<ov::Function>(results, params);
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{ {
std::map<std::string, ov::PartialShape> new_shape; std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = refShape; new_shape["tensor"] = refShape;
ASSERT_NO_THROW(ngraph->reshape(new_shape)); EXPECT_NO_THROW(ngraph->reshape(new_shape));
} }
ASSERT_TRUE(ngraph->input(0).get_partial_shape().is_dynamic()); EXPECT_TRUE(ngraph->input(0).get_partial_shape().is_dynamic());
ASSERT_TRUE(ngraph->output(0).get_partial_shape().is_dynamic()); EXPECT_TRUE(ngraph->output(0).get_partial_shape().is_dynamic());
ASSERT_EQ(ngraph->input(0).get_partial_shape(), refShape); EXPECT_EQ(ngraph->input(0).get_partial_shape(), refShape);
ASSERT_EQ(ngraph->output(0).get_partial_shape(), refShape); EXPECT_EQ(ngraph->output(0).get_partial_shape(), refShape);
} }
TEST(function_reshape, ReshapeSpatialReLUStaticToFullyDynamic) { TEST(function_reshape, ReshapeSpatialReLUStaticToFullyDynamic) {
@ -670,19 +670,19 @@ TEST(function_reshape, ReshapeSpatialReLUStaticToFullyDynamic) {
ngraph = std::make_shared<ov::Function>(results, params); ngraph = std::make_shared<ov::Function>(results, params);
} }
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22})); EXPECT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{ {
std::map<std::string, ov::PartialShape> new_shape; std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = refShape; new_shape["tensor"] = refShape;
ASSERT_NO_THROW(ngraph->reshape(new_shape)); EXPECT_NO_THROW(ngraph->reshape(new_shape));
} }
ASSERT_TRUE(ngraph->input().get_partial_shape().is_dynamic()); EXPECT_TRUE(ngraph->input().get_partial_shape().is_dynamic());
ASSERT_TRUE(ngraph->output().get_partial_shape().is_dynamic()); EXPECT_TRUE(ngraph->output().get_partial_shape().is_dynamic());
ASSERT_EQ(ngraph->input().get_partial_shape(), refShape); EXPECT_EQ(ngraph->input().get_partial_shape(), refShape);
ASSERT_EQ(ngraph->output().get_partial_shape(), refShape); EXPECT_EQ(ngraph->output().get_partial_shape(), refShape);
} }
TEST(function_reshape, ReshapeSpatialReLUDynamicToDynamic) { TEST(function_reshape, ReshapeSpatialReLUDynamicToDynamic) {
@ -702,19 +702,19 @@ TEST(function_reshape, ReshapeSpatialReLUDynamicToDynamic) {
ngraph = std::make_shared<ov::Function>(results, params); ngraph = std::make_shared<ov::Function>(results, params);
} }
ASSERT_EQ(ngraph->input().get_partial_shape(), ov::PartialShape({1, 3, 22, ov::Dimension::dynamic()})); EXPECT_EQ(ngraph->input().get_partial_shape(), ov::PartialShape({1, 3, 22, ov::Dimension::dynamic()}));
ASSERT_EQ(ngraph->output().get_partial_shape(), ov::PartialShape({1, 3, 22, ov::Dimension::dynamic()})); EXPECT_EQ(ngraph->output().get_partial_shape(), ov::PartialShape({1, 3, 22, ov::Dimension::dynamic()}));
{ {
std::map<std::string, ov::PartialShape> new_shape; std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = refShape; new_shape["tensor"] = refShape;
ASSERT_NO_THROW(ngraph->reshape(new_shape)); EXPECT_NO_THROW(ngraph->reshape(new_shape));
} }
ASSERT_TRUE(ngraph->input().get_partial_shape().is_dynamic()); EXPECT_TRUE(ngraph->input().get_partial_shape().is_dynamic());
ASSERT_TRUE(ngraph->output().get_partial_shape().is_dynamic()); EXPECT_TRUE(ngraph->output().get_partial_shape().is_dynamic());
ASSERT_EQ(ngraph->input().get_partial_shape(), refShape); EXPECT_EQ(ngraph->input().get_partial_shape(), refShape);
ASSERT_EQ(ngraph->output().get_partial_shape(), refShape); EXPECT_EQ(ngraph->output().get_partial_shape(), refShape);
} }
TEST(function_reshape, TestInvalidReshape) { TEST(function_reshape, TestInvalidReshape) {
@ -727,12 +727,12 @@ TEST(function_reshape, TestInvalidReshape) {
f = std::make_shared<ov::Function>(ov::OutputVector{reshape}, ov::ParameterVector{input}); f = std::make_shared<ov::Function>(ov::OutputVector{reshape}, ov::ParameterVector{input});
} }
ASSERT_ANY_THROW(f->reshape({{"tensor", ov::Shape({4})}})); EXPECT_ANY_THROW(f->reshape({{"tensor", ov::Shape({4})}}));
auto param = f->get_parameters().front(); auto param = f->get_parameters().front();
ASSERT_EQ(param->get_output_shape(0), ov::Shape({1, 1000, 4})); EXPECT_EQ(param->get_output_shape(0), ov::Shape({1, 1000, 4}));
ASSERT_NO_THROW(f->reshape({{"tensor", ov::Shape({1, 1000, 4})}})); EXPECT_NO_THROW(f->reshape({{"tensor", ov::Shape({1, 1000, 4})}}));
} }
TEST(function_reshape, TestReshapeWithInvalidTensorName) { TEST(function_reshape, TestReshapeWithInvalidTensorName) {
@ -747,10 +747,10 @@ TEST(function_reshape, TestReshapeWithInvalidTensorName) {
} }
// both operation names and tensor names are specified // both operation names and tensor names are specified
ASSERT_ANY_THROW(f->reshape({{"param", ov::Shape({4, 4, 4})}, {"tensor", ov::Shape({4, 4, 4})}})); EXPECT_ANY_THROW(f->reshape({{"param", ov::Shape({4, 4, 4})}, {"tensor", ov::Shape({4, 4, 4})}}));
// operation name does not work // operation name does not work
ASSERT_ANY_THROW(f->reshape({{"param", ov::Shape({4, 4, 4})}})); EXPECT_ANY_THROW(f->reshape({{"param", ov::Shape({4, 4, 4})}}));
} }
TEST(function_reshape, TestReshapeWithInvalidShapesForTheSameTensor) { TEST(function_reshape, TestReshapeWithInvalidShapesForTheSameTensor) {
@ -765,5 +765,162 @@ TEST(function_reshape, TestReshapeWithInvalidShapesForTheSameTensor) {
} }
// both tensor names are specified, but have different shapes // both tensor names are specified, but have different shapes
ASSERT_ANY_THROW(f->reshape({{"tensor1", ov::Shape({2, 500, 4})}, {"tensor2", ov::Shape({4, 250, 4})}})); EXPECT_ANY_THROW(f->reshape({{"tensor1", ov::Shape({2, 500, 4})}, {"tensor2", ov::Shape({4, 250, 4})}}));
}
TEST(function, add_output_tensor_name) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Function>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output("relu_t1"));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_NO_THROW(f->add_output("relu_t1"));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
}
TEST(function, add_output_op_name) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Function>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output("relu1", 0));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_NO_THROW(f->add_output("relu_t1"));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_NO_THROW(f->add_output("relu2", 0));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
}
TEST(function, add_output_port) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Function>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output(relu1->output(0)));
EXPECT_EQ(f->get_results().size(), 2);
EXPECT_EQ(f->get_results()[1]->input_value(0).get_node(), relu1.get());
}
TEST(function, add_output_incorrect_tensor_name) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Function>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_THROW(f->add_output("relu"), ov::Exception);
EXPECT_EQ(f->get_results().size(), 1);
}
TEST(function, add_output_op_incorrect_name) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Function>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_THROW(f->add_output("relu_t1", 0), ov::Exception);
EXPECT_EQ(f->get_results().size(), 1);
}
TEST(function, add_output_op_name_incorrect_idx) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto f = std::make_shared<ov::Function>(relu2, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_THROW(f->add_output("relu1", 10), ov::Exception);
EXPECT_EQ(f->get_results().size(), 1);
}
TEST(function, add_output_port_to_result) {
auto arg0 = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{1});
arg0->set_friendly_name("data");
arg0->get_output_tensor(0).set_names({"input"});
auto relu1 = std::make_shared<ov::opset8::Relu>(arg0);
relu1->set_friendly_name("relu1");
relu1->get_output_tensor(0).set_names({"relu_t1"});
auto relu2 = std::make_shared<ov::opset8::Relu>(relu1);
relu2->set_friendly_name("relu2");
relu2->get_output_tensor(0).set_names({"relu_t2"});
auto result = std::make_shared<ov::opset8::Result>(relu2);
auto f = std::make_shared<ov::Function>(ov::ResultVector{result}, ov::ParameterVector{arg0});
f->validate_nodes_and_infer_types();
EXPECT_EQ(f->get_results().size(), 1);
EXPECT_NO_THROW(f->add_output(result->output(0)));
EXPECT_EQ(f->get_results().size(), 1);
} }