[GNA] Fuse gather and transpose (#18648)

This commit is contained in:
Mikhail Ryzhov 2023-07-26 15:44:44 +02:00 committed by GitHub
parent a930c74143
commit 0569bb8c5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 702 additions and 2 deletions

View File

@ -388,14 +388,14 @@ inline ov::Shape transpose_shape(const ov::Shape& shape, std::vector<size_t> ord
* @param order the permutation array to apply to the input shape
* @return vector with indexes to gather
*/
inline std::vector<size_t> make_gather_indexes_from_transpose_axes(const Shape& input_shape, const Shape& order) {
inline std::vector<size_t> make_gather_indexes_from_transpose_axes(const Shape& input_shape, const AxisVector& order) {
// Supported shape ranks: 2d, 3d, 4d
if (input_shape.size() < 2 || input_shape.size() > 4) {
THROW_GNA_EXCEPTION << "Usupported shape size: " << input_shape.size();
}
ov::Shape input_shape_4d = input_shape;
ov::Shape order_4d = order;
ov::AxisVector order_4d = order;
// Just to simplify the code we transform all shapes to 4d by adding dimension(s) equal to 1 at the end
while (input_shape_4d.size() < 4) {
input_shape_4d.push_back(1);
@ -685,6 +685,23 @@ inline bool has_n_consumers(const std::shared_ptr<ov::Node>& node, size_t n_cons
return node->output(0).get_target_inputs().size() == n_consumers;
}
/**
* @brief Merge gather indexes.
* @param ids_in vector with indexes to 1st gather
* @param ids_out vector with indexes to 2nd gather
* @return vector with indexes to merged gather
*/
inline std::vector<size_t> combine_gather_indexes(const std::vector<size_t>& ids_in,
const std::vector<size_t>& ids_out) {
if (ids_in.size() != ids_out.size())
return {};
std::vector<size_t> result(ids_in.size());
for (size_t i = 0; i < result.size(); ++i) {
result[i] = ids_in[ids_out[i]];
}
return result;
}
} // namespace graph_utils
} // namespace intel_gna
} // namespace ov

View File

@ -34,6 +34,7 @@
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
#include "transformations/fuse_conv_bias_activation.hpp"
#include "transformations/gather_sinking.hpp"
#include "transformations/gather_sinking_transpose.hpp"
#include "transformations/handle_transposes_around_matmul.hpp"
#include "transformations/init_node_info.hpp"
#include "transformations/insert_copy_layer.hpp"
@ -139,6 +140,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
if (has_convolution || has_maxpool || has_mvn || has_matmul) {
manager.register_pass<ov::intel_gna::pass::ReplaceGnaNHWCLayers>();
manager.register_pass<ov::intel_gna::pass::InsertConvolutionTransposeHW>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingTranspose>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::intel_gna::pass::TransposeCompress>();
manager.register_pass<ov::intel_gna::pass::TSConcatForward>();

View File

@ -0,0 +1,155 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/gather_sinking_transpose.hpp"
#include <utility>
#include "common/graph_utils.hpp"
#include "openvino/cc/ngraph/itt.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace ov::intel_gna;
using namespace ov::intel_gna::pass;
namespace {
inline std::vector<std::shared_ptr<ov::Node>> merge_nodes_forward(std::shared_ptr<ov::Node> gather,
std::shared_ptr<ov::Node> transpose) {
auto transpose_const = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto gather_ids = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(1));
auto gather_axis = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(2));
// transpose ids -> gather indexes
const ov::AxisVector transpose_ids =
graph_utils::make_gather_indexes_from_transpose_axes(transpose->get_input_shape(0),
transpose_const->get_axis_vector_val());
// merge gather indexes
const ov::AxisVector gather_new_ids =
graph_utils::combine_gather_indexes(gather_ids->get_axis_vector_val(), transpose_ids);
// new gather
auto gather_new_const_ids =
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_new_ids.size()}, gather_new_ids);
auto gather_new = std::make_shared<Gather>(gather->input_value(0), gather_new_const_ids, gather_axis);
ov::Shape shape_out = transpose->get_output_shape(0);
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
auto reshape_out = std::make_shared<Reshape>(gather_new, reshape_out_const, false);
replace_node_update_name(transpose, reshape_out);
return std::vector<std::shared_ptr<ov::Node>>({gather_new, reshape_out});
}
inline std::vector<std::shared_ptr<ov::Node>> merge_nodes_backward(std::shared_ptr<ov::Node> gather,
std::shared_ptr<ov::Node> transpose) {
auto transpose_const = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
auto gather_ids = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(1));
auto gather_axis = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(2));
ov::Shape shape_in = gather->get_input_shape(0);
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in = std::make_shared<Reshape>(transpose->input_value(0), reshape_in_const, false);
// transpose ids -> gather indexes
const ov::AxisVector transpose_ids =
graph_utils::make_gather_indexes_from_transpose_axes(transpose->get_input_shape(0),
transpose_const->get_axis_vector_val());
// merge gather indexes
const ov::AxisVector gather_new_ids =
graph_utils::combine_gather_indexes(gather_ids->get_axis_vector_val(), transpose_ids);
// new gather
auto gather_new_const_ids =
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_new_ids.size()}, gather_new_ids);
auto gather_new = std::make_shared<Gather>(reshape_in, gather_new_const_ids, gather_axis);
ov::replace_node_update_name(gather, gather_new);
return std::vector<std::shared_ptr<ov::Node>>({reshape_in, gather_new});
}
inline bool is_skip_operation(const std::shared_ptr<ov::Node>& node) {
return std::dynamic_pointer_cast<Reshape>(node) != nullptr && node->output(0).get_target_inputs().size() == 1;
}
} // namespace
GatherSinkingTransposeForward::GatherSinkingTransposeForward() {
MATCHER_SCOPE(GatherSinkingTransposeForward);
auto gather_ids_label = wrap_type<Constant>(graph_utils::is_constant_1d);
auto gather_label = wrap_type<Gather>({any_input(), gather_ids_label, any_input()});
auto reshape_label = wrap_type<Reshape>({gather_label, any_input()});
// auto transpose_label = wrap_type<Transpose>({reshape_label, any_input()});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
auto reshape = as_type_ptr<Reshape>(pattern_to_output.at(reshape_label).get_node_shared_ptr());
// skip all the Reshape layers
std::shared_ptr<ov::Node> non_reshape_node =
graph_utils::get_next_node_skipping_certain(reshape, is_skip_operation);
auto transpose = std::dynamic_pointer_cast<Transpose>(non_reshape_node);
if (!transpose) {
return false;
}
const std::vector<std::shared_ptr<ov::Node>> new_nodes = merge_nodes_forward(gather, transpose);
for (const auto& node : new_nodes) {
register_new_node(node);
}
return true;
};
auto m = std::make_shared<Matcher>(reshape_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
GatherSinkingTransposeBackward::GatherSinkingTransposeBackward() {
MATCHER_SCOPE(GatherSinkingTransposeBackward);
auto reshape_label = wrap_type<Reshape>({any_input(), any_input()});
auto gather_ids_label = wrap_type<Constant>(graph_utils::is_constant_1d);
auto gather_label = wrap_type<Gather>({reshape_label, gather_ids_label, any_input()});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto reshape = as_type_ptr<Reshape>(pattern_to_output.at(reshape_label).get_node_shared_ptr());
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
// skip all the Reshape layers
std::shared_ptr<ov::Node> non_reshape_node =
graph_utils::get_prev_node_skipping_certain(reshape, is_skip_operation);
auto transpose = std::dynamic_pointer_cast<Transpose>(non_reshape_node);
if (!transpose) {
return false;
}
const std::vector<std::shared_ptr<ov::Node>> new_nodes = merge_nodes_backward(gather, transpose);
for (const auto& node : new_nodes) {
register_new_node(node);
}
return true;
};
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
bool GatherSinkingTranspose::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_FUNCTION_SCOPE(GatherSinkingTranspose);
{
ov::pass::Manager manager(get_pass_config());
manager.register_pass<GatherSinkingTransposeForward>();
manager.register_pass<GatherSinkingTransposeBackward>();
manager.run_passes(model);
}
return false;
}

