[CPU] Convert i64->i32 for Reference node. (#16797)

This commit is contained in:
Nikolay Shchegolev
2023-04-13 11:55:53 +04:00
committed by GitHub
parent e238bfc1d0
commit 061ba1d773
3 changed files with 71 additions and 0 deletions

View File

@@ -0,0 +1,48 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ref_convert_i64_i32.hpp"
#include <openvino/opsets/opset10.hpp>
#include "cpu_types.h"
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include "itt.hpp"
ov::pass::RefConvertI64ToI32::RefConvertI64ToI32() {
MATCHER_SCOPE(RefConvertI64ToI32);
auto i64_extension = [](const ov::Output<ov::Node>& output) -> bool {
auto node = output.get_node_shared_ptr();
return ov::intel_cpu::TypeFromName(node->get_type_name()) == ov::intel_cpu::Type::Unknown &&
ov::pass::pattern::type_matches_any({ov::element::i64, ov::element::u64})(output);
};
auto ref_m = ov::pass::pattern::any_input(i64_extension);
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto ref = m.get_match_root();
for (auto& output : ref->outputs()) {
if (output.get_element_type() == ov::element::i64 || output.get_element_type() == ov::element::u64) {
auto targetInputs = output.get_target_inputs();
auto convert = std::make_shared<ov::opset10::Convert>(output, ov::element::i32);
for (const auto& targetInput : targetInputs) {
targetInput.replace_source_output(convert);
}
auto& convertTensor = convert->output(0).get_tensor();
if (!output.get_names().empty()) {
convertTensor.set_names(output.get_names());
}
}
}
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(ref_m, matcher_name);
this->register_matcher(m, callback);
}

View File

@@ -0,0 +1,21 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/graph_rewrite.hpp>
namespace ov {
namespace pass {
// This pass inserts Convert node from i64 to i32 for Reference nodes.
class RefConvertI64ToI32: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RefConvertI64ToI32", "0");
RefConvertI64ToI32();
};
} // namespace pass
} // namespace ov

View File

@@ -99,6 +99,7 @@
#include "transformations/cpu_opset/arm/pass/mish_decomposition.hpp"
#include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp"
#include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp"
#include "transformations/cpu_opset/common/pass/ref_convert_i64_i32.hpp"
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
// Snippets
@@ -248,6 +249,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_REGISTER_PASS_COMMON(manager, ngraph::pass::low_precision::ConvertSubtractConstant, defaultPrecisions);
}
CPU_REGISTER_PASS_COMMON(manager, ov::pass::Validate);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::RefConvertI64ToI32);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertPrecision, precisions, type_to_fuse);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::EliminateConvert);
CPU_REGISTER_PASS_COMMON(manager, SwapConvertTranspose);