[CPU] Convert i64->i32 for Reference node. (#16797)
This commit is contained in:
committed by
GitHub
parent
e238bfc1d0
commit
061ba1d773
@@ -0,0 +1,48 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
|
||||
#include "ref_convert_i64_i32.hpp"
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include "cpu_types.h"
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
ov::pass::RefConvertI64ToI32::RefConvertI64ToI32() {
|
||||
MATCHER_SCOPE(RefConvertI64ToI32);
|
||||
|
||||
auto i64_extension = [](const ov::Output<ov::Node>& output) -> bool {
|
||||
auto node = output.get_node_shared_ptr();
|
||||
return ov::intel_cpu::TypeFromName(node->get_type_name()) == ov::intel_cpu::Type::Unknown &&
|
||||
ov::pass::pattern::type_matches_any({ov::element::i64, ov::element::u64})(output);
|
||||
};
|
||||
|
||||
auto ref_m = ov::pass::pattern::any_input(i64_extension);
|
||||
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
const auto ref = m.get_match_root();
|
||||
|
||||
for (auto& output : ref->outputs()) {
|
||||
if (output.get_element_type() == ov::element::i64 || output.get_element_type() == ov::element::u64) {
|
||||
auto targetInputs = output.get_target_inputs();
|
||||
auto convert = std::make_shared<ov::opset10::Convert>(output, ov::element::i32);
|
||||
|
||||
for (const auto& targetInput : targetInputs) {
|
||||
targetInput.replace_source_output(convert);
|
||||
}
|
||||
|
||||
auto& convertTensor = convert->output(0).get_tensor();
|
||||
if (!output.get_names().empty()) {
|
||||
convertTensor.set_names(output.get_names());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(ref_m, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
// This pass inserts Convert node from i64 to i32 for Reference nodes.
|
||||
|
||||
class RefConvertI64ToI32: public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("RefConvertI64ToI32", "0");
|
||||
RefConvertI64ToI32();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@@ -99,6 +99,7 @@
|
||||
#include "transformations/cpu_opset/arm/pass/mish_decomposition.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/convert_fq_rnn_to_quantized_rnn.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/ref_convert_i64_i32.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
|
||||
|
||||
// Snippets
|
||||
@@ -248,6 +249,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
CPU_REGISTER_PASS_COMMON(manager, ngraph::pass::low_precision::ConvertSubtractConstant, defaultPrecisions);
|
||||
}
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::Validate);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::RefConvertI64ToI32);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertPrecision, precisions, type_to_fuse);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::EliminateConvert);
|
||||
CPU_REGISTER_PASS_COMMON(manager, SwapConvertTranspose);
|
||||
|
||||
Reference in New Issue
Block a user