Fix constant folding in MulMulMulFusion (#20803)

* Fix constant folding in MulMulMulFusion
by add f64 precision in Multiply to perform evaluate for const folding

* Do not transform if input has not supported type
This commit is contained in:
Pawel Raasz 2023-11-07 11:57:29 +01:00 committed by Alexander Nesterov
parent dd5886ed46
commit 84ee240bfa
4 changed files with 53 additions and 1 deletions

View File

@ -17,6 +17,13 @@
using namespace ov; using namespace ov;
namespace {
const auto is_eltwise_supported_type = [](const Output<Node>& output) -> bool {
const auto is_single_output = pass::pattern::consumers_count(1);
return is_single_output(output) && output.get_node()->has_evaluate();
};
}
ov::pass::AddMultiplyFusion::AddMultiplyFusion() { ov::pass::AddMultiplyFusion::AddMultiplyFusion() {
MATCHER_SCOPE(AddMultiplyFusion); MATCHER_SCOPE(AddMultiplyFusion);
// Create Add->Multiply pattern where Add has exactly one consumer // Create Add->Multiply pattern where Add has exactly one consumer
@ -105,7 +112,7 @@ ov::pass::MultiplyMultiplyFusion::MultiplyMultiplyFusion() {
auto m_data = pass::pattern::any_input(); auto m_data = pass::pattern::any_input();
auto m_mul1_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(); auto m_mul1_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto m_mul1 = auto m_mul1 =
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_data, m_mul1_constant}, pattern::consumers_count(1)); ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_data, m_mul1_constant}, is_eltwise_supported_type);
auto m_mul2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>(); auto m_mul2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
auto m_mul2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_mul1, m_mul2_constant}); auto m_mul2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_mul1, m_mul2_constant});

View File

@ -79,6 +79,48 @@ TEST_F(TransformationTestsF, MulMulMulFusion) {
} }
} }
TEST_F(TransformationTestsF, MulMulMulFusion_f64) {
{
auto input = std::make_shared<opset3::Parameter>(element::f64, Shape{1, 128, 3072});
auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {2});
auto mul2_const = opset3::Constant::create(element::f64, Shape{128, 1}, {3});
auto mul3_const = opset3::Constant::create(element::f64, Shape{1}, {3});
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
model = std::make_shared<ov::Model>(NodeVector{mul2}, ParameterVector{input});
manager.register_pass<ov::pass::LinOpSequenceFusion>();
}
{
auto input = std::make_shared<opset3::Parameter>(element::f64, Shape{1, 128, 3072});
auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {12});
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
model_ref = std::make_shared<ov::Model>(NodeVector{mul1}, ParameterVector{input});
}
}
TEST_F(TransformationTestsF, MulMulMulFusion_not_supported_type) {
constexpr auto et = element::u8;
{
auto input = std::make_shared<opset3::Parameter>(et, Shape{1, 128, 3072});
auto mul1_const = opset3::Constant::create(et, Shape{128, 1}, {2});
auto mul2_const = opset3::Constant::create(et, Shape{128, 1}, {3});
auto mul3_const = opset3::Constant::create(et, Shape{1}, {3});
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
model = std::make_shared<ov::Model>(NodeVector{mul2}, ParameterVector{input});
manager.register_pass<ov::pass::LinOpSequenceFusion>();
}
}
TEST_F(TransformationTestsF, AddAddAddFusion) { TEST_F(TransformationTestsF, AddAddAddFusion) {
{ {
auto input = std::make_shared<opset3::Parameter>(element::f32, Shape{1, 128, 3072}); auto input = std::make_shared<opset3::Parameter>(element::f32, Shape{1, 128, 3072});

View File

@ -41,6 +41,7 @@ bool evaluate_multiply(const HostTensorPtr& arg0,
OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_multiply, f64, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec);
OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec);
@ -80,6 +81,7 @@ bool op::v1::Multiply::has_evaluate() const {
case ngraph::element::u64: case ngraph::element::u64:
case ngraph::element::f16: case ngraph::element::f16:
case ngraph::element::f32: case ngraph::element::f32:
case ngraph::element::f64:
case ngraph::element::bf16: case ngraph::element::bf16:
return true; return true;
default: default:

View File

@ -117,6 +117,7 @@ std::vector<MultiplyParams> generateParamsForMultiplyFloat() {
std::vector<MultiplyParams> generateCombinedParamsForMultiply() { std::vector<MultiplyParams> generateCombinedParamsForMultiply() {
const std::vector<std::vector<MultiplyParams>> allTypeParams{generateParamsForMultiply<element::Type_t::f32>(), const std::vector<std::vector<MultiplyParams>> allTypeParams{generateParamsForMultiply<element::Type_t::f32>(),
generateParamsForMultiply<element::Type_t::f64>(),
generateParamsForMultiply<element::Type_t::f16>(), generateParamsForMultiply<element::Type_t::f16>(),
generateParamsForMultiply<element::Type_t::bf16>(), generateParamsForMultiply<element::Type_t::bf16>(),
generateParamsForMultiply<element::Type_t::i64>(), generateParamsForMultiply<element::Type_t::i64>(),