[CPU] Fixed naming in SwapConvertTranspose pass (#18208)

* [CPU] Fixed naming in SwapConvertTranspose pass

* Applied Vladislav comments
This commit is contained in:
Alexandra Sidorova 2023-07-04 13:34:07 +04:00 committed by GitHub
parent 4cc70e22e5
commit 211c56acf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 95 additions and 11 deletions

View File

@ -14,12 +14,12 @@ NGRAPH_RTTI_DEFINITION(ov::intel_cpu::SwapConvertTranspose, "SwapConvertTranspos
ov::intel_cpu::SwapConvertTranspose::SwapConvertTranspose() {
MATCHER_SCOPE(SwapConvertTranspose);
ngraph::element::TypeVector param_precisions{ ngraph::element::i8, ngraph::element::u8 };
auto input_m = ngraph::pattern::wrap_type<ngraph::op::v0::Parameter>(ngraph::pattern::type_matches_any(param_precisions));
auto convert_m = ngraph::pattern::wrap_type<ngraph::op::v0::Convert>({input_m}, ngraph::pattern::type_matches(ngraph::element::f32));
auto transpose_m = ngraph::pattern::wrap_type<ngraph::op::v1::Transpose>({convert_m, ngraph::pattern::any_input()});
ov::element::TypeVector param_precisions{ ov::element::i8, ov::element::u8 };
auto input_m = ov::pass::pattern::wrap_type<ov::op::v0::Parameter>(ov::pass::pattern::type_matches_any(param_precisions));
auto convert_m = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({input_m}, ov::pass::pattern::type_matches(ov::element::f32));
auto transpose_m = ov::pass::pattern::wrap_type<ov::op::v1::Transpose>({convert_m, ov::pass::pattern::any_input()});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
// Swap
// Input -> [i8/u8] -> Convert -> [f32] -> Transpose -> [f32]
// to
@ -28,19 +28,24 @@ ov::intel_cpu::SwapConvertTranspose::SwapConvertTranspose() {
auto convert = pattern_map.at(convert_m).get_node_shared_ptr();
auto transpose = pattern_map.at(transpose_m).get_node_shared_ptr();
ngraph::OutputVector transposeInputs = transpose->input_values();
if (convert->get_output_target_inputs(0).size() != 1)
return false;
ov::OutputVector transposeInputs = transpose->input_values();
transposeInputs[0] = convert->input_value(0);
auto newTranspose = transpose->clone_with_new_inputs(transposeInputs);
ngraph::copy_runtime_info(transpose, newTranspose);
newTranspose->set_friendly_name(transpose->get_friendly_name());
newTranspose->set_friendly_name(transpose->get_friendly_name() + "_original");
ngraph::OutputVector convertInputs = convert->input_values();
ov::OutputVector convertInputs = convert->input_values();
convertInputs[0] = newTranspose;
auto newConvert = convert->clone_with_new_inputs(convertInputs);
ngraph::replace_node(transpose, newConvert);
ov::replace_node(transpose, newConvert);
newConvert->set_friendly_name(transpose->get_friendly_name());
ov::copy_runtime_info(transpose, { newTranspose, newConvert });
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_m, matcher_name);
auto m = std::make_shared<ov::pass::pattern::Matcher>(transpose_m, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,79 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/ngraph_test_utils.hpp"
#include <transformations/cpu_opset/common/pass/swap_convert_transpose.hpp>
#include <transformations/init_node_info.hpp>
#include "openvino/opsets/opset1.hpp"
using namespace testing;
class SwapConvertTransposeTest: public TransformationTestsF {
public:
SwapConvertTransposeTest() : TransformationTestsF() {
comparator.enable(FunctionsComparator::CmpValues::NAMES);
}
};
TEST_F(SwapConvertTransposeTest, SwapConvertTranspose) {
const ov::Shape shape{1, 224, 224, 3};
const std::vector<int64_t> input_order = {0, 3, 1, 2};
const ov::element::Type in_type = ov::element::u8;
const ov::element::Type out_type = ov::element::f32;
const std::string transpose_name = "Transpose";
{
auto input = std::make_shared<ov::op::v0::Parameter>(in_type, shape);
auto convert = std::make_shared<ov::op::v0::Convert>(input, out_type);
auto transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
auto transpose = std::make_shared<ov::op::v1::Transpose>(convert, transpose_const);
transpose->set_friendly_name(transpose_name);
function = std::make_shared<ov::Model>(ov::NodeVector{transpose}, ov::ParameterVector{input});
manager.register_pass<ov::intel_cpu::SwapConvertTranspose>();
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(in_type, shape);
auto transpose_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
auto transpose = std::make_shared<ov::op::v1::Transpose>(input, transpose_const);
auto convert = std::make_shared<ov::op::v0::Convert>(transpose, out_type);
transpose->set_friendly_name(transpose_name + "_original");
convert->set_friendly_name(transpose_name);
function_ref = std::make_shared<ov::Model>(ov::NodeVector{convert}, ov::ParameterVector{input});
}
}
TEST_F(SwapConvertTransposeTest, SwapConvertTransposeImpossible) {
const ov::Shape shape{1, 224, 224, 3};
const std::vector<int64_t> input_order = {0, 3, 1, 2};
const ov::element::Type in_type = ov::element::u8;
const ov::element::Type out_type = ov::element::f32;
const std::string transpose_name = "Transpose";
{
auto input = std::make_shared<ov::op::v0::Parameter>(in_type, shape);
auto convert = std::make_shared<ov::op::v0::Convert>(input, out_type);
auto transpose0_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
auto transpose0 = std::make_shared<ov::op::v1::Transpose>(convert, transpose0_const);
transpose0->set_friendly_name(transpose_name + "_0");
auto transpose1_const = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{input_order.size()}, input_order);
auto transpose1 = std::make_shared<ov::op::v1::Transpose>(convert, transpose1_const);
transpose1->set_friendly_name(transpose_name + "_1");
function = std::make_shared<ov::Model>(ov::NodeVector{transpose0, transpose1}, ov::ParameterVector{input});
manager.register_pass<ov::intel_cpu::SwapConvertTranspose>();
}
}