Allow compatible shapes in FW Node, not only equal (#15363)
* Allow compatible shapes in FW Node, not only equal * Add tests * Add test for type
This commit is contained in:
parent
a19d50d4d2
commit
ea0183359f
@ -149,17 +149,15 @@ void ov::op::util::FrameworkNode::validate_and_infer_types() {
|
||||
rank.get_length() == orig_input_pshape.rank().get_length()) {
|
||||
for (int64_t dim = 0; dim < rank.get_length(); ++dim) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_pshape[dim].is_dynamic() ||
|
||||
(orig_input_pshape[dim].is_static() &&
|
||||
orig_input_pshape[dim].get_length() == input_pshape[dim].get_length()),
|
||||
orig_input_pshape[dim].compatible(input_pshape[dim]),
|
||||
get_error_message());
|
||||
}
|
||||
reset_output_shape_to_dynamic = true;
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
std::get<0>(m_inputs_desc[i]).compatible(input_pshape) &&
|
||||
std::get<1>(m_inputs_desc[i]).compatible(input_type),
|
||||
get_error_message());
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
orig_input_pshape.compatible(input_pshape) && std::get<1>(m_inputs_desc[i]).compatible(input_type),
|
||||
get_error_message());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -50,3 +50,55 @@ TEST(type_prop, framework_node) {
|
||||
|
||||
ASSERT_THROW(f_node->validate_and_infer_types(), ov::Exception);
|
||||
}
|
||||
|
||||
TEST(type_prop, dynamic_framework_node_with_dynamic_input) {
|
||||
auto param = std::make_shared<ov::opset8::Parameter>(ov::element::dynamic, ov::PartialShape::dynamic());
|
||||
auto f_node = std::make_shared<ov::op::util::FrameworkNode>(ov::OutputVector{param});
|
||||
|
||||
// Set partially dynamic shape
|
||||
param->set_partial_shape(ov::PartialShape{ov::Dimension::dynamic(), 64});
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
|
||||
|
||||
// Set dynamic shape with static rank
|
||||
param->set_partial_shape(ov::PartialShape::dynamic(2));
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
|
||||
|
||||
// Set static shape
|
||||
param->set_partial_shape(ov::PartialShape({1, 64}));
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
|
||||
|
||||
// Set static type
|
||||
param->set_element_type(ov::element::f32);
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_element_type(0), ov::element::dynamic);
|
||||
}
|
||||
|
||||
TEST(type_prop, dynamic_framework_node_with_static_rank) {
|
||||
auto param = std::make_shared<ov::opset8::Parameter>(ov::element::dynamic, ov::PartialShape::dynamic(2));
|
||||
auto f_node = std::make_shared<ov::op::util::FrameworkNode>(ov::OutputVector{param});
|
||||
|
||||
// Set partially dynamic shape
|
||||
param->set_partial_shape(ov::PartialShape{ov::Dimension::dynamic(), 64});
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
|
||||
|
||||
// Set static shape
|
||||
param->set_partial_shape(ov::PartialShape({1, 64}));
|
||||
param->validate_and_infer_types();
|
||||
|
||||
ASSERT_NO_THROW(f_node->validate_and_infer_types());
|
||||
ASSERT_EQ(f_node->get_output_partial_shape(0), ov::PartialShape::dynamic());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user