diff --git a/src/common/transformations/src/transformations/common_optimizations/lin_op_sequence_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/lin_op_sequence_fusion.cpp index eaaa2a52312..6c45867b007 100644 --- a/src/common/transformations/src/transformations/common_optimizations/lin_op_sequence_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/lin_op_sequence_fusion.cpp @@ -17,6 +17,13 @@ using namespace ov; +namespace { +const auto is_eltwise_supported_type = [](const Output& output) -> bool { + const auto is_single_output = pass::pattern::consumers_count(1); + return is_single_output(output) && output.get_node()->has_evaluate(); +}; +} + ov::pass::AddMultiplyFusion::AddMultiplyFusion() { MATCHER_SCOPE(AddMultiplyFusion); // Create Add->Multiply pattern where Add has exactly one consumer @@ -105,7 +112,7 @@ ov::pass::MultiplyMultiplyFusion::MultiplyMultiplyFusion() { auto m_data = pass::pattern::any_input(); auto m_mul1_constant = ov::pass::pattern::wrap_type(); auto m_mul1 = - ov::pass::pattern::wrap_type({m_data, m_mul1_constant}, pattern::consumers_count(1)); + ov::pass::pattern::wrap_type({m_data, m_mul1_constant}, is_eltwise_supported_type); auto m_mul2_constant = ov::pass::pattern::wrap_type(); auto m_mul2 = ov::pass::pattern::wrap_type({m_mul1, m_mul2_constant}); diff --git a/src/common/transformations/tests/common_optimizations/lin_op_sequence_fusion_test.cpp b/src/common/transformations/tests/common_optimizations/lin_op_sequence_fusion_test.cpp index 71c16dea124..cfef2d2b2ac 100644 --- a/src/common/transformations/tests/common_optimizations/lin_op_sequence_fusion_test.cpp +++ b/src/common/transformations/tests/common_optimizations/lin_op_sequence_fusion_test.cpp @@ -79,6 +79,48 @@ TEST_F(TransformationTestsF, MulMulMulFusion) { } } +TEST_F(TransformationTestsF, MulMulMulFusion_f64) { + { + auto input = std::make_shared(element::f64, Shape{1, 128, 3072}); + auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {2}); + auto mul2_const = opset3::Constant::create(element::f64, Shape{128, 1}, {3}); + auto mul3_const = opset3::Constant::create(element::f64, Shape{1}, {3}); + + auto mul1 = std::make_shared(input, mul1_const); + auto mul2 = std::make_shared(mul1, mul2_const); + auto mul3 = std::make_shared(mul2, mul3_const); + + model = std::make_shared(NodeVector{mul2}, ParameterVector{input}); + manager.register_pass(); + } + + { + auto input = std::make_shared(element::f64, Shape{1, 128, 3072}); + auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {12}); + + auto mul1 = std::make_shared(input, mul1_const); + + model_ref = std::make_shared(NodeVector{mul1}, ParameterVector{input}); + } +} + +TEST_F(TransformationTestsF, MulMulMulFusion_not_supported_type) { + constexpr auto et = element::u8; + { + auto input = std::make_shared(et, Shape{1, 128, 3072}); + auto mul1_const = opset3::Constant::create(et, Shape{128, 1}, {2}); + auto mul2_const = opset3::Constant::create(et, Shape{128, 1}, {3}); + auto mul3_const = opset3::Constant::create(et, Shape{1}, {3}); + + auto mul1 = std::make_shared(input, mul1_const); + auto mul2 = std::make_shared(mul1, mul2_const); + auto mul3 = std::make_shared(mul2, mul3_const); + + model = std::make_shared(NodeVector{mul2}, ParameterVector{input}); + manager.register_pass(); + } +} + TEST_F(TransformationTestsF, AddAddAddFusion) { { auto input = std::make_shared(element::f32, Shape{1, 128, 3072}); diff --git a/src/core/src/op/multiply.cpp b/src/core/src/op/multiply.cpp index 04ccc8d05e3..b30c2adaa7d 100644 --- a/src/core/src/op/multiply.cpp +++ b/src/core/src/op/multiply.cpp @@ -41,6 +41,7 @@ bool evaluate_multiply(const HostTensorPtr& arg0, OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec); + OPENVINO_TYPE_CASE(evaluate_multiply, f64, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec); OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec); @@ -80,6 +81,7 @@ bool op::v1::Multiply::has_evaluate() const { case ngraph::element::u64: case ngraph::element::f16: case ngraph::element::f32: + case ngraph::element::f64: case ngraph::element::bf16: return true; default: diff --git a/src/plugins/template/tests/functional/op_reference/multiply.cpp b/src/plugins/template/tests/functional/op_reference/multiply.cpp index bd3b27f500a..726917eac9c 100644 --- a/src/plugins/template/tests/functional/op_reference/multiply.cpp +++ b/src/plugins/template/tests/functional/op_reference/multiply.cpp @@ -117,6 +117,7 @@ std::vector generateParamsForMultiplyFloat() { std::vector generateCombinedParamsForMultiply() { const std::vector> allTypeParams{generateParamsForMultiply(), + generateParamsForMultiply(), generateParamsForMultiply(), generateParamsForMultiply(), generateParamsForMultiply(),