Fix constant folding in MulMulMulFusion (#20803)
* Fix constant folding in MulMulMulFusion by add f64 precision in Multiply to perform evaluate for const folding * Do not transform if input has not supported type
This commit is contained in:
parent
dd5886ed46
commit
84ee240bfa
@ -17,6 +17,13 @@
|
|||||||
|
|
||||||
using namespace ov;
|
using namespace ov;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
const auto is_eltwise_supported_type = [](const Output<Node>& output) -> bool {
|
||||||
|
const auto is_single_output = pass::pattern::consumers_count(1);
|
||||||
|
return is_single_output(output) && output.get_node()->has_evaluate();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
ov::pass::AddMultiplyFusion::AddMultiplyFusion() {
|
ov::pass::AddMultiplyFusion::AddMultiplyFusion() {
|
||||||
MATCHER_SCOPE(AddMultiplyFusion);
|
MATCHER_SCOPE(AddMultiplyFusion);
|
||||||
// Create Add->Multiply pattern where Add has exactly one consumer
|
// Create Add->Multiply pattern where Add has exactly one consumer
|
||||||
@ -105,7 +112,7 @@ ov::pass::MultiplyMultiplyFusion::MultiplyMultiplyFusion() {
|
|||||||
auto m_data = pass::pattern::any_input();
|
auto m_data = pass::pattern::any_input();
|
||||||
auto m_mul1_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
auto m_mul1_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
||||||
auto m_mul1 =
|
auto m_mul1 =
|
||||||
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_data, m_mul1_constant}, pattern::consumers_count(1));
|
ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_data, m_mul1_constant}, is_eltwise_supported_type);
|
||||||
auto m_mul2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
auto m_mul2_constant = ov::pass::pattern::wrap_type<ov::op::v0::Constant>();
|
||||||
auto m_mul2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_mul1, m_mul2_constant});
|
auto m_mul2 = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_mul1, m_mul2_constant});
|
||||||
|
|
||||||
|
@ -79,6 +79,48 @@ TEST_F(TransformationTestsF, MulMulMulFusion) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, MulMulMulFusion_f64) {
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<opset3::Parameter>(element::f64, Shape{1, 128, 3072});
|
||||||
|
auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {2});
|
||||||
|
auto mul2_const = opset3::Constant::create(element::f64, Shape{128, 1}, {3});
|
||||||
|
auto mul3_const = opset3::Constant::create(element::f64, Shape{1}, {3});
|
||||||
|
|
||||||
|
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
|
||||||
|
auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
|
||||||
|
auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
|
||||||
|
|
||||||
|
model = std::make_shared<ov::Model>(NodeVector{mul2}, ParameterVector{input});
|
||||||
|
manager.register_pass<ov::pass::LinOpSequenceFusion>();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<opset3::Parameter>(element::f64, Shape{1, 128, 3072});
|
||||||
|
auto mul1_const = opset3::Constant::create(element::f64, Shape{128, 1}, {12});
|
||||||
|
|
||||||
|
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
|
||||||
|
|
||||||
|
model_ref = std::make_shared<ov::Model>(NodeVector{mul1}, ParameterVector{input});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, MulMulMulFusion_not_supported_type) {
|
||||||
|
constexpr auto et = element::u8;
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<opset3::Parameter>(et, Shape{1, 128, 3072});
|
||||||
|
auto mul1_const = opset3::Constant::create(et, Shape{128, 1}, {2});
|
||||||
|
auto mul2_const = opset3::Constant::create(et, Shape{128, 1}, {3});
|
||||||
|
auto mul3_const = opset3::Constant::create(et, Shape{1}, {3});
|
||||||
|
|
||||||
|
auto mul1 = std::make_shared<opset3::Multiply>(input, mul1_const);
|
||||||
|
auto mul2 = std::make_shared<opset3::Multiply>(mul1, mul2_const);
|
||||||
|
auto mul3 = std::make_shared<opset3::Multiply>(mul2, mul3_const);
|
||||||
|
|
||||||
|
model = std::make_shared<ov::Model>(NodeVector{mul2}, ParameterVector{input});
|
||||||
|
manager.register_pass<ov::pass::LinOpSequenceFusion>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, AddAddAddFusion) {
|
TEST_F(TransformationTestsF, AddAddAddFusion) {
|
||||||
{
|
{
|
||||||
auto input = std::make_shared<opset3::Parameter>(element::f32, Shape{1, 128, 3072});
|
auto input = std::make_shared<opset3::Parameter>(element::f32, Shape{1, 128, 3072});
|
||||||
|
@ -41,6 +41,7 @@ bool evaluate_multiply(const HostTensorPtr& arg0,
|
|||||||
OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec);
|
OPENVINO_TYPE_CASE(evaluate_multiply, u64, arg0, arg1, out, broadcast_spec);
|
||||||
OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec);
|
OPENVINO_TYPE_CASE(evaluate_multiply, f16, arg0, arg1, out, broadcast_spec);
|
||||||
OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec);
|
OPENVINO_TYPE_CASE(evaluate_multiply, f32, arg0, arg1, out, broadcast_spec);
|
||||||
|
OPENVINO_TYPE_CASE(evaluate_multiply, f64, arg0, arg1, out, broadcast_spec);
|
||||||
OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec);
|
OPENVINO_TYPE_CASE(evaluate_multiply, bf16, arg0, arg1, out, broadcast_spec);
|
||||||
OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec);
|
OPENVINO_TYPE_CASE(evaluate_multiply, u8, arg0, arg1, out, broadcast_spec);
|
||||||
OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec);
|
OPENVINO_TYPE_CASE(evaluate_multiply, i16, arg0, arg1, out, broadcast_spec);
|
||||||
@ -80,6 +81,7 @@ bool op::v1::Multiply::has_evaluate() const {
|
|||||||
case ngraph::element::u64:
|
case ngraph::element::u64:
|
||||||
case ngraph::element::f16:
|
case ngraph::element::f16:
|
||||||
case ngraph::element::f32:
|
case ngraph::element::f32:
|
||||||
|
case ngraph::element::f64:
|
||||||
case ngraph::element::bf16:
|
case ngraph::element::bf16:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
|
@ -117,6 +117,7 @@ std::vector<MultiplyParams> generateParamsForMultiplyFloat() {
|
|||||||
|
|
||||||
std::vector<MultiplyParams> generateCombinedParamsForMultiply() {
|
std::vector<MultiplyParams> generateCombinedParamsForMultiply() {
|
||||||
const std::vector<std::vector<MultiplyParams>> allTypeParams{generateParamsForMultiply<element::Type_t::f32>(),
|
const std::vector<std::vector<MultiplyParams>> allTypeParams{generateParamsForMultiply<element::Type_t::f32>(),
|
||||||
|
generateParamsForMultiply<element::Type_t::f64>(),
|
||||||
generateParamsForMultiply<element::Type_t::f16>(),
|
generateParamsForMultiply<element::Type_t::f16>(),
|
||||||
generateParamsForMultiply<element::Type_t::bf16>(),
|
generateParamsForMultiply<element::Type_t::bf16>(),
|
||||||
generateParamsForMultiply<element::Type_t::i64>(),
|
generateParamsForMultiply<element::Type_t::i64>(),
|
||||||
|
Loading…
Reference in New Issue
Block a user