View File

@ -0,0 +1,80 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace intel_gna {
namespace pass {
/**
* @brief Merge Gather with Transpose through Reshape or sequence of Reshapes.
*
* Any1[a,b,c] Any1[a,b,c]
* | |
* Reshape[1,a*b*c] Reshape[1,a*b*c]
* | |
* Gather[1,a*b*c] Gather[1,a*b*c]
* | |
* Reshape1[...] |
* | |
* ... |
* | |
* ReshapeN[a,b,c] |
* | => |
* Transpose[c,b,a] Reshape[c,b,a]
* | |
* Any2[c,b,a] Any2[c,b,a]
*
* Gather restrictions:
* - supported Scalar or 1D indexes
* i.e. [1, 64] or [64]
*/
class GatherSinkingTransposeForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingTransposeForward", "0");
GatherSinkingTransposeForward();
};
/**
* @brief Merge Transpose with Gather through Reshape or sequence of Reshapes.
*
* Any1[a,b,c] Any1[a,b,c]
* | |
* Transpose[c,b,a] |
* | |
* Reshape1[...] |
* | |
* ... |
* | |
* Reshape[1, a*b*c] Reshape[1, a*b*c]
* | => |
* Gather[1, a*b*c] Gather[1, a*b*c]
* | |
* Reshape[c,b,a] Reshape[c,b,a]
* | |
* Any2[c,b,a] Any2[c,b,a]
*
* Gather restrictions:
* - supported Scalar or 1D indexes
* i.e. [1, 64] or [64]
*/
class GatherSinkingTransposeBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingTransposeBackward", "0");
GatherSinkingTransposeBackward();
};
class GatherSinkingTranspose : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("GatherSinkingTranspose", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
};
} // namespace pass
} // namespace intel_gna
} // namespace ov

