[nGraph Transformations] Add missing rtinfo copy in ConvertPrecision transformation (#20347)

This commit is contained in:
Nadezhda Ageeva 2023-10-10 17:47:52 +04:00 committed by GitHub
parent 3ebadc14d1
commit 4426486e6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 31 deletions

View File

@ -1172,6 +1172,7 @@ bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
new_const->validate_and_infer_types();
new_const->set_friendly_name(constant->get_friendly_name());
ov::copy_runtime_info(constant, new_const);
return true;
}
return false;

View File

@ -99,10 +99,11 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -131,8 +132,10 @@ TEST(TransformationTests, ConvertPrecision_NMS5) {
pass::Manager manager;
static const precisions_map precisions = {{element::i64, element::i32}, {element::f32, element::f16}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f32>(f));
}
@ -154,8 +157,10 @@ TEST(TransformationTests, ConvertPrecision_MatrixNms) {
pass::Manager manager;
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -177,8 +182,10 @@ TEST(TransformationTests, ConvertPrecision_MulticlassNms) {
pass::Manager manager;
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -195,10 +202,11 @@ TEST(TransformationTests, ConvertPrecision_ShapeOf) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -217,10 +225,11 @@ TEST(TransformationTests, ConvertPrecision_Range) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -238,10 +247,11 @@ TEST(TransformationTests, ConvertPrecision_ConstantRelu) {
static const precisions_map precisions = {{element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -258,10 +268,11 @@ TEST(TransformationTests, ConvertPrecision_Convert) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
}
@ -276,6 +287,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
f = std::make_shared<Model>(NodeVector{convert}, ParameterVector{input});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f16, element::f32}});
manager.run_passes(f);
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
@ -287,7 +299,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
f_ref = std::make_shared<Model>(NodeVector{relu}, ParameterVector{input});
}
ASSERT_NO_THROW(check_rt_info(f));
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
@ -305,10 +317,11 @@ TEST(TransformationTests, ConvertPrecision_TopK) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
}
@ -325,10 +338,11 @@ TEST(TransformationTests, ConvertPrecision_Unique10) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<ov::pass::ConvertPrecision>(precisions);
manager.run_passes(model);
}
ASSERT_NO_THROW(check_rt_info(model));
ASSERT_EQ(model->outputs().size(), 4);
EXPECT_EQ(model->outputs()[0].get_element_type(), element::f32);
EXPECT_EQ(model->outputs()[1].get_element_type(), element::i32);
@ -353,10 +367,11 @@ TEST(TransformationTests, ConvertPrecision_NonZero) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
}
@ -374,10 +389,11 @@ TEST(TransformationTests, ConvertPrecision_Bucketize) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
}
@ -404,6 +420,7 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
@ -413,7 +430,7 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
ASSERT_EQ(casted_end->cast_vector<int32_t>(),
std::vector<int32_t>({max_int32, max_int32, max_int32, max_int32}));
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
}
@ -461,9 +478,11 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(tensor_iterator->get_body()));
@ -484,10 +503,11 @@ TEST(TransformationTests, ConvertPrecision_Equal) {
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -506,10 +526,11 @@ TEST(TransformationTests, ConvertPrecision_NotEqual) {
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -528,10 +549,11 @@ TEST(TransformationTests, ConvertPrecision_Greater) {
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -550,10 +572,11 @@ TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -572,10 +595,11 @@ TEST(TransformationTests, ConvertPrecision_Less) {
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -594,10 +618,11 @@ TEST(TransformationTests, ConvertPrecision_LessEqual) {
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -613,10 +638,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
}
@ -631,10 +657,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalOr) {
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
}
@ -649,10 +676,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalXor) {
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
}
@ -666,10 +694,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
@ -692,10 +721,11 @@ TEST(TransformationTests, ConvertPrecision_Select) {
f = std::make_shared<Model>(OutputVector{select}, ParameterVector{input1});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
}
@ -710,11 +740,12 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
f = std::make_shared<Model>(OutputVector{select}, ParameterVector{input1});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::i32}});
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::i32, element::i64}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_FALSE(has_type<element::Type_t::i32>(f));
ASSERT_TRUE(has_type<element::Type_t::i64>(f));
@ -732,10 +763,12 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
f = std::make_shared<Model>(OutputVector{type_relaxed}, ParameterVector{input1});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::i32}});
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::i32, element::i64}});
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
ASSERT_FALSE(has_type<element::Type_t::i32>(f));
ASSERT_TRUE(has_type<element::Type_t::i64>(f));
@ -758,10 +791,11 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
f = std::make_shared<Model>(NodeVector{mul}, ParameterVector{inp});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f16, element::f32}});
manager.run_passes(f);
}
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
}
@ -789,12 +823,13 @@ TEST(TransformationTests, ConvertPrecision_skip_precision_sensitive) {
pass::Manager manager;
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
manager.run_passes(model);
}
ASSERT_NO_THROW(check_rt_info(model));
ASSERT_TRUE(has_type<element::Type_t::f32>(model));
ASSERT_TRUE(interpolate->input_value(2).get_element_type() == element::Type_t::f32);
}
@ -823,12 +858,13 @@ TEST(TransformationTests, ConvertPrecision_without_keep_precision_sensitive_in_f
pass::Manager manager;
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = false;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
manager.run_passes(model);
}
ASSERT_NO_THROW(check_rt_info(model));
ASSERT_FALSE(has_type<element::Type_t::f32>(model));
ASSERT_TRUE(interpolate->input_value(2).get_element_type() == element::Type_t::f16);
}
@ -847,6 +883,7 @@ TEST(TransformationTests, ConvertPrecision_check_marking_does_not_leak_in_trivia
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
@ -863,7 +900,8 @@ TEST(TransformationTests, ConvertPrecision_check_marking_does_not_leak_in_trivia
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
const auto res = fc.compare(model, model_ref);
ASSERT_TRUE(res.valid) << res.message;
}
@ -888,6 +926,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_1) {
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
@ -909,7 +948,8 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_1) {
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
const auto res = fc.compare(model, model_ref);
ASSERT_TRUE(res.valid) << res.message;
}
@ -943,6 +983,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_2) {
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
@ -973,7 +1014,8 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_2) {
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
const auto res = fc.compare(model, model_ref);
ASSERT_TRUE(res.valid) << res.message;
}
@ -1024,6 +1066,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_3) {
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = true;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
@ -1071,7 +1114,8 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_3) {
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES);
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
const auto res = fc.compare(model, model_ref);
ASSERT_TRUE(res.valid) << res.message;
}
@ -1102,12 +1146,13 @@ TEST(TransformationTests, ConvertCompressedToMixedPrecission_do_not_keep_in_fp32
pass::Manager manager;
type_to_fuse_map empty_type_to_fuse_map = {};
bool keep_precision_sensitive_in_fp32 = false; // didn't keep in FP32 intentionally
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
empty_type_to_fuse_map,
keep_precision_sensitive_in_fp32);
manager.run_passes(model);
}
ASSERT_NO_THROW(check_rt_info(model));
ASSERT_FALSE(has_type<element::Type_t::f32>(model));
ASSERT_TRUE(interpolate->input_value(2).get_element_type() == element::Type_t::f16);
ASSERT_TRUE(interpolate->output(0).get_partial_shape() == PartialShape({1, 3, 287, 511}));