[GPU] GatherElements shape infer (#12517)

This commit is contained in:
Roman Lyamin 2022-08-25 09:38:53 +04:00 committed by GitHub
parent e7e6d7883c
commit fcf20dee86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 0 deletions

View File

@ -32,6 +32,9 @@ public:
int64_t get_axis() const {
return m_axis;
}
void set_axis(int64_t axis) {
m_axis = axis;
}
private:
int64_t m_axis{0};

View File

@ -36,6 +36,14 @@ struct gather_elements : public primitive_base<gather_elements> {
const padding& output_padding = padding())
: primitive_base(id, {data, indices}, ext_prim_id, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {}
gather_elements(const primitive_id& id,
const primitive_id& data,
const primitive_id& indices,
const int64_t axis,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
: primitive_base(id, {data, indices}, ext_prim_id, output_padding), output_format({}), output_shape({}), axis(axis) {}
/// @brief Gather Elements output format
format output_format;
/// @brief Gather Elements output shape

View File

@ -3,6 +3,7 @@
//
#include "gather_elements_inst.h"
#include "gather_elements_shape_inference.hpp"
#include "primitive_type_base.h"
#include "intel_gpu/runtime/error_handler.hpp"
@ -32,6 +33,31 @@ layout gather_elements_inst::calc_output_layout(gather_elements_node const& node
return layout(output_type, output_format, output_shape);
}
template<typename ShapeType>
std::vector<layout> gather_elements_inst::calc_output_layouts(gather_elements_node const& /*node*/, const kernel_impl_params& impl_param) {
auto desc = impl_param.typed_desc<gather_elements>();
auto input_layout = impl_param.get_input_layout(0);
auto output_type = input_layout.data_type;
if (impl_param.has_fused_primitives()) {
output_type = impl_param.get_fused_output_layout().data_type;
}
ov::op::v6::GatherElements op;
op.set_axis(desc->axis);
std::vector<ShapeType> output_shapes = {ShapeType()};
std::vector<ShapeType> input_shapes = {
impl_param.get_input_layout(0).get<ShapeType>(),
impl_param.get_input_layout(1).get<ShapeType>()
};
ov::op::v6::shape_infer(&op, input_shapes, output_shapes);
format output_format = format::adjust_to_rank(input_layout.format, output_shapes[0].size());
return { layout{output_shapes[0], output_type, output_format} };
}
std::string gather_elements_inst::to_string(gather_elements_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();

View File

@ -29,6 +29,7 @@ public:
using parent::parent;
program_node& input(size_t index = 0) const { return get_dependency(index); }
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
};
using gather_elements_node = typed_program_node<gather_elements>;
@ -38,6 +39,8 @@ class typed_primitive_inst<gather_elements> : public typed_primitive_inst_base<g
using parent = typed_primitive_inst_base<gather_elements>;
public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(gather_elements_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(gather_elements_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(gather_elements_node const& node);

View File

@ -0,0 +1,82 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/gather_elements.hpp>
#include <intel_gpu/primitives/data.hpp>
#include "gather_elements_inst.h"
#include "program_wrapper.h"
#include <cmath>
#include <algorithm>
using namespace cldnn;
using namespace ::tests;
namespace shape_infer_tests {
struct gather_elements_test_params {
layout data_layout;
layout indices_layout;
int64_t axis;
layout expected_layout;
};
class gather_elements_test : public testing::TestWithParam<gather_elements_test_params> { };
TEST_P(gather_elements_test, shape_infer) {
auto p = GetParam();
auto& engine = get_test_engine();
auto data_layout_prim = std::make_shared<input_layout>("data", p.data_layout);
auto indices_layout_prim = std::make_shared<input_layout>("indices", p.indices_layout);
auto gather_prim = std::make_shared<gather_elements>("output", "data", "indices", p.axis);
cldnn::program prog(engine);
auto& data_layout_node = prog.get_or_create(data_layout_prim);
auto& indices_layout_node = prog.get_or_create(indices_layout_prim);
auto& gather_elements_node = prog.get_or_create(gather_prim);
program_wrapper::add_connection(prog, data_layout_node, gather_elements_node);
program_wrapper::add_connection(prog, indices_layout_node, gather_elements_node);
auto res = gather_elements_inst::calc_output_layouts<ov::PartialShape>(gather_elements_node, *gather_elements_node.get_kernel_impl_params());
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], p.expected_layout);
}
INSTANTIATE_TEST_SUITE_P(smoke, gather_elements_test,
testing::ValuesIn(std::vector<gather_elements_test_params>{
{
layout{ov::PartialShape{3, 7, 5}, data_types::f32, format::bfyx},
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx},
1,
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx}
},
{
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx},
1,
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx}
},
{
layout{ov::PartialShape{3, 7, 5}, data_types::f32, format::bfyx},
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
1,
layout{ov::PartialShape{3, -1, 5}, data_types::f32, format::bfyx}
},
{
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
1,
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}
}
}));
} // shape_infer_tests