[nGraph Transformations] Add missing rtinfo copy in ConvertPrecision transformation (#20347)
This commit is contained in:
parent
3ebadc14d1
commit
4426486e6f
@ -1172,6 +1172,7 @@ bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
|
||||
|
||||
new_const->validate_and_infer_types();
|
||||
new_const->set_friendly_name(constant->get_friendly_name());
|
||||
ov::copy_runtime_info(constant, new_const);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -99,10 +99,11 @@ TEST(TransformationTests, ConvertPrecision_NMS4) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -131,8 +132,10 @@ TEST(TransformationTests, ConvertPrecision_NMS5) {
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f32, element::f16}};
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f32>(f));
|
||||
}
|
||||
@ -154,8 +157,10 @@ TEST(TransformationTests, ConvertPrecision_MatrixNms) {
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -177,8 +182,10 @@ TEST(TransformationTests, ConvertPrecision_MulticlassNms) {
|
||||
|
||||
pass::Manager manager;
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -195,10 +202,11 @@ TEST(TransformationTests, ConvertPrecision_ShapeOf) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -217,10 +225,11 @@ TEST(TransformationTests, ConvertPrecision_Range) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -238,10 +247,11 @@ TEST(TransformationTests, ConvertPrecision_ConstantRelu) {
|
||||
|
||||
static const precisions_map precisions = {{element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
@ -258,10 +268,11 @@ TEST(TransformationTests, ConvertPrecision_Convert) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
}
|
||||
@ -276,6 +287,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
|
||||
f = std::make_shared<Model>(NodeVector{convert}, ParameterVector{input});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f16, element::f32}});
|
||||
manager.run_passes(f);
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
@ -287,7 +299,7 @@ TEST(TransformationTests, ConvertPrecision_ConvertElimination) {
|
||||
|
||||
f_ref = std::make_shared<Model>(NodeVector{relu}, ParameterVector{input});
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
@ -305,10 +317,11 @@ TEST(TransformationTests, ConvertPrecision_TopK) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
}
|
||||
@ -325,10 +338,11 @@ TEST(TransformationTests, ConvertPrecision_Unique10) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(model));
|
||||
ASSERT_EQ(model->outputs().size(), 4);
|
||||
EXPECT_EQ(model->outputs()[0].get_element_type(), element::f32);
|
||||
EXPECT_EQ(model->outputs()[1].get_element_type(), element::i32);
|
||||
@ -353,10 +367,11 @@ TEST(TransformationTests, ConvertPrecision_NonZero) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
}
|
||||
@ -374,10 +389,11 @@ TEST(TransformationTests, ConvertPrecision_Bucketize) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
}
|
||||
@ -404,6 +420,7 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
|
||||
@ -413,7 +430,7 @@ TEST(TransformationTests, ConvertPrecision_Roundings) {
|
||||
ASSERT_EQ(casted_end->cast_vector<int32_t>(),
|
||||
std::vector<int32_t>({max_int32, max_int32, max_int32, max_int32}));
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
}
|
||||
@ -461,9 +478,11 @@ TEST(TransformationTests, ConvertPrecision_TIBody) {
|
||||
|
||||
static const precisions_map precisions = {{element::i64, element::i32}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i64>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(tensor_iterator->get_body()));
|
||||
@ -484,10 +503,11 @@ TEST(TransformationTests, ConvertPrecision_Equal) {
|
||||
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
@ -506,10 +526,11 @@ TEST(TransformationTests, ConvertPrecision_NotEqual) {
|
||||
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
@ -528,10 +549,11 @@ TEST(TransformationTests, ConvertPrecision_Greater) {
|
||||
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
@ -550,10 +572,11 @@ TEST(TransformationTests, ConvertPrecision_GreaterEqual) {
|
||||
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
@ -572,10 +595,11 @@ TEST(TransformationTests, ConvertPrecision_Less) {
|
||||
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
@ -594,10 +618,11 @@ TEST(TransformationTests, ConvertPrecision_LessEqual) {
|
||||
|
||||
static const precisions_map precisions = {{element::boolean, element::u8}, {element::f16, element::f32}};
|
||||
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions);
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
@ -613,10 +638,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalAnd) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
}
|
||||
@ -631,10 +657,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalOr) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
}
|
||||
@ -649,10 +676,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalXor) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1, input2});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
}
|
||||
@ -666,10 +694,11 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
|
||||
f = std::make_shared<Model>(OutputVector{node}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
|
||||
@ -692,10 +721,11 @@ TEST(TransformationTests, ConvertPrecision_Select) {
|
||||
f = std::make_shared<Model>(OutputVector{select}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::u8}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||
}
|
||||
@ -710,11 +740,12 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxedWithSelect) {
|
||||
f = std::make_shared<Model>(OutputVector{select}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::i32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::i32, element::i64}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i32>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::i64>(f));
|
||||
@ -732,10 +763,12 @@ TEST(TransformationTests, ConvertPrecision_TypeRelaxed) {
|
||||
f = std::make_shared<Model>(OutputVector{type_relaxed}, ParameterVector{input1});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::boolean, element::i32}});
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::i32, element::i64}});
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::i32>(f));
|
||||
ASSERT_TRUE(has_type<element::Type_t::i64>(f));
|
||||
@ -758,10 +791,11 @@ TEST(TransformationTests, ConvertPrecision_Variables) {
|
||||
f = std::make_shared<Model>(NodeVector{mul}, ParameterVector{inp});
|
||||
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f16, element::f32}});
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f16>(f));
|
||||
}
|
||||
|
||||
@ -789,12 +823,13 @@ TEST(TransformationTests, ConvertPrecision_skip_precision_sensitive) {
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(model));
|
||||
ASSERT_TRUE(has_type<element::Type_t::f32>(model));
|
||||
ASSERT_TRUE(interpolate->input_value(2).get_element_type() == element::Type_t::f32);
|
||||
}
|
||||
@ -823,12 +858,13 @@ TEST(TransformationTests, ConvertPrecision_without_keep_precision_sensitive_in_f
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(model));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f32>(model));
|
||||
ASSERT_TRUE(interpolate->input_value(2).get_element_type() == element::Type_t::f16);
|
||||
}
|
||||
@ -847,6 +883,7 @@ TEST(TransformationTests, ConvertPrecision_check_marking_does_not_leak_in_trivia
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
@ -863,7 +900,8 @@ TEST(TransformationTests, ConvertPrecision_check_marking_does_not_leak_in_trivia
|
||||
|
||||
const auto fc = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
const auto res = fc.compare(model, model_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
@ -888,6 +926,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_1) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
@ -909,7 +948,8 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_1) {
|
||||
|
||||
const auto fc = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
const auto res = fc.compare(model, model_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
@ -943,6 +983,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_2) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
@ -973,7 +1014,8 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_2) {
|
||||
|
||||
const auto fc = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
const auto res = fc.compare(model, model_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
@ -1024,6 +1066,7 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_3) {
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
@ -1071,7 +1114,8 @@ TEST(TransformationTests, ConvertPrecision_whole_shape_subgraph_is_marked_3) {
|
||||
|
||||
const auto fc = FunctionsComparator::with_default()
|
||||
.enable(FunctionsComparator::PRECISIONS)
|
||||
.enable(FunctionsComparator::CONST_VALUES);
|
||||
.enable(FunctionsComparator::CONST_VALUES)
|
||||
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
|
||||
const auto res = fc.compare(model, model_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
@ -1102,12 +1146,13 @@ TEST(TransformationTests, ConvertCompressedToMixedPrecission_do_not_keep_in_fp32
|
||||
pass::Manager manager;
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = false; // didn't keep in FP32 intentionally
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(check_rt_info(model));
|
||||
ASSERT_FALSE(has_type<element::Type_t::f32>(model));
|
||||
ASSERT_TRUE(interpolate->input_value(2).get_element_type() == element::Type_t::f16);
|
||||
ASSERT_TRUE(interpolate->output(0).get_partial_shape() == PartialShape({1, 3, 287, 511}));
|
||||
|
Loading…
Reference in New Issue
Block a user