View File

@ -0,0 +1,203 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <map>
#include <numeric>
#include <string>
#include <vector>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "openvino/opsets/opset10.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
using namespace ov::opset10;
using namespace ov::test;
namespace gather_transpose_merge_test {
namespace {
inline std::vector<size_t> make_indexes(size_t size) {
std::vector<size_t> indexes(size);
std::iota(indexes.begin(), indexes.end(), 0);
std::reverse(indexes.begin(), indexes.end());
return indexes;
}
inline std::vector<size_t> make_transpose_order(const std::vector<size_t>& input_shape) {
std::vector<size_t> transpose_order;
switch (input_shape.size()) {
case 2:
transpose_order = {1, 0};
break;
case 3:
transpose_order = {0, 2, 1};
break;
case 4:
transpose_order = {0, 2, 3, 1};
break;
default:
break;
}
return transpose_order;
}
} // namespace
typedef std::tuple<std::vector<size_t>, // Input shape
ov::element::Type, // Net precision
std::string, // Device name
std::map<std::string, std::string>, // Configuration
std::map<std::string, std::string> // Additional Configuration
>
GatherTransposeMergeTestParams;
class GatherTransposeMergeTest : public testing::WithParamInterface<GatherTransposeMergeTestParams>,
virtual public SubgraphBaseTest {
public:
static std::string get_test_case_name(const testing::TestParamInfo<GatherTransposeMergeTestParams>& obj) {
std::vector<size_t> input_shape;
ov::element::Type net_type, in_type, out_type;
std::string target_device;
std::map<std::string, std::string> conf, conf_ext;
std::tie(input_shape, net_type, target_device, conf, conf_ext) = obj.param;
for (auto& conf_item : conf_ext) {
conf[conf_item.first] = conf_item.second;
}
std::ostringstream result;
result << "Shape=" << CommonTestUtils::vec2str(input_shape) << "_";
result << "netPRC=" << net_type << "_";
result << "trgDev=" << target_device;
for (auto const& conf_i : conf) {
result << "_configItem=" << conf_i.first.c_str() << "_" << conf_i.second.c_str();
}
return result.str();
}
protected:
void SetUp() override {
abs_threshold = std::numeric_limits<int32_t>::max();
rel_threshold = std::numeric_limits<int32_t>::max();
std::map<std::string, std::string> conf, conf_ext;
std::tie(m_input_shape, m_net_type, targetDevice, conf, conf_ext) = this->GetParam();
std::vector<InputShape> input_shapes = static_shapes_to_test_representation({m_input_shape, m_input_shape});
configuration.insert(conf.begin(), conf.end());
for (auto& conf_item : conf_ext) {
configuration[conf_item.first] = conf_item.second;
}
init_input_shapes(input_shapes);
}
void init_test_model();
ov::element::Type m_net_type;
std::vector<size_t> m_input_shape;
};
class TransposeGatherTest : public GatherTransposeMergeTest {
protected:
void init_test_model() {
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
auto transpose_const =
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
auto transpose_node = std::make_shared<Transpose>(params[0], transpose_const);
std::vector<int8_t> shape_in = {1, -1};
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in_node = std::make_shared<Reshape>(transpose_node, reshape_in_const, false);
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
const size_t gather_axis = 1;
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
ov::Shape shape_out = transpose_node->get_output_shape(0);
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
function = std::make_shared<ov::Model>(results, params, "gather_transpose_merge_test");
}
};
class GatherTransposeTest : public GatherTransposeMergeTest {
protected:
void init_test_model() {
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
std::vector<int8_t> shape_in = {1, -1};
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
const size_t gather_axis = 1;
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
ov::Shape shape_middle = m_input_shape;
auto reshape_middle_const =
std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_middle.size()}, shape_middle);
auto reshape_middle_node = std::make_shared<Reshape>(gather_node, reshape_middle_const, false);
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
auto transpose_const =
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
auto transpose_node = std::make_shared<Transpose>(reshape_middle_node, transpose_const);
ov::ResultVector results{std::make_shared<Result>(transpose_node)};
function = std::make_shared<ov::Model>(results, params, "gather_transpose_merge_test");
}
};
TEST_P(TransposeGatherTest, CompareWithRefs) {
init_test_model();
run();
}
TEST_P(GatherTransposeTest, CompareWithRefs) {
init_test_model();
run();
}
std::vector<std::map<std::string, std::string>> configs = {{{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}}};
std::vector<std::map<std::string, std::string>> target_configs = {{{"GNA_DEVICE_MODE", "GNA_SW_FP32"}},
{{"GNA_EXEC_TARGET", "GNA_TARGET_2_0"}},
{{"GNA_EXEC_TARGET", "GNA_TARGET_3_0"}},
{{"GNA_EXEC_TARGET", "GNA_TARGET_3_5"}}};
const std::vector<std::vector<size_t>> input_shapes = {{16, 64}, {1, 16, 64}, {1, 8, 16, 64}};
const ov::element::TypeVector input_precisions = {ov::element::f32};
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
TransposeGatherTest,
::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(input_precisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(target_configs)),
TransposeGatherTest::get_test_case_name);
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
GatherTransposeTest,
::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(input_precisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(configs),
::testing::ValuesIn(target_configs)),
GatherTransposeTest::get_test_case_name);
} // namespace gather_transpose_merge_test

