Fix constant folding in MulMulMulFusion (#20803)
* Fix constant folding in MulMulMulFusion by add f64 precision in Multiply to perform evaluate for const folding * Do not transform if input has not supported type
This commit is contained in:
parent
dd5886ed46
commit
84ee240bfa
@ -17,6 +17,13 @@
|
||||
|
||||
using namespace ov;
|
||||
|
||||
namespace {
|
||||
const auto is_eltwise_supported_type = [](const Output<Node>& output) -> bool {
|
||||
const auto is_single_output = pass::pattern::consumers_count(1);
|
||||
return is_single_output(output) && output.get_node()->has_evaluate();
|
||||
};
|
||||
}
|
||||
|
||||
ov::pass::AddMultiplyFusion::AddMultiplyFusion() {
|
||||
MATCHER_SCOPE(AddMultiplyFusion);
|
||||
// Create Add->Multiply pattern where Add has exactly one consumer
|
||||
@ -105,7 +112,7 @@ ov::pass::MultiplyMultiplyFusion::MultiplyMultiplyFusion() {
|
||||
auto m_data = pass::pattern::any_input();
|
||||
auto m_mul1_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
||||
auto m_mul1 =
|
||||
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_data, m_mul1_constant}, pattern::consumers_count(1));
|
||||
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_data, m_mul1_constant}, is_eltwise_supported_type);
|
||||
auto m_mul2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
||||
auto m_mul2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_mul1, m_mul2_constant});
|
||||
|
||||
|
@ -79,6 +79,48 @@ TEST_F(TransformationTestsF, MulMulMulFusion) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, MulMulMulFusion_f64) {
|
||||
{
|
||||
auto input = std::make_shared<opset3::Parameter>(element::f64, Shape{1, 128, 3072});
|
||||
auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {2});
|
||||
auto mul2_const = opset3::Constant::create(element::f64, Shape{128, 1}, {3});
|
||||
auto mul3_const = opset3::Constant::create(element::f64, Shape{1}, {3});
|
||||
|
||||
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
|
||||
auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
|
||||
auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{mul2}, ParameterVector{input});
|
||||
manager.register_pass<ov::pass::LinOpSequenceFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset3::Parameter>(element::f64, Shape{1, 128, 3072});
|
||||
auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {12});
|
||||
|
||||
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
|
||||
|
||||
model_ref = std::make_shared<ov::Model>(NodeVector{mul1}, ParameterVector{input});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, MulMulMulFusion_not_supported_type) {
|
||||
constexpr auto et = element::u8;
|
||||
{
|
||||
auto input = std::make_shared<opset3::Parameter>(et, Shape{1, 128, 3072});
|
||||
auto mul1_const = opset3::Constant::create(et, Shape{128, 1}, {2});
|
||||
auto mul2_const = opset3::Constant::create(et, Shape{128, 1}, {3});
|
||||
auto mul3_const = opset3::Constant::create(et, Shape{1}, {3});
|
||||
|
||||
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
|
||||
auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
|
||||
auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
|
||||
|
||||
model = std::make_shared<ov::Model>(NodeVector{mul2}, ParameterVector{input});
|
||||
manager.register_pass<ov::pass::LinOpSequenceFusion>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, AddAddAddFusion) {
|
||||
{
|
||||
auto input = std::make_shared<opset3::Parameter>(element::f32, Shape{1, 128, 3072});
|
||||
|
@ -41,6 +41,7 @@ bool evaluate_multiply(const HostTensorPtr& arg0,
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, f64, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec);
|
||||
OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec);
|
||||
@ -80,6 +81,7 @@ bool op::v1::Multiply::has_evaluate() const {
|
||||
case ngraph::element::u64:
|
||||
case ngraph::element::f16:
|
||||
case ngraph::element::f32:
|
||||
case ngraph::element::f64:
|
||||
case ngraph::element::bf16:
|
||||
return true;
|
||||
default:
|
||||
|
@ -117,6 +117,7 @@ std::vector<MultiplyParams> generateParamsForMultiplyFloat() {
|
||||
|
||||
std::vector<MultiplyParams> generateCombinedParamsForMultiply() {
|
||||
const std::vector<std::vector<MultiplyParams>> allTypeParams{generateParamsForMultiply<element::Type_t::f32>(),
|
||||
generateParamsForMultiply<element::Type_t::f64>(),
|
||||
generateParamsForMultiply<element::Type_t::f16>(),
|
||||
generateParamsForMultiply<element::Type_t::bf16>(),
|
||||
generateParamsForMultiply<element::Type_t::i64>(),
|
||||
|
Loading…
Reference in New Issue
Block a user