[GPU] GatherElements shape infer (#12517)
This commit is contained in:
parent
e7e6d7883c
commit
fcf20dee86
@ -32,6 +32,9 @@ public:
|
||||
int64_t get_axis() const {
|
||||
return m_axis;
|
||||
}
|
||||
void set_axis(int64_t axis) {
|
||||
m_axis = axis;
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t m_axis{0};
|
||||
|
@ -36,6 +36,14 @@ struct gather_elements : public primitive_base<gather_elements> {
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data, indices}, ext_prim_id, output_padding), output_format(output_format), output_shape(output_shape), axis(axis) {}
|
||||
|
||||
gather_elements(const primitive_id& id,
|
||||
const primitive_id& data,
|
||||
const primitive_id& indices,
|
||||
const int64_t axis,
|
||||
const primitive_id& ext_prim_id = "",
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data, indices}, ext_prim_id, output_padding), output_format({}), output_shape({}), axis(axis) {}
|
||||
|
||||
/// @brief Gather Elements output format
|
||||
format output_format;
|
||||
/// @brief Gather Elements output shape
|
||||
|
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "gather_elements_inst.h"
|
||||
#include "gather_elements_shape_inference.hpp"
|
||||
|
||||
#include "primitive_type_base.h"
|
||||
#include "intel_gpu/runtime/error_handler.hpp"
|
||||
@ -32,6 +33,31 @@ layout gather_elements_inst::calc_output_layout(gather_elements_node const& node
|
||||
return layout(output_type, output_format, output_shape);
|
||||
}
|
||||
|
||||
template<typename ShapeType>
|
||||
std::vector<layout> gather_elements_inst::calc_output_layouts(gather_elements_node const& /*node*/, const kernel_impl_params& impl_param) {
|
||||
auto desc = impl_param.typed_desc<gather_elements>();
|
||||
auto input_layout = impl_param.get_input_layout(0);
|
||||
|
||||
auto output_type = input_layout.data_type;
|
||||
if (impl_param.has_fused_primitives()) {
|
||||
output_type = impl_param.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
ov::op::v6::GatherElements op;
|
||||
op.set_axis(desc->axis);
|
||||
|
||||
std::vector<ShapeType> output_shapes = {ShapeType()};
|
||||
std::vector<ShapeType> input_shapes = {
|
||||
impl_param.get_input_layout(0).get<ShapeType>(),
|
||||
impl_param.get_input_layout(1).get<ShapeType>()
|
||||
};
|
||||
ov::op::v6::shape_infer(&op, input_shapes, output_shapes);
|
||||
|
||||
format output_format = format::adjust_to_rank(input_layout.format, output_shapes[0].size());
|
||||
|
||||
return { layout{output_shapes[0], output_type, output_format} };
|
||||
}
|
||||
|
||||
std::string gather_elements_inst::to_string(gather_elements_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto node_info = node.desc_to_json();
|
||||
|
@ -29,6 +29,7 @@ public:
|
||||
using parent::parent;
|
||||
|
||||
program_node& input(size_t index = 0) const { return get_dependency(index); }
|
||||
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
|
||||
};
|
||||
|
||||
using gather_elements_node = typed_program_node<gather_elements>;
|
||||
@ -38,6 +39,8 @@ class typed_primitive_inst<gather_elements> : public typed_primitive_inst_base<g
|
||||
using parent = typed_primitive_inst_base<gather_elements>;
|
||||
|
||||
public:
|
||||
template<typename ShapeType>
|
||||
static std::vector<layout> calc_output_layouts(gather_elements_node const& /*node*/, const kernel_impl_params& impl_param);
|
||||
static layout calc_output_layout(gather_elements_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(gather_elements_node const& node);
|
||||
|
||||
|
@ -0,0 +1,82 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/gather_elements.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
|
||||
#include "gather_elements_inst.h"
|
||||
|
||||
#include "program_wrapper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace shape_infer_tests {
|
||||
|
||||
struct gather_elements_test_params {
|
||||
layout data_layout;
|
||||
layout indices_layout;
|
||||
int64_t axis;
|
||||
layout expected_layout;
|
||||
};
|
||||
|
||||
class gather_elements_test : public testing::TestWithParam<gather_elements_test_params> { };
|
||||
|
||||
TEST_P(gather_elements_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto data_layout_prim = std::make_shared<input_layout>("data", p.data_layout);
|
||||
auto indices_layout_prim = std::make_shared<input_layout>("indices", p.indices_layout);
|
||||
auto gather_prim = std::make_shared<gather_elements>("output", "data", "indices", p.axis);
|
||||
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto& data_layout_node = prog.get_or_create(data_layout_prim);
|
||||
auto& indices_layout_node = prog.get_or_create(indices_layout_prim);
|
||||
auto& gather_elements_node = prog.get_or_create(gather_prim);
|
||||
program_wrapper::add_connection(prog, data_layout_node, gather_elements_node);
|
||||
program_wrapper::add_connection(prog, indices_layout_node, gather_elements_node);
|
||||
auto res = gather_elements_inst::calc_output_layouts<ov::PartialShape>(gather_elements_node, *gather_elements_node.get_kernel_impl_params());
|
||||
|
||||
ASSERT_EQ(res.size(), 1);
|
||||
ASSERT_EQ(res[0], p.expected_layout);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, gather_elements_test,
|
||||
testing::ValuesIn(std::vector<gather_elements_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{3, 7, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx},
|
||||
1,
|
||||
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx},
|
||||
1,
|
||||
layout{ov::PartialShape{3, 10, 5}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape{3, 7, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
1,
|
||||
layout{ov::PartialShape{3, -1, 5}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
1,
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}
|
||||
}
|
||||
}));
|
||||
|
||||
} // shape_infer_tests
|
Loading…
Reference in New Issue
Block a user