diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/ref_convert_i64_i32.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/ref_convert_i64_i32.cpp new file mode 100644 index 00000000000..66295921faf --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/ref_convert_i64_i32.cpp @@ -0,0 +1,48 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + + +#include "ref_convert_i64_i32.hpp" +#include +#include "cpu_types.h" +#include + +#include "itt.hpp" + +ov::pass::RefConvertI64ToI32::RefConvertI64ToI32() { + MATCHER_SCOPE(RefConvertI64ToI32); + + auto i64_extension = [](const ov::Output& output) -> bool { + auto node = output.get_node_shared_ptr(); + return ov::intel_cpu::TypeFromName(node->get_type_name()) == ov::intel_cpu::Type::Unknown && + ov::pass::pattern::type_matches_any({ov::element::i64, ov::element::u64})(output); + }; + + auto ref_m = ov::pass::pattern::any_input(i64_extension); + + ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { + const auto ref = m.get_match_root(); + + for (auto& output : ref->outputs()) { + if (output.get_element_type() == ov::element::i64 || output.get_element_type() == ov::element::u64) { + auto targetInputs = output.get_target_inputs(); + auto convert = std::make_shared(output, ov::element::i32); + + for (const auto& targetInput : targetInputs) { + targetInput.replace_source_output(convert); + } + + auto& convertTensor = convert->output(0).get_tensor(); + if (!output.get_names().empty()) { + convertTensor.set_names(output.get_names()); + } + } + } + + return true; + }; + + auto m = std::make_shared(ref_m, matcher_name); + this->register_matcher(m, callback); +} diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/ref_convert_i64_i32.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/ref_convert_i64_i32.hpp new file mode 100644 index 00000000000..271c4a0a42d --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/ref_convert_i64_i32.hpp @@ -0,0 +1,21 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace pass { + +// This pass inserts Convert node from i64 to i32 for Reference nodes. + +class RefConvertI64ToI32: public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RefConvertI64ToI32", "0"); + RefConvertI64ToI32(); +}; + +} // namespace pass +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 3ee53279478..942c41b3048 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -99,6 +99,7 @@ #include "transformations/cpu_opset/arm/pass/mish_decomposition.hpp" #include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp" #include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp" +#include "transformations/cpu_opset/common/pass/ref_convert_i64_i32.hpp" #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" // Snippets @@ -248,6 +249,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_REGISTER_PASS_COMMON(manager, ngraph::pass::low_precision::ConvertSubtractConstant, defaultPrecisions); } CPU_REGISTER_PASS_COMMON(manager, ov::pass::Validate); + CPU_REGISTER_PASS_COMMON(manager, ov::pass::RefConvertI64ToI32); CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertPrecision, precisions, type_to_fuse); CPU_REGISTER_PASS_COMMON(manager, ov::pass::EliminateConvert); CPU_REGISTER_PASS_COMMON(manager, SwapConvertTranspose);