[Snippets] Fixed Convert elimination in AlignElementType (#20701)
This commit is contained in:
parent
b1ce297bde
commit
4e41678502
@ -3,6 +3,8 @@
|
||||
//
|
||||
|
||||
#include "snippets/pass/align_element_types.hpp"
|
||||
|
||||
#include "snippets/pass/propagate_precision.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -40,6 +42,20 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
|
||||
consumer = transpose;
|
||||
}
|
||||
|
||||
// If there is already Convert[needed_in_type->original_type] and this node has only one consumer, we can remove the Convert,
|
||||
// since the sequence existing Convert[needed_in_type->original_type] -> new Convert[original_type->needed_in_type] is redundant
|
||||
if (const auto existing_convert = ov::as_type_ptr<ov::snippets::op::ConvertSaturation>(parent_output.get_node_shared_ptr())) {
|
||||
const auto actual_before = existing_convert->get_input_element_type(0);
|
||||
const auto actual_after = existing_convert->get_output_element_type(0);
|
||||
const auto required_after = needed_out_type;
|
||||
if (ov::snippets::pass::PropagatePrecision::can_be_removed(actual_before, actual_after, required_after) &&
|
||||
parent_output.get_target_inputs().size() == 1) {
|
||||
// remove existing convert
|
||||
existing_convert->output(0).replace(existing_convert->input_value(0));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const auto convert = std::make_shared<op::ConvertSaturation>(parent_output, needed_out_type);
|
||||
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
|
||||
|
||||
@ -85,6 +101,20 @@ bool pass::AlignElementTypes::run_on_model(const std::shared_ptr<ov::Model>& m)
|
||||
consumer_inputs = parent_output.get_target_inputs();
|
||||
}
|
||||
|
||||
// If there is already Convert[original_type->needed_in_type] and this node is alone consumer, we can remove the Convert,
|
||||
// since the sequence new Convert[needed_in_type->original_type] -> existing Convert[original_type->needed_in_type] is redundant
|
||||
if (const auto existing_convert = ov::as_type_ptr<ov::snippets::op::ConvertSaturation>(consumer_inputs.cbegin()->get_node()->shared_from_this())) {
|
||||
const auto actual_before = needed_in_type;
|
||||
const auto actual_after = original_type;
|
||||
const auto required_after = existing_convert->get_element_type();
|
||||
if (ov::snippets::pass::PropagatePrecision::can_be_removed(actual_before, actual_after, required_after) &&
|
||||
consumer_inputs.size() == 1) {
|
||||
// remove existing convert
|
||||
existing_convert->output(0).replace(parent_output);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const auto& convert = std::make_shared<ov::snippets::op::ConvertSaturation>(parent_output, original_type);
|
||||
ov::copy_runtime_info(parent_output.get_node_shared_ptr(), convert);
|
||||
|
||||
|
@ -128,7 +128,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, MHA,
|
||||
::testing::ValuesIn({false}),
|
||||
::testing::Values(MHA::default_thread_count),
|
||||
::testing::Values(7),
|
||||
::testing::Values(7),
|
||||
::testing::Values(6),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuBF16PluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
Loading…
Reference in New Issue
Block a user