[CPU] Fixed naming in SwapConvertTranspose pass (#18208)
* [CPU] Fixed naming in SwapConvertTranspose pass * Applied Vladislav comments
This commit is contained in:
parent
4cc70e22e5
commit
211c56acf9
@ -14,12 +14,12 @@ NGRAPH_RTTI_DEFINITION(ov::intel_cpu::SwapConvertTranspose, "SwapConvertTranspos
|
||||
|
||||
ov::intel_cpu::SwapConvertTranspose::SwapConvertTranspose() {
|
||||
MATCHER_SCOPE(SwapConvertTranspose);
|
||||
ngraph::element::TypeVector param_precisions{ ngraph::element::i8, ngraph::element::u8 };
|
||||
auto input_m = ngraph::pattern::wrap_type<ngraph::op::v0::Parameter>(ngraph::pattern::type_matches_any(param_precisions));
|
||||
auto convert_m = ngraph::pattern::wrap_type<ngraph::op::v0::Convert>({input_m}, ngraph::pattern::type_matches(ngraph::element::f32));
|
||||
auto transpose_m = ngraph::pattern::wrap_type<ngraph::op::v1::Transpose>({convert_m, ngraph::pattern::any_input()});
|
||||
ov::element::TypeVector param_precisions{ ov::element::i8, ov::element::u8 };
|
||||
auto input_m = ov::pass::pattern::wrap_type<ov::op::v0::Parameter>(ov::pass::pattern::type_matches_any(param_precisions));
|
||||
auto convert_m = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({input_m}, ov::pass::pattern::type_matches(ov::element::f32));
|
||||
auto transpose_m = ov::pass::pattern::wrap_type<ov::op::v1::Transpose>({convert_m, ov::pass::pattern::any_input()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
// Swap
|
||||
// Input -> [i8/u8] -> Convert -> [f32] -> Transpose -> [f32]
|
||||
// to
|
||||
@ -28,19 +28,24 @@ ov::intel_cpu::SwapConvertTranspose::SwapConvertTranspose() {
|
||||
auto convert = pattern_map.at(convert_m).get_node_shared_ptr();
|
||||
auto transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
|
||||
|
||||
ngraph::OutputVector transposeInputs = transpose->input_values();
|
||||
if (convert->get_output_target_inputs(0).size() != 1)
|
||||
return false;
|
||||
|
||||
ov::OutputVector transposeInputs = transpose->input_values();
|
||||
transposeInputs[0] = convert->input_value(0);
|
||||
auto newTranspose = transpose->clone_with_new_inputs(transposeInputs);
|
||||
ngraph::copy_runtime_info(transpose, newTranspose);
|
||||
newTranspose->set_friendly_name(transpose->get_friendly_name());
|
||||
newTranspose->set_friendly_name(transpose->get_friendly_name() + "_original");
|
||||
|
||||
ngraph::OutputVector convertInputs = convert->input_values();
|
||||
ov::OutputVector convertInputs = convert->input_values();
|
||||
convertInputs[0] = newTranspose;
|
||||
auto newConvert = convert->clone_with_new_inputs(convertInputs);
|
||||
ngraph::replace_node(transpose, newConvert);
|
||||
ov::replace_node(transpose, newConvert);
|
||||
newConvert->set_friendly_name(transpose->get_friendly_name());
|
||||
|
||||
ov::copy_runtime_info(transpose, { newTranspose, newConvert });
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_m, matcher_name);
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(transpose_m, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
@ -0,0 +1,79 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include <transformations/cpu_opset/common/pass/swap_convert_transpose.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
class SwapConvertTransposeTest: public TransformationTestsF {
|
||||
public:
|
||||
SwapConvertTransposeTest() : TransformationTestsF() {
|
||||
comparator.enable(FunctionsComparator::CmpValues::NAMES);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SwapConvertTransposeTest, SwapConvertTranspose) {
|
||||
const ov::Shape shape{1, 224, 224, 3};
|
||||
const std::vector<int64_t> input_order = {0, 3, 1, 2};
|
||||
const ov::element::Type in_type = ov::element::u8;
|
||||
const ov::element::Type out_type = ov::element::f32;
|
||||
const std::string transpose_name = "Transpose";
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(in_type, shape);
|
||||
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(input, out_type);
|
||||
|
||||
auto transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
|
||||
auto transpose = std::make_shared<ov::op::v1::Transpose>(convert, transpose_const);
|
||||
transpose->set_friendly_name(transpose_name);
|
||||
|
||||
function = std::make_shared<ov::Model>(ov::NodeVector{transpose}, ov::ParameterVector{input});
|
||||
manager.register_pass<ov::intel_cpu::SwapConvertTranspose>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(in_type, shape);
|
||||
|
||||
auto transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
|
||||
auto transpose = std::make_shared<ov::op::v1::Transpose>(input, transpose_const);
|
||||
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(transpose, out_type);
|
||||
|
||||
transpose->set_friendly_name(transpose_name + "_original");
|
||||
convert->set_friendly_name(transpose_name);
|
||||
|
||||
function_ref = std::make_shared<ov::Model>(ov::NodeVector{convert}, ov::ParameterVector{input});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SwapConvertTransposeTest, SwapConvertTransposeImpossible) {
|
||||
const ov::Shape shape{1, 224, 224, 3};
|
||||
const std::vector<int64_t> input_order = {0, 3, 1, 2};
|
||||
const ov::element::Type in_type = ov::element::u8;
|
||||
const ov::element::Type out_type = ov::element::f32;
|
||||
const std::string transpose_name = "Transpose";
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(in_type, shape);
|
||||
|
||||
auto convert = std::make_shared<ov::op::v0::Convert>(input, out_type);
|
||||
|
||||
auto transpose0_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
|
||||
auto transpose0 = std::make_shared<ov::op::v1::Transpose>(convert, transpose0_const);
|
||||
transpose0->set_friendly_name(transpose_name + "_0");
|
||||
|
||||
auto transpose1_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
|
||||
auto transpose1 = std::make_shared<ov::op::v1::Transpose>(convert, transpose1_const);
|
||||
transpose1->set_friendly_name(transpose_name + "_1");
|
||||
|
||||
function = std::make_shared<ov::Model>(ov::NodeVector{transpose0, transpose1}, ov::ParameterVector{input});
|
||||
manager.register_pass<ov::intel_cpu::SwapConvertTranspose>();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user