diff --git a/src/common/snippets/src/pass/align_element_types.cpp b/src/common/snippets/src/pass/align_element_types.cpp index da1ab1cb2c0..ebf0580ae8f 100644 --- a/src/common/snippets/src/pass/align_element_types.cpp +++ b/src/common/snippets/src/pass/align_element_types.cpp @@ -3,6 +3,8 @@ // #include "snippets/pass/align_element_types.hpp" + +#include "snippets/pass/propagate_precision.hpp" #include "snippets/itt.hpp" namespace ov { @@ -40,6 +42,20 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) consumer = transpose; } + // If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can remove the Convert, + // since the sequence existing Convert[needed_in_type->original_type] -> new Convert[original_type->needed_in_type] is redundant + if (const auto existing_convert = ov::as_type_ptr(parent_output.get_node_shared_ptr())) { + const auto actual_before = existing_convert->get_input_element_type(0); + const auto actual_after = existing_convert->get_output_element_type(0); + const auto required_after = needed_out_type; + if (ov::snippets::pass::PropagatePrecision::can_be_removed(actual_before, actual_after, required_after) && + parent_output.get_target_inputs().size() == 1) { + // remove existing convert + existing_convert->output(0).replace(existing_convert->input_value(0)); + continue; + } + } + const auto convert = std::make_shared(parent_output, needed_out_type); ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert); @@ -85,6 +101,20 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr& m) consumer_inputs = parent_output.get_target_inputs(); } + // If there is already Convert[original_type->needed_in_type] and this node is alone consumer, we can remove the Convert, + // since the sequence new Convert[needed_in_type->original_type] -> existing Convert[original_type->needed_in_type] is redundant + if (const auto existing_convert = ov::as_type_ptr(consumer_inputs.cbegin()->get_node()->shared_from_this())) { + const auto actual_before = needed_in_type; + const auto actual_after = original_type; + const auto required_after = existing_convert->get_element_type(); + if (ov::snippets::pass::PropagatePrecision::can_be_removed(actual_before, actual_after, required_after) && + consumer_inputs.size() == 1) { + // remove existing convert + existing_convert->output(0).replace(parent_output); + continue; + } + } + const auto& convert = std::make_shared(parent_output, original_type); ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index b05bf845538..0aae38f4f48 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -128,7 +128,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, MHA, ::testing::ValuesIn({false}), ::testing::Values(MHA::default_thread_count), ::testing::Values(7), - ::testing::Values(7), + ::testing::Values(6), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::cpuBF16PluginConfig)), MHA::getTestCaseName);