View File

@ -0,0 +1,243 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common/graph_utils.hpp"
#include "common_test_utils/common_utils.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/gather_sinking_transpose.hpp"
#include "transformations/init_node_info.hpp"
using namespace ov::opset10;
using namespace ov::intel_gna;
namespace gather_transpose_merge_test {
namespace {
inline std::vector<size_t> make_indexes(size_t size) {
std::vector<size_t> indexes(size);
std::iota(indexes.begin(), indexes.end(), 0);
std::reverse(indexes.begin(), indexes.end());
return indexes;
}
inline std::vector<size_t> make_transpose_order(const std::vector<size_t>& input_shape) {
std::vector<size_t> transpose_order;
switch (input_shape.size()) {
case 2:
transpose_order = {1, 0};
break;
case 3:
transpose_order = {0, 2, 1};
break;
case 4:
transpose_order = {0, 2, 3, 1};
break;
default:
break;
}
return transpose_order;
}
} // namespace
typedef std::tuple<ov::element::Type, // Net precision
ov::Shape // Input shape
>
GatherTransposeMergeTestParams;
class GatherTransposeMergeBase : public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<GatherTransposeMergeTestParams> {
public:
static std::string get_test_name(const testing::TestParamInfo<GatherTransposeMergeTestParams>& obj) {
std::vector<size_t> input_shape;
ov::element::Type net_type;
std::tie(net_type, input_shape) = obj.param;
std::ostringstream result;
result << "netPRC=" << net_type << "_";
result << "Shape=" << CommonTestUtils::vec2str(input_shape);
return result.str();
}
void SetUp() override {
std::tie(m_net_type, m_input_shape) = this->GetParam();
}
virtual void init_test_model(){};
virtual void init_ref_model(){};
virtual void Validate() {
ov::pass::Manager m;
m.register_pass<ov::pass::InitNodeInfo>();
m.register_pass<ov::pass::Serialize>("test_before.xml", "test_before.bin");
m.register_pass<ov::intel_gna::pass::GatherSinkingTranspose>();
m.register_pass<ov::pass::Serialize>("test_after.xml", "test_after.bin");
m.run_passes(m_model);
const FunctionsComparator func_comparator =
FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
FunctionsComparator::Result result = func_comparator(m_model, m_model_ref);
EXPECT_TRUE(result.valid) << result.message;
}
virtual void Run() {
SetUp();
init_test_model();
init_ref_model();
Validate();
}
protected:
std::shared_ptr<ov::Model> m_model, m_model_ref;
ov::element::Type m_net_type;
ov::Shape m_input_shape, m_output_shape;
ov::AxisVector m_gather_ids_ref;
};
class TransposeGatherTest : public GatherTransposeMergeBase {
public:
void init_test_model() override {
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
auto transpose_const =
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
auto transpose_node = std::make_shared<Transpose>(params[0], transpose_const);
ov::Shape shape_in = {1, input_shape_size};
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in_node = std::make_shared<Reshape>(transpose_node, reshape_in_const, false);
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
const size_t gather_axis = 1;
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
ov::Shape shape_out = transpose_node->get_output_shape(0);
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
m_model = std::make_shared<ov::Model>(results, params, "concat");
// save for the ref model
const ov::AxisVector transpose_ids =
graph_utils::make_gather_indexes_from_transpose_axes(m_input_shape, transpose_order);
m_gather_ids_ref = graph_utils::combine_gather_indexes(gather_ids, transpose_ids);
m_output_shape = shape_out;
}
void init_ref_model() override {
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
ov::Shape shape_in = {1, input_shape_size};
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
const std::vector<size_t> gather_ids = m_gather_ids_ref;
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
const size_t gather_axis = 1;
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
ov::Shape shape_out = m_output_shape;
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
m_model_ref = std::make_shared<ov::Model>(results, params, "concat");
}
};
class GatherTransposeTest : public GatherTransposeMergeBase {
public:
void init_test_model() override {
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
ov::Shape shape_in = {1, input_shape_size};
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
const size_t gather_axis = 1;
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
ov::Shape shape_middle = m_input_shape;
auto reshape_middle_const =
std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_middle.size()}, shape_middle);
auto reshape_middle_node = std::make_shared<Reshape>(gather_node, reshape_middle_const, false);
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
auto transpose_const =
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
auto transpose_node = std::make_shared<Transpose>(reshape_middle_node, transpose_const);
ov::ResultVector results{std::make_shared<Result>(transpose_node)};
m_model = std::make_shared<ov::Model>(results, params, "transpose_gather_test_model");
// save values for the ref model
const ov::AxisVector transpose_ids =
graph_utils::make_gather_indexes_from_transpose_axes(m_input_shape, transpose_order);
m_gather_ids_ref = graph_utils::combine_gather_indexes(gather_ids, transpose_ids);
m_output_shape = transpose_node->get_output_shape(0);
}
void init_ref_model() override {
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
ov::Shape shape_in = {1, input_shape_size};
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
const std::vector<size_t> gather_ids = m_gather_ids_ref;
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
const size_t gather_axis = 1;
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
ov::Shape shape_out = m_output_shape;
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
m_model_ref = std::make_shared<ov::Model>(results, params, "transpose_gather_test_model");
}
};
TEST_P(TransposeGatherTest, CompareWithRefs) {
Run();
}
TEST_P(GatherTransposeTest, CompareWithRefs) {
Run();
}
ov::element::TypeVector input_precisions = {ov::element::f16, ov::element::f32};
const std::vector<ov::Shape> input_shapes = {{16, 64}, {1, 16, 64}, {1, 8, 16, 64}};
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
TransposeGatherTest,
::testing::Combine(::testing::ValuesIn(input_precisions), ::testing::ValuesIn(input_shapes)),
TransposeGatherTest::get_test_name);
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
GatherTransposeTest,
::testing::Combine(::testing::ValuesIn(input_precisions), ::testing::ValuesIn(input_shapes)),
GatherTransposeTest::get_test_name);
} // namespace gather_transpose_merge_test