[GPU] Added custom canonicalize_shapes for Gather (#16733)

This commit is contained in:
Roman Lyamin
2023-04-06 10:50:57 +04:00
committed by GitHub
parent 362389c733
commit 38c8a3d15b
3 changed files with 71 additions and 1 deletions

View File

@@ -99,6 +99,30 @@ public:
return {params, optional_params};
}
static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
auto updated_impl_params = canonicalize_fused_shapes(impl_params);
const auto& prim = impl_params.typed_desc<gather>();
auto input_pshape = updated_impl_params.input_layouts[0].get_partial_shape();
auto& out_layout = updated_impl_params.output_layouts[0];
auto output_pshape = out_layout.get_partial_shape();
OPENVINO_ASSERT(input_pshape.size() <= output_pshape.size() || input_pshape.size() - output_pshape.size() == 1,
"[GPU] Gather output rank must be greater than or equal to the input rank, or less by one");
if (input_pshape.size() > output_pshape.size()) {
output_pshape.insert(output_pshape.begin() + prim->axis, ov::Dimension(1));
out_layout.set_partial_shape(output_pshape);
out_layout.format = format::adjust_to_rank(out_layout.format, output_pshape.size());
}
return primitive_impl::static_canonicalize_shapes(updated_impl_params);
}
kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
return static_canonicalize_shapes(impl_params);
}
void update_dispatch_data(const kernel_impl_params& impl_param) override {
auto kernel_params = get_kernel_params(impl_param, true);
(_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data);

View File

@@ -13,6 +13,7 @@
#include "eltwise_inst.h"
#include "fully_connected_inst.h"
#include "gemm_inst.h"
#include "gather_inst.h"
using namespace cldnn;
using namespace ::tests;
@@ -185,6 +186,46 @@ TEST(canonicalization, gemm) {
}
}
struct gather_params {
int64_t axis;
int64_t batch_dim;
bool support_neg_ind;
};
std::vector<std::pair<Shapes, gather_params>> gather_shapes_with_params {
{
{{{8, 2, 3}, {}}, {{8, 2, 3, 1}, {1, 1, 1, 1}}, {{1, 2, 3, 1}}},
{0, 0, false}
},
{
{{{8, -1, -1, 2}, {}}, {{8, -1, -1, 2}, {1, 1, 1, 1}}, {{1, -1, -1, 2}}},
{0, 0, false}
},
{
{{{8, 2, 3}, {1}}, {{8, 2, 3, 1}, {1, 1, 1, 1}}, {{1, 2, 3, 1}}},
{0, 0, false}
},
{
{{{8, 2, 3, 4}, {8}}, {{8, 2, 3, 4}, {8, 1, 1, 1}}, {{8, 2, 1, 4}}},
{2, 1, false}
}
};
TEST(canonicalization, gather) {
for (const auto& params : gather_shapes_with_params) {
layout data_layout = create_default_layout(std::get<0>(params.first)[0]);
layout indices_layout = create_default_layout(std::get<0>(params.first)[1]);
topology topology;
topology.add(input_layout("data", data_layout));
topology.add(input_layout("indices", indices_layout));
topology.add(gather("gather", input_info("data"), input_info("indices"), params.second.axis,
ov::Shape{}, params.second.batch_dim, params.second.support_neg_ind));
canonicalization_test(topology, "gather", std::get<1>(params.first), std::get<2>(params.first));
}
}
struct fusing_gemm_eltwise_params {
ov::PartialShape input_gemm_first;
ov::PartialShape weights_gemm_first;

View File

@@ -169,6 +169,11 @@ const std::vector<GatherShapeParams> dynamicInputShapeConstTargetShape = {
ov::test::InputShape(ov::PartialShape({}), {{2, 1}}),
3, 2
},
{
ov::test::InputShape(ov::PartialShape({8, -1, -1, 2}), {{8, 2, 3, 2}, {8, 4, 5, 2}}),
ov::test::InputShape(ov::PartialShape({}), {{}}),
0, 0
},
{
ov::test::InputShape(ov::PartialShape({-1, -1, -1, -1, -1}), {{2, 6, 7, 8, 9}, {2, 6, 9, 1, 2}}),
ov::test::InputShape(ov::PartialShape({}), {{2, 6}}),
@@ -193,7 +198,7 @@ const std::vector<GatherShapeParams> dynamicInputShapeConstTargetShape = {
ov::test::InputShape(ov::PartialShape({-1, -1, -1, -1, -1, -1}), {{2, 4, 2, 3, 1, 3}, {2, 4, 7, 8, 9, 10}}),
ov::test::InputShape(ov::PartialShape({}), {{2, 4}}),
2, 2
},
}
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_input_shapes_const_target_shapes, GatherGPUTest,