[GPU] Added custom canonicalize_shapes for Gather (#16733)
This commit is contained in:
@@ -99,6 +99,30 @@ public:
|
||||
return {params, optional_params};
|
||||
}
|
||||
|
||||
static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params) {
|
||||
auto updated_impl_params = canonicalize_fused_shapes(impl_params);
|
||||
const auto& prim = impl_params.typed_desc<gather>();
|
||||
|
||||
auto input_pshape = updated_impl_params.input_layouts[0].get_partial_shape();
|
||||
auto& out_layout = updated_impl_params.output_layouts[0];
|
||||
auto output_pshape = out_layout.get_partial_shape();
|
||||
|
||||
OPENVINO_ASSERT(input_pshape.size() <= output_pshape.size() || input_pshape.size() - output_pshape.size() == 1,
|
||||
"[GPU] Gather output rank must be greater than or equal to the input rank, or less by one");
|
||||
|
||||
if (input_pshape.size() > output_pshape.size()) {
|
||||
output_pshape.insert(output_pshape.begin() + prim->axis, ov::Dimension(1));
|
||||
out_layout.set_partial_shape(output_pshape);
|
||||
out_layout.format = format::adjust_to_rank(out_layout.format, output_pshape.size());
|
||||
}
|
||||
|
||||
return primitive_impl::static_canonicalize_shapes(updated_impl_params);
|
||||
}
|
||||
|
||||
kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const override {
|
||||
return static_canonicalize_shapes(impl_params);
|
||||
}
|
||||
|
||||
void update_dispatch_data(const kernel_impl_params& impl_param) override {
|
||||
auto kernel_params = get_kernel_params(impl_param, true);
|
||||
(_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data);
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "eltwise_inst.h"
|
||||
#include "fully_connected_inst.h"
|
||||
#include "gemm_inst.h"
|
||||
#include "gather_inst.h"
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
@@ -185,6 +186,46 @@ TEST(canonicalization, gemm) {
|
||||
}
|
||||
}
|
||||
|
||||
struct gather_params {
|
||||
int64_t axis;
|
||||
int64_t batch_dim;
|
||||
bool support_neg_ind;
|
||||
};
|
||||
|
||||
std::vector<std::pair<Shapes, gather_params>> gather_shapes_with_params {
|
||||
{
|
||||
{{{8, 2, 3}, {}}, {{8, 2, 3, 1}, {1, 1, 1, 1}}, {{1, 2, 3, 1}}},
|
||||
{0, 0, false}
|
||||
},
|
||||
{
|
||||
{{{8, -1, -1, 2}, {}}, {{8, -1, -1, 2}, {1, 1, 1, 1}}, {{1, -1, -1, 2}}},
|
||||
{0, 0, false}
|
||||
},
|
||||
{
|
||||
{{{8, 2, 3}, {1}}, {{8, 2, 3, 1}, {1, 1, 1, 1}}, {{1, 2, 3, 1}}},
|
||||
{0, 0, false}
|
||||
},
|
||||
{
|
||||
{{{8, 2, 3, 4}, {8}}, {{8, 2, 3, 4}, {8, 1, 1, 1}}, {{8, 2, 1, 4}}},
|
||||
{2, 1, false}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(canonicalization, gather) {
|
||||
for (const auto& params : gather_shapes_with_params) {
|
||||
layout data_layout = create_default_layout(std::get<0>(params.first)[0]);
|
||||
layout indices_layout = create_default_layout(std::get<0>(params.first)[1]);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("data", data_layout));
|
||||
topology.add(input_layout("indices", indices_layout));
|
||||
topology.add(gather("gather", input_info("data"), input_info("indices"), params.second.axis,
|
||||
ov::Shape{}, params.second.batch_dim, params.second.support_neg_ind));
|
||||
|
||||
canonicalization_test(topology, "gather", std::get<1>(params.first), std::get<2>(params.first));
|
||||
}
|
||||
}
|
||||
|
||||
struct fusing_gemm_eltwise_params {
|
||||
ov::PartialShape input_gemm_first;
|
||||
ov::PartialShape weights_gemm_first;
|
||||
|
||||
@@ -169,6 +169,11 @@ const std::vector<GatherShapeParams> dynamicInputShapeConstTargetShape = {
|
||||
ov::test::InputShape(ov::PartialShape({}), {{2, 1}}),
|
||||
3, 2
|
||||
},
|
||||
{
|
||||
ov::test::InputShape(ov::PartialShape({8, -1, -1, 2}), {{8, 2, 3, 2}, {8, 4, 5, 2}}),
|
||||
ov::test::InputShape(ov::PartialShape({}), {{}}),
|
||||
0, 0
|
||||
},
|
||||
{
|
||||
ov::test::InputShape(ov::PartialShape({-1, -1, -1, -1, -1}), {{2, 6, 7, 8, 9}, {2, 6, 9, 1, 2}}),
|
||||
ov::test::InputShape(ov::PartialShape({}), {{2, 6}}),
|
||||
@@ -193,7 +198,7 @@ const std::vector<GatherShapeParams> dynamicInputShapeConstTargetShape = {
|
||||
ov::test::InputShape(ov::PartialShape({-1, -1, -1, -1, -1, -1}), {{2, 4, 2, 3, 1, 3}, {2, 4, 7, 8, 9, 10}}),
|
||||
ov::test::InputShape(ov::PartialShape({}), {{2, 4}}),
|
||||
2, 2
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_input_shapes_const_target_shapes, GatherGPUTest,
|
||||
|
||||
Reference in New Issue
Block a user