fix f16/f32 el type mismatch for shape subgraphs: Parameter type should not be fused for precision sensitive nodes (#19959)
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
10c3b60aac
commit
efe54362fd
@ -201,6 +201,8 @@ bool convert_function_precision(const std::shared_ptr<Model>& f,
|
||||
}
|
||||
|
||||
for (const auto& param : f->get_parameters()) {
|
||||
if (skip_precision_sensitive && fp16_compression_is_disabled(param) && has_fp16_compression)
|
||||
continue;
|
||||
is_changed |= fuse_type_to_parameter(param, precisions, convert_input_output_precision);
|
||||
}
|
||||
|
||||
|
@ -2136,3 +2136,93 @@ TEST(TransformationTests, ConvertPrecisionExplicitConvertsMultiSubgraphs) {
|
||||
const auto& results = model->get_results();
|
||||
ASSERT_EQ("if_result", results[0]->get_input_node_ptr(0)->get_friendly_name());
|
||||
}
|
||||
|
||||
TEST(TransformationTests, align_mixed_fp16_fp32_with_parameter_for_shape_1) {
|
||||
shared_ptr<Model> model, model_ref;
|
||||
pass::Manager manager;
|
||||
{
|
||||
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 224, 224});
|
||||
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});
|
||||
|
||||
auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
|
||||
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
|
||||
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
|
||||
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
|
||||
auto reshape_1 = make_shared<ov::op::v1::Reshape>(input_1, final_int_shape, false);
|
||||
|
||||
model = make_shared<Model>(NodeVector{reshape_1}, ParameterVector{input_1, shape_input});
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
{
|
||||
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f16, Shape{1, 3, 224, 224});
|
||||
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});
|
||||
|
||||
// even for FP16 compressed model shape subgraph should be kept in fp32
|
||||
auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
|
||||
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
|
||||
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
|
||||
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
|
||||
auto reshape_1 = make_shared<ov::op::v1::Reshape>(input_1, final_int_shape, false);
|
||||
|
||||
model_ref = make_shared<Model>(NodeVector{reshape_1}, ParameterVector{input_1, shape_input});
|
||||
}
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, align_mixed_fp16_fp32_with_parameter_for_shape_2) {
|
||||
shared_ptr<Model> model, model_ref;
|
||||
pass::Manager manager;
|
||||
{
|
||||
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 224, 224});
|
||||
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});
|
||||
|
||||
auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
|
||||
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
|
||||
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
|
||||
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
|
||||
auto reshape_1 = make_shared<ov::op::v1::Reshape>(input_1, final_int_shape, false);
|
||||
|
||||
model = make_shared<Model>(NodeVector{reshape_1}, ParameterVector{input_1, shape_input});
|
||||
|
||||
type_to_fuse_map empty_type_to_fuse_map = {};
|
||||
bool keep_precision_sensitive_in_fp32 = true;
|
||||
const bool convert_input_output_precision = false;
|
||||
manager.register_pass<pass::ConvertPrecision>(precisions_map{{element::f32, element::f16}},
|
||||
empty_type_to_fuse_map,
|
||||
keep_precision_sensitive_in_fp32,
|
||||
convert_input_output_precision);
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
{
|
||||
auto input_1 = make_shared<ov::op::v0::Parameter>(element::f32, Shape{1, 3, 224, 224});
|
||||
auto convert_to_f16 = make_shared<ov::op::v0::Convert>(input_1, element::f16);
|
||||
auto shape_input = make_shared<ov::op::v0::Parameter>(element::f32, Shape{2});
|
||||
|
||||
// even for FP16 compressed model shape subgraph should be kept in fp32
|
||||
auto upscale_const = ov::op::v0::Constant::create(element::f32, Shape{1}, {2.0f});
|
||||
auto mul_1 = make_shared<ov::op::v1::Multiply>(shape_input, upscale_const);
|
||||
auto axis_const = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto final_float_shape = make_shared<ov::op::v1::ReduceProd>(mul_1, axis_const);
|
||||
auto final_int_shape = make_shared<ov::op::v0::Convert>(final_float_shape, element::i64);
|
||||
auto reshape_1 = make_shared<ov::op::v1::Reshape>(convert_to_f16, final_int_shape, false);
|
||||
auto convert_to_f32 = make_shared<ov::op::v0::Convert>(reshape_1, element::f32);
|
||||
|
||||
model_ref = make_shared<Model>(NodeVector{convert_to_f32}, ParameterVector{input_1, shape_input});
|
||||
}
|
||||
const FunctionsComparator func_comparator = FunctionsComparator::with_default();
|
||||
FunctionsComparator::Result result = func_comparator(model_ref, model);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user