Allow compatible shapes in FW Node, not only equal (#15363)

* Allow compatible shapes in FW Node, not only equal

* Add tests

* Add test for type
This commit is contained in:
Maxim Vafin 2023-01-30 15:38:35 +01:00 committed by GitHub
parent a19d50d4d2
commit ea0183359f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 7 deletions

View File

@ -149,17 +149,15 @@ void ov::op::util::FrameworkNode::validate_and_infer_types() {
rank.get_length() == orig_input_pshape.rank().get_length()) {
for (int64_t dim = 0; dim < rank.get_length(); ++dim) {
NODE_VALIDATION_CHECK(this,
input_pshape[dim].is_dynamic() ||
(orig_input_pshape[dim].is_static() &&
orig_input_pshape[dim].get_length() == input_pshape[dim].get_length()),
orig_input_pshape[dim].compatible(input_pshape[dim]),
get_error_message());
}
reset_output_shape_to_dynamic = true;
} else {
NODE_VALIDATION_CHECK(this,
std::get<0>(m_inputs_desc[i]).compatible(input_pshape) &&
std::get<1>(m_inputs_desc[i]).compatible(input_type),
get_error_message());
NODE_VALIDATION_CHECK(
this,
orig_input_pshape.compatible(input_pshape) && std::get<1>(m_inputs_desc[i]).compatible(input_type),
get_error_message());
}
}
}

View File

@ -50,3 +50,55 @@ TEST(type_prop, framework_node) {
ASSERT_THROW(f_node->validate_and_infer_types(), ov::Exception);
}
TEST(type_prop, dynamic_framework_node_with_dynamic_input) {
auto param = std::make_shared<ov::opset8::Parameter>(ov::element::dynamic, ov::PartialShape::dynamic());
auto f_node = std::make_shared<ov::op::util::FrameworkNode>(ov::OutputVector{param});
// Set partially dynamic shape
param->set_partial_shape(ov::PartialShape{ov::Dimension::dynamic(), 64});
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
// Set dynamic shape with static rank
param->set_partial_shape(ov::PartialShape::dynamic(2));
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
// Set static shape
param->set_partial_shape(ov::PartialShape({1, 64}));
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
// Set static type
param->set_element_type(ov::element::f32);
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_element_type(0), ov::element::dynamic);
}
TEST(type_prop, dynamic_framework_node_with_static_rank) {
auto param = std::make_shared<ov::opset8::Parameter>(ov::element::dynamic, ov::PartialShape::dynamic(2));
auto f_node = std::make_shared<ov::op::util::FrameworkNode>(ov::OutputVector{param});
// Set partially dynamic shape
param->set_partial_shape(ov::PartialShape{ov::Dimension::dynamic(), 64});
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
// Set static shape
param->set_partial_shape(ov::PartialShape({1, 64}));
param->validate_and_infer_types();
ASSERT_NO_THROW(f_node->validate_and_infer_types());
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
}