[Snippets] Fixed Convert elimination in AlignElementType (#20701)

This commit is contained in:
Alexandra Sidorova 2023-10-27 13:50:51 +04:00 committed by GitHub
parent b1ce297bde
commit 4e41678502
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 1 deletions

View File

@ -3,6 +3,8 @@
//
#include "snippets/pass/align_element_types.hpp"
#include "snippets/pass/propagate_precision.hpp"
#include "snippets/itt.hpp"
namespace ov {
@ -40,6 +42,20 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
consumer = transpose;
}
// If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can remove the Convert,
// since the sequence existing Convert[needed_in_type->original_type] -> new Convert[original_type->needed_in_type] is redundant
if (const auto existing_convert = ov::as_type_ptr<ov::snippets::op::ConvertSaturation>(parent_output.get_node_shared_ptr())) {
const auto actual_before = existing_convert->get_input_element_type(0);
const auto actual_after = existing_convert->get_output_element_type(0);
const auto required_after = needed_out_type;
if (ov::snippets::pass::PropagatePrecision::can_be_removed(actual_before, actual_after, required_after) &&
parent_output.get_target_inputs().size() == 1) {
// remove existing convert
existing_convert->output(0).replace(existing_convert->input_value(0));
continue;
}
}
const auto convert = std::make_shared<op::ConvertSaturation>(parent_output, needed_out_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
@ -85,6 +101,20 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
consumer_inputs = parent_output.get_target_inputs();
}
// If there is already Convert[original_type->needed_in_type] and this node is alone consumer, we can remove the Convert,
// since the sequence new Convert[needed_in_type->original_type] -> existing Convert[original_type->needed_in_type] is redundant
if (const auto existing_convert = ov::as_type_ptr<ov::snippets::op::ConvertSaturation>(consumer_inputs.cbegin()->get_node()->shared_from_this())) {
const auto actual_before = needed_in_type;
const auto actual_after = original_type;
const auto required_after = existing_convert->get_element_type();
if (ov::snippets::pass::PropagatePrecision::can_be_removed(actual_before, actual_after, required_after) &&
consumer_inputs.size() == 1) {
// remove existing convert
existing_convert->output(0).replace(parent_output);
continue;
}
}
const auto& convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, original_type);
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);

View File

@ -128,7 +128,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, MHA,
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(7),
::testing::Values(6),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuBF16PluginConfig)),
MHA::getTestCaseName);