[GPU] Update ScatterNDUpdate Op to use ngraph shape infer (#15176)

This commit is contained in:
Kelvin Choi 2023-02-06 14:31:33 +09:00 committed by GitHub
parent 3bfd07d535
commit 8ed71a22fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 307 additions and 3 deletions

View File

@ -17,6 +17,8 @@ class typed_primitive_inst<scatter_nd_update> : public typed_primitive_inst_base
using parent::parent;
public:
template<typename ShapeType>
static std::vector<layout> calc_output_layouts(scatter_nd_update_node const& /*node*/, const kernel_impl_params& impl_param);
static layout calc_output_layout(scatter_nd_update_node const& node, kernel_impl_params const& impl_param);
static std::string to_string(scatter_nd_update_node const& node);

View File

@ -28,6 +28,7 @@
#include "depth_to_space_inst.h"
#include "region_yolo_inst.h"
#include "prior_box_inst.h"
#include "scatter_nd_update_inst.h"
#include "to_string_utils.h"
#include <vector>
#include <memory>
@ -1601,9 +1602,12 @@ format layout_optimizer::get_preferred_format(program_node& node) {
// Let reorder_input pass to check input format instead of output_format in forward investigation, vice versa
auto out_lay_rank = node.get_output_layout(false).get_rank();
auto in_lay_rank = node.get_dependencies().size() > 0 ? node.get_dependency(0).get_output_layout(false).get_rank() : out_lay_rank;
if (in_lay_rank != out_lay_rank)
node.set_preferred_input_fmt(0, get_preferred_format(node.get_dependency(0)));
auto dep_size = node.get_dependencies().size();
for (size_t i = 0; i < dep_size; i++) {
auto in_lay_rank = node.get_dependency(i).get_output_layout(false).get_rank();
if (in_lay_rank != out_lay_rank)
node.set_preferred_input_fmt(i, get_preferred_format(node.get_dependency(i)));
}
// shape_infer_dep should be plain format because the memory is being read by ngraph shape infer as is
if (node.is_shape_infer_dep()) {

View File

@ -9,6 +9,8 @@
#include "json_object.h"
#include <string>
#include "scatter_nd_base_shape_inference.hpp"
namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(scatter_nd_update)
@ -26,6 +28,29 @@ layout scatter_nd_update_inst::calc_output_layout(scatter_nd_update_node const&
return layout{output_type, input_format, output_shape};
}
template<typename ShapeType>
std::vector<layout> scatter_nd_update_inst::calc_output_layouts(scatter_nd_update_node const& /*node*/, const kernel_impl_params& impl_param) {
auto input0_layout = impl_param.get_input_layout(0);
auto input1_layout = impl_param.get_input_layout(1);
auto input2_layout = impl_param.get_input_layout(2);
std::vector<ShapeType> input_shapes = {
input0_layout.get<ShapeType>(), // inputs_shape
input1_layout.get<ShapeType>(), // indices_shape,
input2_layout.get<ShapeType>(), // updates_shape,
};
std::vector<ShapeType> output_shapes = {ShapeType()};
ov::op::v3::ScatterNDUpdate op;
shape_infer(&op, input_shapes, output_shapes);
return { layout{output_shapes[0], input0_layout.data_type, input0_layout.format} };
}
template std::vector<layout>
scatter_nd_update_inst::calc_output_layouts<ov::PartialShape>(scatter_nd_update_node const& node, const kernel_impl_params& impl_param);
std::string scatter_nd_update_inst::to_string(scatter_nd_update_node const& node) {
auto desc = node.get_primitive();
auto node_info = node.desc_to_json();

View File

@ -0,0 +1,86 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/input_layout.hpp>
#include <intel_gpu/primitives/scatter_nd_update.hpp>
#include <intel_gpu/primitives/data.hpp>
#include "scatter_nd_update_inst.h"
#include "program_wrapper.h"
#include <cmath>
#include <algorithm>
using namespace cldnn;
using namespace ::tests;
namespace shape_infer_tests {
struct scatter_nd_update_test_params {
layout data_layout;
layout indices_layout;
layout updates_layout;
int64_t indices_rank;
layout expected_layout;
};
class scatter_nd_update_test : public testing::TestWithParam<scatter_nd_update_test_params> { };
TEST_P(scatter_nd_update_test, shape_infer) {
auto p = GetParam();
auto& engine = get_test_engine();
auto data_layout_prim = std::make_shared<input_layout>("data", p.data_layout);
auto indices_layout_prim = std::make_shared<input_layout>("indices", p.indices_layout);
auto updates_layout_prim = std::make_shared<input_layout>("updates", p.updates_layout);
auto scatter_nd_update_prim = std::make_shared<scatter_nd_update>("output",
input_info("data"),
input_info("indices"),
input_info("updates"),
p.indices_rank);
cldnn::program prog(engine);
auto& data_node = prog.get_or_create(data_layout_prim);
auto& incides_node = prog.get_or_create(indices_layout_prim);
auto& updates_node = prog.get_or_create(updates_layout_prim);
auto& scatter_nd_update_node = prog.get_or_create(scatter_nd_update_prim);
program_wrapper::add_connection(prog, data_node, scatter_nd_update_node);
program_wrapper::add_connection(prog, incides_node, scatter_nd_update_node);
program_wrapper::add_connection(prog, updates_node, scatter_nd_update_node);
auto res = scatter_nd_update_inst::calc_output_layouts<ov::PartialShape>(scatter_nd_update_node, *scatter_nd_update_node.get_kernel_impl_params());
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], p.expected_layout);
}
INSTANTIATE_TEST_SUITE_P(smoke, scatter_nd_update_test,
testing::ValuesIn(std::vector<scatter_nd_update_test_params>{
{
layout{ov::PartialShape{1000, 256, 10, 15}, data_types::f32, format::bfyx},
layout{ov::PartialShape{25, 125, 3}, data_types::f32, format::bfyx},
layout{ov::PartialShape{25, 125, 15}, data_types::f32, format::bfyx},
3,
layout{ov::PartialShape{1000, 256, 10, 15}, data_types::f32, format::bfyx},
},
{
layout{ov::PartialShape{3, 5}, data_types::f32, format::bfyx},
layout{ov::PartialShape{2}, data_types::f32, format::bfyx},
layout{ov::PartialShape{3, 2}, data_types::f32, format::bfyx},
1,
layout{ov::PartialShape{3, 5}, data_types::f32, format::bfyx},
},
{
layout{ov::PartialShape::dynamic(2), data_types::f32, format::bfyx},
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
3,
layout{ov::PartialShape::dynamic(2), data_types::f32, format::bfyx},
}
}));
} // namespace shape_infer_tests

View File

@ -0,0 +1,187 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_test_classes/single_layer/scatter_ND_update.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "ie_precision.hpp"
#include "ngraph_functions/builders.hpp"
#include <common_test_utils/ov_tensor_utils.hpp>
#include <string>
using namespace ngraph;
using namespace InferenceEngine;
using namespace ov::test;
namespace GPULayerTestsDefinitions {
using ScatterNDUpdateShapes = std::vector<InputShape>;
using IndicesValues = std::vector<std::int64_t>;
struct ScatterNDUpdateLayerParams {
ScatterNDUpdateShapes inputShapes;
IndicesValues indicesValues;
};
typedef std::tuple<
ScatterNDUpdateLayerParams,
ElementType, // input precision
ElementType // indices precision
> ScatterUpdateParams;
class ScatterNDUpdateLayerGPUTest : public testing::WithParamInterface<ScatterUpdateParams>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(testing::TestParamInfo<ScatterUpdateParams> obj) {
ScatterNDUpdateLayerParams scatterParams;
ElementType inputPrecision;
ElementType idxPrecision;
std::tie(scatterParams, inputPrecision, idxPrecision) = obj.param;
const auto inputShapes = scatterParams.inputShapes;
const auto indicesValues = scatterParams.indicesValues;
std::ostringstream result;
result << inputPrecision << "_IS=";
for (const auto& shape : inputShapes) {
result << CommonTestUtils::partialShape2str({ shape.first }) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
for (const auto& targetShape : shape.second) {
result << CommonTestUtils::vec2str(targetShape) << "_";
}
result << ")_";
}
result << "indices_values=" << CommonTestUtils::vec2str(indicesValues);
result << "_idx_precision=" << idxPrecision;
result << "trgDev=GPU";
return result.str();
}
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (int i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
const auto& inputPrecision = funcInput.get_element_type();
const auto& targetShape = targetInputStaticShapes[i];
ov::Tensor tensor;
if (i == 1) {
tensor = ov::Tensor{ inputPrecision, targetShape };
const auto indicesVals = std::get<0>(this->GetParam()).indicesValues;
if (inputPrecision == ElementType::i32) {
auto data = tensor.data<std::int32_t>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
data[i] = static_cast<std::int32_t>(indicesVals[i]);
}
} else if (inputPrecision == ElementType::i64) {
auto data = tensor.data<std::int64_t>();
for (size_t i = 0; i < tensor.get_size(); ++i) {
data[i] = indicesVals[i];
}
} else {
IE_THROW() << "GatherNDUpdate. Unsupported indices precision: " << inputPrecision;
}
} else {
if (inputPrecision.is_real()) {
tensor = ov::test::utils::create_and_fill_tensor(inputPrecision, targetShape, 10, 0, 1000);
} else {
tensor = ov::test::utils::create_and_fill_tensor(inputPrecision, targetShape);
}
}
inputs.insert({ funcInput.get_node_shared_ptr(), tensor });
}
}
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_GPU;
ScatterNDUpdateLayerParams scatterParams;
ElementType inputPrecision;
ElementType idxPrecision;
std::tie(scatterParams, inputPrecision, idxPrecision) = this->GetParam();
const auto inputShapes = scatterParams.inputShapes;
init_input_shapes(inputShapes);
auto dataParams = ngraph::builder::makeDynamicParams(inputPrecision, { inputDynamicShapes[0], inputDynamicShapes[2] });
auto indicesParam = ngraph::builder::makeDynamicParams(idxPrecision, { inputDynamicShapes[1] });
dataParams[0]->set_friendly_name("Param_1");
indicesParam[0]->set_friendly_name("Param_2");
dataParams[1]->set_friendly_name("Param_3");
auto scatter = std::make_shared<ngraph::opset4::ScatterNDUpdate>(dataParams[0], indicesParam[0], dataParams[1]);
ngraph::ParameterVector allParams{ dataParams[0], indicesParam[0], dataParams[1] };
auto makeFunction = [](ParameterVector &params, const std::shared_ptr<Node> &lastNode) {
ResultVector results;
for (int i = 0; i < lastNode->get_output_size(); i++)
results.push_back(std::make_shared<opset1::Result>(lastNode->output(i)));
return std::make_shared<Function>(results, params, "ScatterNDUpdateLayerGPUTest");
};
function = makeFunction(allParams, scatter);
}
};
TEST_P(ScatterNDUpdateLayerGPUTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
}
namespace ScatterNDUpdate {
const std::vector<ScatterNDUpdateLayerParams> scatterParams = {
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{-1, -1, -1, -1, -1}, {{10, 9, 10, 9, 10}, {10, 1, 11, 2, 5}, {10, 15, 8, 1, 7}}},
{{2, 2, 1}, {{2, 2, 1}, {2, 2, 1}, {2, 2, 1}}},
{{-1, -1, -1, -1, -1, -1}, {{2, 2, 9, 10, 9, 10}, {2, 2, 1, 11, 2, 5}, {2, 2, 15, 8, 1, 7}}},
},
IndicesValues{ 5, 6, 2, 8 }
},
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{-1, -1, -1, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, 3}, {{2, 3}, {2, 3}, {2, 3}}},
{{-1, -1}, {{2, 11}, {2, 12}, {2, 8}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 }
},
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{{3, 10}, -1, {3, 9}, -1}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, 3}, {{2, 3}, {2, 3}, {2, 3}}},
{{{2, 4}, -1}, {{2, 11}, {2, 12}, {2, 8}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 }
},
ScatterNDUpdateLayerParams{
ScatterNDUpdateShapes{
{{{3, 10}, {4, 11}, {3, 9}, {8, 15}}, {{ 10, 9, 9, 11 }, { 7, 5, 3, 12 }, { 3, 4, 9, 8 }}},
{{2, 3}, {{2, 3}, {2, 3}, {2, 3}}},
{{{2, 4}, -1}, {{2, 11}, {2, 12}, {2, 8}}}
},
IndicesValues{ 0, 1, 1, 2, 2, 2 }
},
};
const std::vector<ElementType> inputPrecisions = {
ElementType::f32,
};
const std::vector<ElementType> constantPrecisions = {
ElementType::i32,
};
INSTANTIATE_TEST_SUITE_P(smoke_scatterndupdate_CompareWithRefs_dynamic, ScatterNDUpdateLayerGPUTest,
::testing::Combine(
::testing::ValuesIn(scatterParams),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(constantPrecisions)),
ScatterNDUpdateLayerGPUTest::getTestCaseName);
} // namespace ScatterNDUpdate
} // namespace GPULayerTestsDefinitions