[GNA] Fuse gather and transpose (#18648)
This commit is contained in:
parent
a930c74143
commit
0569bb8c5d
@ -388,14 +388,14 @@ inline ov::Shape transpose_shape(const ov::Shape& shape, std::vector<size_t> ord
|
||||
* @param order the permutation array to apply to the input shape
|
||||
* @return vector with indexes to gather
|
||||
*/
|
||||
inline std::vector<size_t> make_gather_indexes_from_transpose_axes(const Shape& input_shape, const Shape& order) {
|
||||
inline std::vector<size_t> make_gather_indexes_from_transpose_axes(const Shape& input_shape, const AxisVector& order) {
|
||||
// Supported shape ranks: 2d, 3d, 4d
|
||||
if (input_shape.size() < 2 || input_shape.size() > 4) {
|
||||
THROW_GNA_EXCEPTION << "Usupported shape size: " << input_shape.size();
|
||||
}
|
||||
|
||||
ov::Shape input_shape_4d = input_shape;
|
||||
ov::Shape order_4d = order;
|
||||
ov::AxisVector order_4d = order;
|
||||
// Just to simplify the code we transform all shapes to 4d by adding dimension(s) equal to 1 at the end
|
||||
while (input_shape_4d.size() < 4) {
|
||||
input_shape_4d.push_back(1);
|
||||
@ -685,6 +685,23 @@ inline bool has_n_consumers(const std::shared_ptr<ov::Node>& node, size_t n_cons
|
||||
return node->output(0).get_target_inputs().size() == n_consumers;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Merge gather indexes.
|
||||
* @param ids_in vector with indexes to 1st gather
|
||||
* @param ids_out vector with indexes to 2nd gather
|
||||
* @return vector with indexes to merged gather
|
||||
*/
|
||||
inline std::vector<size_t> combine_gather_indexes(const std::vector<size_t>& ids_in,
|
||||
const std::vector<size_t>& ids_out) {
|
||||
if (ids_in.size() != ids_out.size())
|
||||
return {};
|
||||
std::vector<size_t> result(ids_in.size());
|
||||
for (size_t i = 0; i < result.size(); ++i) {
|
||||
result[i] = ids_in[ids_out[i]];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace graph_utils
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include "transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp"
|
||||
#include "transformations/fuse_conv_bias_activation.hpp"
|
||||
#include "transformations/gather_sinking.hpp"
|
||||
#include "transformations/gather_sinking_transpose.hpp"
|
||||
#include "transformations/handle_transposes_around_matmul.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
#include "transformations/insert_copy_layer.hpp"
|
||||
@ -139,6 +140,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
|
||||
if (has_convolution || has_maxpool || has_mvn || has_matmul) {
|
||||
manager.register_pass<ov::intel_gna::pass::ReplaceGnaNHWCLayers>();
|
||||
manager.register_pass<ov::intel_gna::pass::InsertConvolutionTransposeHW>();
|
||||
manager.register_pass<ov::intel_gna::pass::GatherSinkingTranspose>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::intel_gna::pass::TransposeCompress>();
|
||||
manager.register_pass<ov::intel_gna::pass::TSConcatForward>();
|
||||
|
@ -0,0 +1,155 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/gather_sinking_transpose.hpp"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "common/graph_utils.hpp"
|
||||
#include "openvino/cc/ngraph/itt.hpp"
|
||||
#include "openvino/opsets/opset12.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::intel_gna;
|
||||
using namespace ov::intel_gna::pass;
|
||||
|
||||
namespace {
|
||||
|
||||
inline std::vector<std::shared_ptr<ov::Node>> merge_nodes_forward(std::shared_ptr<ov::Node> gather,
|
||||
std::shared_ptr<ov::Node> transpose) {
|
||||
auto transpose_const = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto gather_ids = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(1));
|
||||
auto gather_axis = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(2));
|
||||
|
||||
// transpose ids -> gather indexes
|
||||
const ov::AxisVector transpose_ids =
|
||||
graph_utils::make_gather_indexes_from_transpose_axes(transpose->get_input_shape(0),
|
||||
transpose_const->get_axis_vector_val());
|
||||
// merge gather indexes
|
||||
const ov::AxisVector gather_new_ids =
|
||||
graph_utils::combine_gather_indexes(gather_ids->get_axis_vector_val(), transpose_ids);
|
||||
|
||||
// new gather
|
||||
auto gather_new_const_ids =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_new_ids.size()}, gather_new_ids);
|
||||
auto gather_new = std::make_shared<Gather>(gather->input_value(0), gather_new_const_ids, gather_axis);
|
||||
|
||||
ov::Shape shape_out = transpose->get_output_shape(0);
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out = std::make_shared<Reshape>(gather_new, reshape_out_const, false);
|
||||
|
||||
replace_node_update_name(transpose, reshape_out);
|
||||
|
||||
return std::vector<std::shared_ptr<ov::Node>>({gather_new, reshape_out});
|
||||
}
|
||||
|
||||
inline std::vector<std::shared_ptr<ov::Node>> merge_nodes_backward(std::shared_ptr<ov::Node> gather,
|
||||
std::shared_ptr<ov::Node> transpose) {
|
||||
auto transpose_const = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
auto gather_ids = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(1));
|
||||
auto gather_axis = ov::as_type_ptr<Constant>(gather->get_input_node_shared_ptr(2));
|
||||
|
||||
ov::Shape shape_in = gather->get_input_shape(0);
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in = std::make_shared<Reshape>(transpose->input_value(0), reshape_in_const, false);
|
||||
|
||||
// transpose ids -> gather indexes
|
||||
const ov::AxisVector transpose_ids =
|
||||
graph_utils::make_gather_indexes_from_transpose_axes(transpose->get_input_shape(0),
|
||||
transpose_const->get_axis_vector_val());
|
||||
// merge gather indexes
|
||||
const ov::AxisVector gather_new_ids =
|
||||
graph_utils::combine_gather_indexes(gather_ids->get_axis_vector_val(), transpose_ids);
|
||||
|
||||
// new gather
|
||||
auto gather_new_const_ids =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_new_ids.size()}, gather_new_ids);
|
||||
auto gather_new = std::make_shared<Gather>(reshape_in, gather_new_const_ids, gather_axis);
|
||||
|
||||
ov::replace_node_update_name(gather, gather_new);
|
||||
|
||||
return std::vector<std::shared_ptr<ov::Node>>({reshape_in, gather_new});
|
||||
}
|
||||
|
||||
inline bool is_skip_operation(const std::shared_ptr<ov::Node>& node) {
|
||||
return std::dynamic_pointer_cast<Reshape>(node) != nullptr && node->output(0).get_target_inputs().size() == 1;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatherSinkingTransposeForward::GatherSinkingTransposeForward() {
|
||||
MATCHER_SCOPE(GatherSinkingTransposeForward);
|
||||
auto gather_ids_label = wrap_type<Constant>(graph_utils::is_constant_1d);
|
||||
auto gather_label = wrap_type<Gather>({any_input(), gather_ids_label, any_input()});
|
||||
auto reshape_label = wrap_type<Reshape>({gather_label, any_input()});
|
||||
// auto transpose_label = wrap_type<Transpose>({reshape_label, any_input()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
auto reshape = as_type_ptr<Reshape>(pattern_to_output.at(reshape_label).get_node_shared_ptr());
|
||||
|
||||
// skip all the Reshape layers
|
||||
std::shared_ptr<ov::Node> non_reshape_node =
|
||||
graph_utils::get_next_node_skipping_certain(reshape, is_skip_operation);
|
||||
auto transpose = std::dynamic_pointer_cast<Transpose>(non_reshape_node);
|
||||
if (!transpose) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ov::Node>> new_nodes = merge_nodes_forward(gather, transpose);
|
||||
for (const auto& node : new_nodes) {
|
||||
register_new_node(node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(reshape_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
GatherSinkingTransposeBackward::GatherSinkingTransposeBackward() {
|
||||
MATCHER_SCOPE(GatherSinkingTransposeBackward);
|
||||
|
||||
auto reshape_label = wrap_type<Reshape>({any_input(), any_input()});
|
||||
auto gather_ids_label = wrap_type<Constant>(graph_utils::is_constant_1d);
|
||||
auto gather_label = wrap_type<Gather>({reshape_label, gather_ids_label, any_input()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto reshape = as_type_ptr<Reshape>(pattern_to_output.at(reshape_label).get_node_shared_ptr());
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
|
||||
// skip all the Reshape layers
|
||||
std::shared_ptr<ov::Node> non_reshape_node =
|
||||
graph_utils::get_prev_node_skipping_certain(reshape, is_skip_operation);
|
||||
auto transpose = std::dynamic_pointer_cast<Transpose>(non_reshape_node);
|
||||
if (!transpose) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ov::Node>> new_nodes = merge_nodes_backward(gather, transpose);
|
||||
for (const auto& node : new_nodes) {
|
||||
register_new_node(node);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
bool GatherSinkingTranspose::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
||||
RUN_ON_FUNCTION_SCOPE(GatherSinkingTranspose);
|
||||
{
|
||||
ov::pass::Manager manager(get_pass_config());
|
||||
manager.register_pass<GatherSinkingTransposeForward>();
|
||||
manager.register_pass<GatherSinkingTransposeBackward>();
|
||||
manager.run_passes(model);
|
||||
}
|
||||
return false;
|
||||
}
|
@ -0,0 +1,80 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @brief Merge Gather with Transpose through Reshape or sequence of Reshapes.
|
||||
*
|
||||
* Any1[a,b,c] Any1[a,b,c]
|
||||
* | |
|
||||
* Reshape[1,a*b*c] Reshape[1,a*b*c]
|
||||
* | |
|
||||
* Gather[1,a*b*c] Gather[1,a*b*c]
|
||||
* | |
|
||||
* Reshape1[...] |
|
||||
* | |
|
||||
* ... |
|
||||
* | |
|
||||
* ReshapeN[a,b,c] |
|
||||
* | => |
|
||||
* Transpose[c,b,a] Reshape[c,b,a]
|
||||
* | |
|
||||
* Any2[c,b,a] Any2[c,b,a]
|
||||
*
|
||||
* Gather restrictions:
|
||||
* - supported Scalar or 1D indexes
|
||||
* i.e. [1, 64] or [64]
|
||||
*/
|
||||
class GatherSinkingTransposeForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingTransposeForward", "0");
|
||||
GatherSinkingTransposeForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Merge Transpose with Gather through Reshape or sequence of Reshapes.
|
||||
*
|
||||
* Any1[a,b,c] Any1[a,b,c]
|
||||
* | |
|
||||
* Transpose[c,b,a] |
|
||||
* | |
|
||||
* Reshape1[...] |
|
||||
* | |
|
||||
* ... |
|
||||
* | |
|
||||
* Reshape[1, a*b*c] Reshape[1, a*b*c]
|
||||
* | => |
|
||||
* Gather[1, a*b*c] Gather[1, a*b*c]
|
||||
* | |
|
||||
* Reshape[c,b,a] Reshape[c,b,a]
|
||||
* | |
|
||||
* Any2[c,b,a] Any2[c,b,a]
|
||||
*
|
||||
* Gather restrictions:
|
||||
* - supported Scalar or 1D indexes
|
||||
* i.e. [1, 64] or [64]
|
||||
*/
|
||||
class GatherSinkingTransposeBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingTransposeBackward", "0");
|
||||
GatherSinkingTransposeBackward();
|
||||
};
|
||||
|
||||
class GatherSinkingTranspose : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingTranspose", "0");
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -0,0 +1,203 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace gather_transpose_merge_test {
|
||||
|
||||
namespace {
|
||||
|
||||
inline std::vector<size_t> make_indexes(size_t size) {
|
||||
std::vector<size_t> indexes(size);
|
||||
std::iota(indexes.begin(), indexes.end(), 0);
|
||||
std::reverse(indexes.begin(), indexes.end());
|
||||
return indexes;
|
||||
}
|
||||
|
||||
inline std::vector<size_t> make_transpose_order(const std::vector<size_t>& input_shape) {
|
||||
std::vector<size_t> transpose_order;
|
||||
switch (input_shape.size()) {
|
||||
case 2:
|
||||
transpose_order = {1, 0};
|
||||
break;
|
||||
case 3:
|
||||
transpose_order = {0, 2, 1};
|
||||
break;
|
||||
case 4:
|
||||
transpose_order = {0, 2, 3, 1};
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return transpose_order;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
typedef std::tuple<std::vector<size_t>, // Input shape
|
||||
ov::element::Type, // Net precision
|
||||
std::string, // Device name
|
||||
std::map<std::string, std::string>, // Configuration
|
||||
std::map<std::string, std::string> // Additional Configuration
|
||||
>
|
||||
GatherTransposeMergeTestParams;
|
||||
|
||||
class GatherTransposeMergeTest : public testing::WithParamInterface<GatherTransposeMergeTestParams>,
|
||||
virtual public SubgraphBaseTest {
|
||||
public:
|
||||
static std::string get_test_case_name(const testing::TestParamInfo<GatherTransposeMergeTestParams>& obj) {
|
||||
std::vector<size_t> input_shape;
|
||||
ov::element::Type net_type, in_type, out_type;
|
||||
std::string target_device;
|
||||
std::map<std::string, std::string> conf, conf_ext;
|
||||
|
||||
std::tie(input_shape, net_type, target_device, conf, conf_ext) = obj.param;
|
||||
for (auto& conf_item : conf_ext) {
|
||||
conf[conf_item.first] = conf_item.second;
|
||||
}
|
||||
|
||||
std::ostringstream result;
|
||||
result << "Shape=" << CommonTestUtils::vec2str(input_shape) << "_";
|
||||
result << "netPRC=" << net_type << "_";
|
||||
result << "trgDev=" << target_device;
|
||||
for (auto const& conf_i : conf) {
|
||||
result << "_configItem=" << conf_i.first.c_str() << "_" << conf_i.second.c_str();
|
||||
}
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
abs_threshold = std::numeric_limits<int32_t>::max();
|
||||
rel_threshold = std::numeric_limits<int32_t>::max();
|
||||
std::map<std::string, std::string> conf, conf_ext;
|
||||
std::tie(m_input_shape, m_net_type, targetDevice, conf, conf_ext) = this->GetParam();
|
||||
|
||||
std::vector<InputShape> input_shapes = static_shapes_to_test_representation({m_input_shape, m_input_shape});
|
||||
configuration.insert(conf.begin(), conf.end());
|
||||
for (auto& conf_item : conf_ext) {
|
||||
configuration[conf_item.first] = conf_item.second;
|
||||
}
|
||||
init_input_shapes(input_shapes);
|
||||
}
|
||||
|
||||
void init_test_model();
|
||||
|
||||
ov::element::Type m_net_type;
|
||||
std::vector<size_t> m_input_shape;
|
||||
};
|
||||
|
||||
class TransposeGatherTest : public GatherTransposeMergeTest {
|
||||
protected:
|
||||
void init_test_model() {
|
||||
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
|
||||
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
|
||||
|
||||
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
|
||||
auto transpose_const =
|
||||
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
auto transpose_node = std::make_shared<Transpose>(params[0], transpose_const);
|
||||
|
||||
std::vector<int8_t> shape_in = {1, -1};
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in_node = std::make_shared<Reshape>(transpose_node, reshape_in_const, false);
|
||||
|
||||
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
|
||||
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
|
||||
const size_t gather_axis = 1;
|
||||
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
|
||||
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
|
||||
|
||||
ov::Shape shape_out = transpose_node->get_output_shape(0);
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
|
||||
|
||||
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
|
||||
function = std::make_shared<ov::Model>(results, params, "gather_transpose_merge_test");
|
||||
}
|
||||
};
|
||||
|
||||
class GatherTransposeTest : public GatherTransposeMergeTest {
|
||||
protected:
|
||||
void init_test_model() {
|
||||
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
|
||||
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
|
||||
|
||||
std::vector<int8_t> shape_in = {1, -1};
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
|
||||
|
||||
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
|
||||
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
|
||||
const size_t gather_axis = 1;
|
||||
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
|
||||
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
|
||||
|
||||
ov::Shape shape_middle = m_input_shape;
|
||||
auto reshape_middle_const =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_middle.size()}, shape_middle);
|
||||
auto reshape_middle_node = std::make_shared<Reshape>(gather_node, reshape_middle_const, false);
|
||||
|
||||
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
|
||||
auto transpose_const =
|
||||
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
auto transpose_node = std::make_shared<Transpose>(reshape_middle_node, transpose_const);
|
||||
|
||||
ov::ResultVector results{std::make_shared<Result>(transpose_node)};
|
||||
function = std::make_shared<ov::Model>(results, params, "gather_transpose_merge_test");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeGatherTest, CompareWithRefs) {
|
||||
init_test_model();
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_P(GatherTransposeTest, CompareWithRefs) {
|
||||
init_test_model();
|
||||
run();
|
||||
}
|
||||
|
||||
std::vector<std::map<std::string, std::string>> configs = {{{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}}};
|
||||
|
||||
std::vector<std::map<std::string, std::string>> target_configs = {{{"GNA_DEVICE_MODE", "GNA_SW_FP32"}},
|
||||
{{"GNA_EXEC_TARGET", "GNA_TARGET_2_0"}},
|
||||
{{"GNA_EXEC_TARGET", "GNA_TARGET_3_0"}},
|
||||
{{"GNA_EXEC_TARGET", "GNA_TARGET_3_5"}}};
|
||||
|
||||
const std::vector<std::vector<size_t>> input_shapes = {{16, 64}, {1, 16, 64}, {1, 8, 16, 64}};
|
||||
|
||||
const ov::element::TypeVector input_precisions = {ov::element::f32};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
|
||||
TransposeGatherTest,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes),
|
||||
::testing::ValuesIn(input_precisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(target_configs)),
|
||||
TransposeGatherTest::get_test_case_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
|
||||
GatherTransposeTest,
|
||||
::testing::Combine(::testing::ValuesIn(input_shapes),
|
||||
::testing::ValuesIn(input_precisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(configs),
|
||||
::testing::ValuesIn(target_configs)),
|
||||
GatherTransposeTest::get_test_case_name);
|
||||
|
||||
} // namespace gather_transpose_merge_test
|
@ -0,0 +1,243 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common/graph_utils.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/gather_sinking_transpose.hpp"
|
||||
#include "transformations/init_node_info.hpp"
|
||||
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::intel_gna;
|
||||
|
||||
namespace gather_transpose_merge_test {
|
||||
|
||||
namespace {
|
||||
|
||||
inline std::vector<size_t> make_indexes(size_t size) {
|
||||
std::vector<size_t> indexes(size);
|
||||
std::iota(indexes.begin(), indexes.end(), 0);
|
||||
std::reverse(indexes.begin(), indexes.end());
|
||||
return indexes;
|
||||
}
|
||||
|
||||
inline std::vector<size_t> make_transpose_order(const std::vector<size_t>& input_shape) {
|
||||
std::vector<size_t> transpose_order;
|
||||
switch (input_shape.size()) {
|
||||
case 2:
|
||||
transpose_order = {1, 0};
|
||||
break;
|
||||
case 3:
|
||||
transpose_order = {0, 2, 1};
|
||||
break;
|
||||
case 4:
|
||||
transpose_order = {0, 2, 3, 1};
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return transpose_order;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
typedef std::tuple<ov::element::Type, // Net precision
|
||||
ov::Shape // Input shape
|
||||
>
|
||||
GatherTransposeMergeTestParams;
|
||||
|
||||
class GatherTransposeMergeBase : public CommonTestUtils::TestsCommon,
|
||||
public ::testing::WithParamInterface<GatherTransposeMergeTestParams> {
|
||||
public:
|
||||
static std::string get_test_name(const testing::TestParamInfo<GatherTransposeMergeTestParams>& obj) {
|
||||
std::vector<size_t> input_shape;
|
||||
ov::element::Type net_type;
|
||||
std::tie(net_type, input_shape) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "netPRC=" << net_type << "_";
|
||||
result << "Shape=" << CommonTestUtils::vec2str(input_shape);
|
||||
|
||||
return result.str();
|
||||
}
|
||||
void SetUp() override {
|
||||
std::tie(m_net_type, m_input_shape) = this->GetParam();
|
||||
}
|
||||
|
||||
virtual void init_test_model(){};
|
||||
virtual void init_ref_model(){};
|
||||
|
||||
virtual void Validate() {
|
||||
ov::pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<ov::pass::Serialize>("test_before.xml", "test_before.bin");
|
||||
m.register_pass<ov::intel_gna::pass::GatherSinkingTranspose>();
|
||||
m.register_pass<ov::pass::Serialize>("test_after.xml", "test_after.bin");
|
||||
m.run_passes(m_model);
|
||||
|
||||
const FunctionsComparator func_comparator =
|
||||
FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
FunctionsComparator::Result result = func_comparator(m_model, m_model_ref);
|
||||
|
||||
EXPECT_TRUE(result.valid) << result.message;
|
||||
}
|
||||
virtual void Run() {
|
||||
SetUp();
|
||||
init_test_model();
|
||||
init_ref_model();
|
||||
Validate();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> m_model, m_model_ref;
|
||||
ov::element::Type m_net_type;
|
||||
ov::Shape m_input_shape, m_output_shape;
|
||||
ov::AxisVector m_gather_ids_ref;
|
||||
};
|
||||
|
||||
class TransposeGatherTest : public GatherTransposeMergeBase {
|
||||
public:
|
||||
void init_test_model() override {
|
||||
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
|
||||
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
|
||||
|
||||
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
|
||||
auto transpose_const =
|
||||
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
auto transpose_node = std::make_shared<Transpose>(params[0], transpose_const);
|
||||
|
||||
ov::Shape shape_in = {1, input_shape_size};
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in_node = std::make_shared<Reshape>(transpose_node, reshape_in_const, false);
|
||||
|
||||
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
|
||||
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
|
||||
const size_t gather_axis = 1;
|
||||
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
|
||||
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
|
||||
|
||||
ov::Shape shape_out = transpose_node->get_output_shape(0);
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
|
||||
|
||||
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
|
||||
m_model = std::make_shared<ov::Model>(results, params, "concat");
|
||||
// save for the ref model
|
||||
const ov::AxisVector transpose_ids =
|
||||
graph_utils::make_gather_indexes_from_transpose_axes(m_input_shape, transpose_order);
|
||||
m_gather_ids_ref = graph_utils::combine_gather_indexes(gather_ids, transpose_ids);
|
||||
m_output_shape = shape_out;
|
||||
}
|
||||
|
||||
void init_ref_model() override {
|
||||
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
|
||||
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
|
||||
|
||||
ov::Shape shape_in = {1, input_shape_size};
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
|
||||
|
||||
const std::vector<size_t> gather_ids = m_gather_ids_ref;
|
||||
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
|
||||
const size_t gather_axis = 1;
|
||||
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
|
||||
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
|
||||
|
||||
ov::Shape shape_out = m_output_shape;
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
|
||||
|
||||
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
|
||||
m_model_ref = std::make_shared<ov::Model>(results, params, "concat");
|
||||
}
|
||||
};
|
||||
|
||||
class GatherTransposeTest : public GatherTransposeMergeBase {
|
||||
public:
|
||||
void init_test_model() override {
|
||||
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
|
||||
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
|
||||
|
||||
ov::Shape shape_in = {1, input_shape_size};
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
|
||||
|
||||
const std::vector<size_t> gather_ids = make_indexes(ov::shape_size(m_input_shape));
|
||||
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
|
||||
const size_t gather_axis = 1;
|
||||
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
|
||||
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
|
||||
|
||||
ov::Shape shape_middle = m_input_shape;
|
||||
auto reshape_middle_const =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_middle.size()}, shape_middle);
|
||||
auto reshape_middle_node = std::make_shared<Reshape>(gather_node, reshape_middle_const, false);
|
||||
|
||||
std::vector<size_t> transpose_order = make_transpose_order(m_input_shape);
|
||||
auto transpose_const =
|
||||
std::make_shared<Constant>(ov::element::i8, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
auto transpose_node = std::make_shared<Transpose>(reshape_middle_node, transpose_const);
|
||||
|
||||
ov::ResultVector results{std::make_shared<Result>(transpose_node)};
|
||||
m_model = std::make_shared<ov::Model>(results, params, "transpose_gather_test_model");
|
||||
|
||||
// save values for the ref model
|
||||
const ov::AxisVector transpose_ids =
|
||||
graph_utils::make_gather_indexes_from_transpose_axes(m_input_shape, transpose_order);
|
||||
m_gather_ids_ref = graph_utils::combine_gather_indexes(gather_ids, transpose_ids);
|
||||
m_output_shape = transpose_node->get_output_shape(0);
|
||||
}
|
||||
|
||||
void init_ref_model() override {
|
||||
auto params = ngraph::builder::makeParams(m_net_type, {m_input_shape});
|
||||
const size_t input_shape_size = ov::shape_size(params[0]->get_shape());
|
||||
|
||||
ov::Shape shape_in = {1, input_shape_size};
|
||||
auto reshape_in_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_in.size()}, shape_in);
|
||||
auto reshape_in_node = std::make_shared<Reshape>(params[0], reshape_in_const, false);
|
||||
|
||||
const std::vector<size_t> gather_ids = m_gather_ids_ref;
|
||||
auto gather_const_ids = Constant::create(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
|
||||
const size_t gather_axis = 1;
|
||||
auto gather_const_axis = Constant::create(ov::element::i64, ov::Shape{}, {gather_axis});
|
||||
auto gather_node = std::make_shared<Gather>(reshape_in_node, gather_const_ids, gather_const_axis);
|
||||
|
||||
ov::Shape shape_out = m_output_shape;
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out_node = std::make_shared<Reshape>(gather_node, reshape_out_const, false);
|
||||
|
||||
ov::ResultVector results{std::make_shared<Result>(reshape_out_node)};
|
||||
m_model_ref = std::make_shared<ov::Model>(results, params, "transpose_gather_test_model");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(TransposeGatherTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(GatherTransposeTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
ov::element::TypeVector input_precisions = {ov::element::f16, ov::element::f32};
|
||||
|
||||
const std::vector<ov::Shape> input_shapes = {{16, 64}, {1, 16, 64}, {1, 8, 16, 64}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
|
||||
TransposeGatherTest,
|
||||
::testing::Combine(::testing::ValuesIn(input_precisions), ::testing::ValuesIn(input_shapes)),
|
||||
TransposeGatherTest::get_test_name);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_merge_transpose_gather,
|
||||
GatherTransposeTest,
|
||||
::testing::Combine(::testing::ValuesIn(input_precisions), ::testing::ValuesIn(input_shapes)),
|
||||
GatherTransposeTest::get_test_name);
|
||||
|
||||
} // namespace gather_transpose_merge_test
|
Loading…
Reference in New Issue
Block a user