[GPU] Update GatherTree Op to use ngraph shape infer (#13626)
This commit is contained in:
parent
20bd062d5e
commit
1e75a4b427
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "gather_tree_inst.h"
|
||||
#include "gather_tree_shape_inference.hpp"
|
||||
|
||||
#include "intel_gpu/runtime/error_handler.hpp"
|
||||
#include "json_object.h"
|
||||
@ -20,6 +21,34 @@ layout gather_tree_inst::calc_output_layout(gather_tree_node const& node, kernel
|
||||
return input_layout;
|
||||
}
|
||||
|
||||
template<typename ShapeType>
|
||||
std::vector<layout> gather_tree_inst::calc_output_layouts(gather_tree_node const& /*node*/, const kernel_impl_params& impl_param) {
|
||||
auto desc = impl_param.typed_desc<gather_tree>();
|
||||
auto input0_layout = impl_param.get_input_layout(0);
|
||||
|
||||
auto output_type = input0_layout.data_type;
|
||||
if (impl_param.has_fused_primitives()) {
|
||||
output_type = impl_param.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
ov::op::v1::GatherTree op;
|
||||
|
||||
std::vector<ShapeType> output_shapes = {ShapeType()};
|
||||
std::vector<ShapeType> input_shapes = {
|
||||
impl_param.get_input_layout(0).get<ShapeType>(),
|
||||
impl_param.get_input_layout(1).get<ShapeType>(),
|
||||
impl_param.get_input_layout(2).get<ShapeType>(),
|
||||
impl_param.get_input_layout(3).get<ShapeType>(),
|
||||
};
|
||||
ov::op::v1::shape_infer(&op, input_shapes, output_shapes);
|
||||
|
||||
format output_format = format::adjust_to_rank(input0_layout.format, output_shapes[0].size());
|
||||
|
||||
return { layout{output_shapes[0], output_type, output_format} };
|
||||
}
|
||||
|
||||
template std::vector<layout> gather_tree_inst::calc_output_layouts<ov::PartialShape>(gather_tree_node const& node, const kernel_impl_params& impl_param);
|
||||
|
||||
std::string gather_tree_inst::to_string(gather_tree_node const& node) {
|
||||
std::stringstream primitive_description;
|
||||
node.desc_to_json()->dump(primitive_description);
|
||||
@ -27,6 +56,14 @@ std::string gather_tree_inst::to_string(gather_tree_node const& node) {
|
||||
}
|
||||
|
||||
gather_tree_inst::typed_primitive_inst(network& network, gather_tree_node const& node) : parent(network, node) {
|
||||
auto dependencies = node.get_dependencies();
|
||||
|
||||
for (auto& dep : dependencies) {
|
||||
if (dep.first->get_output_layout().is_dynamic()) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto input_layout = node.input().get_output_layout();
|
||||
|
||||
const auto input_format = input_layout.format;
|
||||
@ -48,8 +85,6 @@ gather_tree_inst::typed_primitive_inst(network& network, gather_tree_node const&
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32);
|
||||
|
||||
auto dependencies = node.get_dependencies();
|
||||
|
||||
// check input dims
|
||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
||||
"input0 size", dependencies.at(0).first->get_output_layout().get_tensor(), "output size", input_layout.get_tensor(),
|
||||
|
@ -19,6 +19,8 @@ class typed_primitive_inst<gather_tree> : public typed_primitive_inst_base<gathe
|
||||
using parent::parent;
|
||||
|
||||
public:
|
||||
template<typename ShapeType>
|
||||
static std::vector<layout> calc_output_layouts(gather_tree_node const& /*node*/, const kernel_impl_params& impl_param);
|
||||
static layout calc_output_layout(gather_tree_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(gather_tree_node const& node);
|
||||
typed_primitive_inst(network& network, gather_tree_node const& node);
|
||||
|
120
src/plugins/intel_gpu/tests/shape_infer/gather_tree_si_test.cpp
Normal file
120
src/plugins/intel_gpu/tests/shape_infer/gather_tree_si_test.cpp
Normal file
@ -0,0 +1,120 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/gather_tree.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
|
||||
#include "gather_tree_inst.h"
|
||||
|
||||
#include "program_wrapper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace shape_infer_tests {
|
||||
|
||||
struct gather_tree_test_params {
|
||||
layout step_ids;
|
||||
layout parent_ids;
|
||||
layout max_seq_len;
|
||||
layout end_token;
|
||||
layout expected_layout;
|
||||
};
|
||||
|
||||
class gather_tree_test : public testing::TestWithParam<gather_tree_test_params> { };
|
||||
|
||||
TEST_P(gather_tree_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input0_prim = std::make_shared<input_layout>("input0", p.step_ids);
|
||||
auto input1_prim = std::make_shared<input_layout>("input1", p.parent_ids);
|
||||
auto input2_prim = std::make_shared<input_layout>("input2", p.max_seq_len);
|
||||
auto input3_prim = std::make_shared<input_layout>("input3", p.end_token);
|
||||
auto gather_prim = std::make_shared<gather_tree>("output",
|
||||
input_info("input0"),
|
||||
input_info("input1"),
|
||||
input_info("input2"),
|
||||
input_info("input3"));
|
||||
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto& input0_node = prog.get_or_create(input0_prim);
|
||||
auto& input1_node = prog.get_or_create(input1_prim);
|
||||
auto& input2_node = prog.get_or_create(input2_prim);
|
||||
auto& input3_node = prog.get_or_create(input3_prim);
|
||||
|
||||
auto& gather_tree_node = prog.get_or_create(gather_prim);
|
||||
program_wrapper::add_connection(prog, input0_node, gather_tree_node);
|
||||
program_wrapper::add_connection(prog, input1_node, gather_tree_node);
|
||||
program_wrapper::add_connection(prog, input2_node, gather_tree_node);
|
||||
program_wrapper::add_connection(prog, input3_node, gather_tree_node);
|
||||
|
||||
auto res = gather_tree_inst::calc_output_layouts<ov::PartialShape>(gather_tree_node, *gather_tree_node.get_kernel_impl_params());
|
||||
|
||||
ASSERT_EQ(res.size(), 1);
|
||||
ASSERT_EQ(res[0], p.expected_layout);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, gather_tree_test,
|
||||
testing::ValuesIn(std::vector<gather_tree_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{100, 1, 10}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{100, 1, 10}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{ 1 }, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{100, 1, 10}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{ 4 }, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{ 4 }, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{ 4 }, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{20, 4, 5}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{ 4 }, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{-1, 4, -1}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(1), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(1), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(0), data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}
|
||||
},
|
||||
}));
|
||||
|
||||
} // namespace shape_infer_tests
|
@ -0,0 +1,182 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/single_layer/gather_tree.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "ie_precision.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "common_test_utils/ov_tensor_utils.hpp"
|
||||
#include <string>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace GPULayerTestsDefinitions {
|
||||
|
||||
typedef std::tuple<
|
||||
InputShape, // Input tensors shape
|
||||
ngraph::helpers::InputLayerType, // Secondary input type
|
||||
ov::element::Type_t, // Network precision
|
||||
std::string // Device name
|
||||
> GatherTreeGPUTestParams;
|
||||
|
||||
class GatherTreeLayerGPUTest : public testing::WithParamInterface<GatherTreeGPUTestParams>,
|
||||
virtual public SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherTreeGPUTestParams> &obj) {
|
||||
InputShape inputShape;
|
||||
ov::element::Type_t netPrecision;
|
||||
ngraph::helpers::InputLayerType secondaryInputType;
|
||||
std::string targetName;
|
||||
|
||||
std::tie(inputShape, secondaryInputType, netPrecision, targetName) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=" << CommonTestUtils::partialShape2str({inputShape.first}) << "_";
|
||||
result << "TS=";
|
||||
for (const auto& item : inputShape.second) {
|
||||
result << CommonTestUtils::vec2str(item) << "_";
|
||||
}
|
||||
result << "secondaryInputType=" << secondaryInputType << "_";
|
||||
result << "netPRC=" << netPrecision << "_";
|
||||
result << "trgDev=" << targetName;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InputShape inputShape;
|
||||
ov::element::Type netPrecision;
|
||||
ngraph::helpers::InputLayerType secondaryInputType;
|
||||
|
||||
std::tie(inputShape, secondaryInputType, netPrecision, targetDevice) = this->GetParam();
|
||||
InputShape parentShape{inputShape};
|
||||
InputShape::first_type maxSeqLenFirst;
|
||||
if (inputShape.first.is_dynamic()) {
|
||||
maxSeqLenFirst = {inputShape.first[1]};
|
||||
}
|
||||
InputShape::second_type maxSeqLenSecond;
|
||||
maxSeqLenSecond.reserve(inputShape.second.size());
|
||||
for (const auto& item : inputShape.second) {
|
||||
maxSeqLenSecond.emplace_back(std::initializer_list<size_t>{item[1]});
|
||||
}
|
||||
InputShape maxSeqLenShape{std::move(maxSeqLenFirst), std::move(maxSeqLenSecond)};
|
||||
|
||||
init_input_shapes({inputShape, parentShape, maxSeqLenShape});
|
||||
|
||||
// initialization of scalar input as it cannot be done properly in init_input_shapes
|
||||
inputDynamicShapes.push_back({});
|
||||
for (auto& shape : targetStaticShapes) {
|
||||
shape.push_back({});
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> inp2;
|
||||
std::shared_ptr<ngraph::Node> inp3;
|
||||
std::shared_ptr<ngraph::Node> inp4;
|
||||
|
||||
auto paramsIn = ngraph::builder::makeDynamicParams(netPrecision, {inputDynamicShapes[0]});
|
||||
if (ngraph::helpers::InputLayerType::PARAMETER == secondaryInputType) {
|
||||
inp2 = ngraph::builder::makeDynamicInputLayer(netPrecision, secondaryInputType, inputDynamicShapes[1]);
|
||||
inp3 = ngraph::builder::makeDynamicInputLayer(netPrecision, secondaryInputType, inputDynamicShapes[2]);
|
||||
inp4 = ngraph::builder::makeDynamicInputLayer(netPrecision, secondaryInputType, inputDynamicShapes[3]);
|
||||
|
||||
paramsIn.push_back(std::dynamic_pointer_cast<ngraph::opset1::Parameter>(inp2));
|
||||
paramsIn.push_back(std::dynamic_pointer_cast<ngraph::opset1::Parameter>(inp3));
|
||||
paramsIn.push_back(std::dynamic_pointer_cast<ngraph::opset1::Parameter>(inp4));
|
||||
} else if (ngraph::helpers::InputLayerType::CONSTANT == secondaryInputType) {
|
||||
auto maxBeamIndex = inputShape.second.front().at(2) - 1;
|
||||
|
||||
inp2 = ngraph::builder::makeConstant<float>(netPrecision, inputShape.second.front(), {}, true, maxBeamIndex);
|
||||
inp3 = ngraph::builder::makeConstant<float>(netPrecision, {inputShape.second.front().at(1)}, {}, true, maxBeamIndex);
|
||||
inp4 = ngraph::builder::makeConstant<float>(netPrecision, {}, {}, true, maxBeamIndex);
|
||||
} else {
|
||||
throw std::runtime_error("Unsupported inputType");
|
||||
}
|
||||
|
||||
auto operationResult = std::make_shared<ngraph::opset4::GatherTree>(paramsIn.front(), inp2, inp3, inp4);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset4::Result>(operationResult)};
|
||||
function = std::make_shared<ngraph::Function>(results, paramsIn, "GatherTree");
|
||||
}
|
||||
|
||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
||||
inputs.clear();
|
||||
const auto maxBeamIndex = targetInputStaticShapes.front().at(2) - 1;
|
||||
const auto& funcInputs = function->inputs();
|
||||
|
||||
for (size_t i = 0; i < funcInputs.size(); ++i) {
|
||||
auto tensor =
|
||||
ov::test::utils::create_and_fill_tensor(funcInputs[i].get_element_type(),
|
||||
targetInputStaticShapes[i],
|
||||
maxBeamIndex,
|
||||
(i == 2 || i == 3) ? maxBeamIndex / 2 : 0);
|
||||
inputs.insert({funcInputs[i].get_node_shared_ptr(), tensor});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(GatherTreeLayerGPUTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
|
||||
run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<ov::element::Type_t> netPrecisions = {
|
||||
ov::element::f32,
|
||||
ov::element::i32
|
||||
};
|
||||
|
||||
const std::vector<InputShape> inputDynamicShapesParameter = {
|
||||
{
|
||||
{-1, 1, -1}, {{7, 1, 10}, {8, 1, 20}}
|
||||
},
|
||||
{
|
||||
{-1, 1, {5, 10}}, {{2, 1, 7}, {5, 1, 8}}
|
||||
},
|
||||
{
|
||||
{-1, {1, 5}, 10}, {{20, 1, 10}, {17, 2, 10}}
|
||||
},
|
||||
{
|
||||
{-1, -1, -1}, {{20, 20, 15}, {30, 30, 10}}
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<InputShape> inputDynamicShapesConstant = {
|
||||
{
|
||||
{-1, 1, -1}, {{7, 1, 10}}
|
||||
},
|
||||
{
|
||||
{-1, 1, {5, 10}}, {{2, 1, 7}}
|
||||
},
|
||||
{
|
||||
{-1, {1, 5}, 10}, {{20, 1, 10}}
|
||||
},
|
||||
{
|
||||
{-1, -1, -1}, {{20, 20, 15}}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_gathertree_parameter_compareWithRefs_dynamic, GatherTreeLayerGPUTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputDynamicShapesParameter),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::PARAMETER),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherTreeLayerGPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_gathertree_constant_compareWithRefs_dynamic, GatherTreeLayerGPUTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputDynamicShapesConstant),
|
||||
::testing::Values(ngraph::helpers::InputLayerType::CONSTANT),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU)),
|
||||
GatherTreeLayerGPUTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace GPULayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user