Revise GatherTree reference implementation (#7275)

* Add visitor api test

* Review ngraph op shell with type_prop tests

* Add op to list of trusted operations

* Change name of struct with information of inputs

* Add include of array data structure to fix windowds compilation error

* Add template plugin test class

* Remove usage of CoordinateTransform index function call from reference implementation

* Rename SLT test suite

* Add template plugin unit test

* Add serialization SLTs

* Add indentation on GatherTreeParams class data members
This commit is contained in:
Gabriele Galiero Casay 2021-09-10 13:02:49 +02:00 committed by GitHub
parent 288a7633bf
commit deeb96440f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 522 additions and 109 deletions

View File

@ -0,0 +1,100 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ie_core.hpp>
#include <ie_ngraph_utils.hpp>
#include <limits>
#include <algorithm>
#include <ngraph/ngraph.hpp>
#include <shared_test_classes/base/layer_test_utils.hpp>
#include "base_reference_test.hpp"
using namespace reference_tests;
using namespace ngraph;
using namespace InferenceEngine;
namespace {
struct GatherTreeParams {
template <class IN_ET>
GatherTreeParams(const ngraph::Shape inShape, std::vector<IN_ET> stepIds, const std::vector<IN_ET> parentIds,
const std::vector<IN_ET> maxSeqLen, const std::vector<IN_ET> endToken, std::vector<IN_ET> output) :
stepIdsTensor(inShape, element::from<IN_ET>(), stepIds), parentIdsTensor(inShape, element::from<IN_ET>(), parentIds),
maxSeqLenTensor(ngraph::Shape{inShape[1]}, element::from<IN_ET>(), maxSeqLen), endTokenTensor(ngraph::Shape{}, element::from<IN_ET>(), endToken),
expectedTensor(inShape, element::from<IN_ET>(), output) {}
Tensor stepIdsTensor;
Tensor parentIdsTensor;
Tensor maxSeqLenTensor;
Tensor endTokenTensor;
Tensor expectedTensor;
};
class ReferenceGatherTreeTest : public testing::TestWithParam<GatherTreeParams>, public CommonReferenceTest {
public:
void SetUp() override {
auto params = GetParam();
function = CreateFunction(params);
inputData = {params.stepIdsTensor.data, params.parentIdsTensor.data, params.maxSeqLenTensor.data, params.endTokenTensor.data};
refOutData = {params.expectedTensor.data};
}
static std::string getTestCaseName(const testing::TestParamInfo<GatherTreeParams>& obj) {
auto param = obj.param;
std::ostringstream result;
result << "iType=" << param.stepIdsTensor.type << "_";
result << "iShape=" << param.stepIdsTensor.shape;
return result.str();
}
private:
static std::shared_ptr<Function> CreateFunction(const GatherTreeParams& params) {
const auto stepIds = std::make_shared<op::Parameter>(params.stepIdsTensor.type, params.stepIdsTensor.shape);
const auto parentIds = std::make_shared<op::Parameter>(params.parentIdsTensor.type, params.parentIdsTensor.shape);
const auto maxSeqLen = std::make_shared<op::Parameter>(params.maxSeqLenTensor.type, params.maxSeqLenTensor.shape);
const auto endToken = std::make_shared<op::Parameter>(params.endTokenTensor.type, params.endTokenTensor.shape);
const auto gatherTree = std::make_shared<op::v1::GatherTree>(stepIds, parentIds, maxSeqLen, endToken);
return std::make_shared<Function>(NodeVector {gatherTree}, ParameterVector {stepIds, parentIds, maxSeqLen, endToken});
}
};
TEST_P(ReferenceGatherTreeTest, CompareWithRefs) {
Exec();
}
template <element::Type_t IN_ET>
std::vector<GatherTreeParams> generateGatherTreeParams() {
using T = typename element_type_traits<IN_ET>::value_type;
std::vector<GatherTreeParams> gatherTreeParams {
GatherTreeParams(Shape{4, 1, 3},
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1},
std::vector<T>{0, 0, 0, 0, 1, 1, 2, 1, 2, -1, -1, -1},
std::vector<T>{3},
std::vector<T>{10},
std::vector<T>{2, 2, 2, 6, 5, 6, 7, 8, 9, 10, 10, 10}),
GatherTreeParams(Shape{2, 2, 2},
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8},
std::vector<T>{0, 0, 0, 0, 0, 0, 0, 0},
std::vector<T>{2, 4},
std::vector<T>{0},
std::vector<T>{1, 1, 3, 3, 5, 6, 7, 8})
};
return gatherTreeParams;
}
std::vector<GatherTreeParams> generateGatherTreeCombinedParams() {
const std::vector<std::vector<GatherTreeParams>> gatherTreeTypeParams {
generateGatherTreeParams<element::Type_t::f32>(),
generateGatherTreeParams<element::Type_t::i32>()};
std::vector<GatherTreeParams> combinedParams;
for (const auto& params : gatherTreeTypeParams) {
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
}
return combinedParams;
}
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_With_Hardcoded_Refs, ReferenceGatherTreeTest,
testing::ValuesIn(generateGatherTreeCombinedParams()), ReferenceGatherTreeTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,41 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "shared_test_classes/single_layer/gather_tree.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
TEST_P(GatherTreeLayerTest, Serialize) {
Serialize();
}
const std::vector<InferenceEngine::Precision> netPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::I32
};
const std::vector<std::vector<size_t>> inputShapes = { {5, 1, 10}, {1, 1, 10}, {20, 1, 10}, {20, 20, 10} };
const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
ngraph::helpers::InputLayerType::CONSTANT,
ngraph::helpers::InputLayerType::PARAMETER
};
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_Serialization, GatherTreeLayerTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(secondaryInputTypes),
::testing::ValuesIn(netPrecisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
GatherTreeLayerTest::getTestCaseName);
} // namespace

View File

@ -23,7 +23,7 @@ const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
ngraph::helpers::InputLayerType::PARAMETER
};
INSTANTIATE_TEST_SUITE_P(Basic_smoke, GatherTreeLayerTest,
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree, GatherTreeLayerTest,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(secondaryInputTypes),

View File

@ -72,11 +72,12 @@ void runtime::reference::gather_tree(const char* step_ids,
throw ngraph_error("max_seq_len must have size of BATCH_SIZE");
}
NGRAPH_SUPPRESS_DEPRECATED_START
ngraph::CoordinateTransform cordinate_transform(step_ids_shape);
const auto in_strides = row_major_strides(step_ids_shape);
ngraph::CoordinateTransformBasic cordinate_transform(step_ids_shape);
for (const auto& coord : cordinate_transform) {
memcpy(out + cordinate_transform.index(coord) * elem_size, end_token, elem_size);
const auto out_idx = std::inner_product(coord.begin(), coord.end(), in_strides.begin(), 0);
memcpy(out + out_idx * elem_size, end_token, elem_size);
}
for (size_t batch = 0; batch < batch_size; ++batch) {
@ -87,31 +88,35 @@ void runtime::reference::gather_tree(const char* step_ids,
continue;
}
auto offset = cordinate_transform.index({max_seq_in_beam - 1, batch, beam}) * elem_size;
const auto coord = Coordinate({max_seq_in_beam - 1, batch, beam});
const auto offset = std::inner_product(coord.begin(), coord.end(), in_strides.begin(), 0) * elem_size;
memcpy(out + offset, step_ids + offset, elem_size);
size_t parent = _asIndex(parent_ids + offset, element_type);
for (size_t level = max_seq_in_beam - 1; level-- > 0;) {
memcpy(out + cordinate_transform.index({level, batch, beam}) * elem_size,
step_ids + cordinate_transform.index({level, batch, parent}) * elem_size,
elem_size);
const auto coord_beam = Coordinate({level, batch, beam});
const auto out_idx = std::inner_product(coord_beam.begin(), coord_beam.end(), in_strides.begin(), 0);
parent =
_asIndex(parent_ids + cordinate_transform.index({level, batch, parent}) * elem_size, element_type);
const auto coord_parent = Coordinate({level, batch, parent});
const auto step_ids_idx =
std::inner_product(coord_parent.begin(), coord_parent.end(), in_strides.begin(), 0);
memcpy(out + out_idx * elem_size, step_ids + step_ids_idx * elem_size, elem_size);
parent = _asIndex(parent_ids + step_ids_idx * elem_size, element_type);
}
bool finished = false;
for (size_t time = 0; time < max_seq_in_beam; ++time) {
const auto out_coord = Coordinate({time, batch, beam});
const auto out_idx = std::inner_product(out_coord.begin(), out_coord.end(), in_strides.begin(), 0);
if (finished) {
memcpy(out + cordinate_transform.index({time, batch, beam}) * elem_size, end_token, elem_size);
} else if (_asIndex(out + cordinate_transform.index({time, batch, beam}) * elem_size, element_type) ==
_asIndex(end_token, element_type)) {
memcpy(out + out_idx * elem_size, end_token, elem_size);
} else if (_asIndex(out + out_idx * elem_size, element_type) == _asIndex(end_token, element_type)) {
finished = true;
}
}
}
}
NGRAPH_SUPPRESS_DEPRECATED_END
}

View File

@ -33,35 +33,68 @@ bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor) {
void op::v1::GatherTree::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v1_GatherTree_validate_and_infer_types);
const auto& step_ids_rank = get_input_partial_shape(0);
const auto& parent_idx_rank = get_input_partial_shape(1);
const auto& max_seq_len_rank = get_input_partial_shape(2);
const auto& end_token_rank = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
step_ids_rank.rank().is_dynamic() || step_ids_rank.rank().get_length() == 3,
"step_ids input rank must equal to 3 (step_ids rank: ",
step_ids_rank.rank().get_length(),
")");
NODE_VALIDATION_CHECK(this,
parent_idx_rank.rank().is_dynamic() || parent_idx_rank.rank().get_length() == 3,
"parent_idx input rank must equal to 3 (parent_idx rank: ",
parent_idx_rank.rank().get_length(),
")");
NODE_VALIDATION_CHECK(this,
max_seq_len_rank.rank().is_dynamic() || max_seq_len_rank.rank().get_length() == 1,
"max_seq_len input rank must equal to 1 (max_seq_len rank: ",
max_seq_len_rank.rank().get_length(),
")");
NODE_VALIDATION_CHECK(this,
end_token_rank.rank().is_dynamic() || end_token_rank.rank().get_length() == 0,
"end_token input rank must be scalar (end_token rank: ",
end_token_rank.rank().get_length(),
")");
const auto& step_ids_et = get_input_element_type(0);
set_output_type(0, step_ids_et, step_ids_rank);
const auto& parent_idx_et = get_input_element_type(1);
const auto& max_seq_len_et = get_input_element_type(2);
const auto& end_token_et = get_input_element_type(3);
element::Type result_et;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, step_ids_et, parent_idx_et) &&
element::Type::merge(result_et, result_et, max_seq_len_et) &&
element::Type::merge(result_et, result_et, end_token_et),
"Inputs must have the same element type. Got: step_ids (",
step_ids_et,
"), parent_idx_et (",
parent_idx_et,
"), max_seq_len (",
max_seq_len_et,
"), end_token (",
end_token_et,
")");
NODE_VALIDATION_CHECK(this,
result_et.is_real() || result_et.is_integral_number(),
"Element type of inputs must be numeric. Got: ",
result_et);
const auto& step_ids_pshape = get_input_partial_shape(0);
const auto& parent_idx_pshape = get_input_partial_shape(1);
const auto& max_seq_len_pshape = get_input_partial_shape(2);
const auto& end_token_pshape = get_input_partial_shape(3);
PartialShape result_pshape{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(result_pshape, step_ids_pshape) &&
PartialShape::merge_into(result_pshape, parent_idx_pshape) &&
result_pshape.rank().compatible(3),
"step_ids and parent_idx inputs must have the same shape with rank 3. Got: ",
step_ids_pshape,
" and ",
parent_idx_pshape,
", respectively");
NODE_VALIDATION_CHECK(this,
max_seq_len_pshape.rank().compatible(1),
"max_seq_len input must have rank 1. Got: ",
max_seq_len_pshape);
if (result_pshape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(this,
Dimension::merge(result_pshape[1], result_pshape[1], max_seq_len_pshape[0]),
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
"step_ids/parent_idx inputs. Got: ",
result_pshape[1],
" and ",
max_seq_len_pshape[0],
", respectively");
}
NODE_VALIDATION_CHECK(this,
end_token_pshape.rank().compatible(0),
"end_token input must be scalar. Got: ",
end_token_pshape);
set_output_type(0, result_et, result_pshape);
}

View File

@ -270,6 +270,7 @@ set(SRC
visitors/op/floor_mod.cpp
visitors/op/floor.cpp
visitors/op/gather.cpp
visitors/op/gather_tree.cpp
visitors/op/gelu.cpp
visitors/op/greater_equal.cpp
visitors/op/greater.cpp

View File

@ -2,6 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <array>
#include <utility>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
@ -9,78 +12,280 @@
using namespace std;
using namespace ngraph;
namespace {
constexpr size_t step_ids_input_idx = 0;
constexpr size_t parent_idx_input_idx = 1;
constexpr size_t max_seq_len_input_idx = 2;
constexpr size_t end_token_input_idx = 3;
constexpr size_t gather_tree_required_inputs = 4;
struct GatherTreeInputInfo {
element::Type in_et;
PartialShape in_pshape;
};
using GatherTreeInputParams = std::array<GatherTreeInputInfo, gather_tree_required_inputs>;
std::shared_ptr<Node> makeGatherTreeOp(const GatherTreeInputParams& p) {
if (p.size() != gather_tree_required_inputs) {
throw runtime_error("GatherTree requires 4 inputs");
}
auto step_ids = make_shared<op::Parameter>(p.at(step_ids_input_idx).in_et, p.at(step_ids_input_idx).in_pshape);
auto parent_idx =
make_shared<op::Parameter>(p.at(parent_idx_input_idx).in_et, p.at(parent_idx_input_idx).in_pshape);
auto max_seq_len =
make_shared<op::Parameter>(p.at(max_seq_len_input_idx).in_et, p.at(max_seq_len_input_idx).in_pshape);
auto end_token = make_shared<op::Parameter>(p.at(end_token_input_idx).in_et, p.at(end_token_input_idx).in_pshape);
return make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
}
} // namespace
TEST(type_prop, gather_tree_invalid_input_element_type) {
Shape scalar_shape{};
Shape vector_shape{2};
Shape tensor_shape{1, 2, 3};
element::Type input_et = element::boolean;
GatherTreeInputParams params{GatherTreeInputInfo{input_et, tensor_shape},
GatherTreeInputInfo{input_et, tensor_shape},
GatherTreeInputInfo{input_et, vector_shape},
GatherTreeInputInfo{input_et, scalar_shape}};
try {
auto gather_tree = makeGatherTreeOp(params);
FAIL() << "Invalid element types for inputs not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Element type of inputs must be numeric.");
} catch (...) {
FAIL() << "Element type check for inputs failed for unexpected reason";
}
}
TEST(type_prop, gather_tree_incompatible_input_element_types) {
element::Type float_et = element::f32;
element::Type integer_et = element::i32;
Shape scalar_shape{};
Shape vector_shape{2};
Shape tensor_shape{1, 2, 3};
vector<GatherTreeInputParams> test_cases = {// step_ids input has incompatible element type
GatherTreeInputParams{GatherTreeInputInfo{integer_et, tensor_shape},
GatherTreeInputInfo{float_et, tensor_shape},
GatherTreeInputInfo{float_et, vector_shape},
GatherTreeInputInfo{float_et, scalar_shape}},
// parent_idx input has incompatible element type
GatherTreeInputParams{GatherTreeInputInfo{float_et, tensor_shape},
GatherTreeInputInfo{integer_et, tensor_shape},
GatherTreeInputInfo{float_et, vector_shape},
GatherTreeInputInfo{float_et, scalar_shape}},
// max_seq_len input has incompatible element type
GatherTreeInputParams{GatherTreeInputInfo{float_et, tensor_shape},
GatherTreeInputInfo{float_et, tensor_shape},
GatherTreeInputInfo{integer_et, vector_shape},
GatherTreeInputInfo{float_et, scalar_shape}},
// end_token input has incompatible element type
GatherTreeInputParams{GatherTreeInputInfo{float_et, tensor_shape},
GatherTreeInputInfo{float_et, tensor_shape},
GatherTreeInputInfo{float_et, vector_shape},
GatherTreeInputInfo{integer_et, scalar_shape}}};
for (const auto& test_case : test_cases) {
try {
auto gather_tree = makeGatherTreeOp(test_case);
FAIL() << "Incompatible element types for inputs not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "Inputs must have the same element type.");
} catch (...) {
FAIL() << "Element type check for inputs failed for unexpected reason";
}
}
}
TEST(type_prop, gather_tree_input_element_types) {
Shape scalar_shape{};
Shape vector_shape{2};
Shape tensor_shape{1, 2, 3};
std::vector<element::Type> element_types{element::u4,
element::u8,
element::u16,
element::u32,
element::i8,
element::i16,
element::i32,
element::i64,
element::f32,
element::f64,
element::u32};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(element_types), std::end(element_types), [&](element::Type et) {
GatherTreeInputParams params{GatherTreeInputInfo{et, tensor_shape},
GatherTreeInputInfo{et, tensor_shape},
GatherTreeInputInfo{et, vector_shape},
GatherTreeInputInfo{et, scalar_shape}};
test_cases.insert(test_cases.end(), params);
});
for (const auto& test_case : test_cases) {
try {
EXPECT_NO_THROW(makeGatherTreeOp(test_case));
} catch (...) {
FAIL() << "Inputs element type validation check failed for unexpected reason";
}
}
}
TEST(type_prop, gather_tree_invalid_step_ids_and_parent_idx_input_shapes) {
element::Type et = element::f32;
Shape scalar_shape{};
PartialShape vector_shape{Dimension()};
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{PartialShape{1}, PartialShape{1, 2, 3}},
{PartialShape{1, 2, 3}, PartialShape{3, 3, 3, 3}},
{PartialShape{Dimension(), Dimension(), 3}, PartialShape::dynamic(4)},
{PartialShape::dynamic(2), PartialShape::dynamic()},
{PartialShape{1, 2, 3}, PartialShape{Dimension(), Dimension(3, 5), 3}}};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
GatherTreeInputParams params{GatherTreeInputInfo{et, shapes.first},
GatherTreeInputInfo{et, shapes.second},
GatherTreeInputInfo{et, vector_shape},
GatherTreeInputInfo{et, scalar_shape}};
test_cases.insert(test_cases.end(), params);
});
for (const auto& test_case : test_cases) {
try {
auto gather_tree = makeGatherTreeOp(test_case);
FAIL() << "Incompatible shapes for inputs step_ids and parent_idx not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "step_ids and parent_idx inputs must have the same shape with rank 3.");
} catch (...) {
FAIL() << "Shape check for step_ids and parent_idx inputs failed for unexpected reason";
}
}
}
TEST(type_prop, gather_tree_invalid_max_seq_len_rank) {
element::Type et = element::f32;
Shape tensor_shape{1, 2, 3};
Shape scalar_shape{};
std::vector<PartialShape> max_seq_len_shapes = {{}, {Dimension(), 1}, PartialShape::dynamic(3), {1, 2, 3, 4}};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(max_seq_len_shapes), std::end(max_seq_len_shapes), [&](PartialShape shape) {
GatherTreeInputParams params{GatherTreeInputInfo{et, tensor_shape},
GatherTreeInputInfo{et, tensor_shape},
GatherTreeInputInfo{et, shape},
GatherTreeInputInfo{et, scalar_shape}};
test_cases.insert(test_cases.end(), params);
});
for (const auto& test_case : test_cases) {
try {
auto gather_tree = makeGatherTreeOp(test_case);
FAIL() << "Invalid shapes for max_seq_len input not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "max_seq_len input must have rank 1.");
} catch (...) {
FAIL() << "Shape check for max_seq_len input failed for unexpected reason";
}
}
}
TEST(type_prop, gather_tree_incompatible_step_ids_and_max_seq_len_shapes) {
element::Type et = element::f32;
Shape scalar_shape{};
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{PartialShape{1, 2, 3}, PartialShape{4}},
{PartialShape{Dimension(), 2, 3}, PartialShape{Dimension(3, 6)}}};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
GatherTreeInputParams params{GatherTreeInputInfo{et, shapes.first},
GatherTreeInputInfo{et, shapes.first},
GatherTreeInputInfo{et, shapes.second},
GatherTreeInputInfo{et, scalar_shape}};
test_cases.insert(test_cases.end(), params);
});
for (const auto& test_case : test_cases) {
try {
auto gather_tree = makeGatherTreeOp(test_case);
FAIL() << "Incompatible shapes for inputs step_ids and max_seq_len not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
"step_ids/parent_idx inputs.");
} catch (...) {
FAIL() << "Shape check for step_ids and max_seq_len inputs failed for unexpected reason";
}
}
}
TEST(type_prop, gather_tree_output_shape) {
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
element::Type et = element::f32;
Shape scalar_shape{};
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{PartialShape{1, 2, 3}, PartialShape{2}},
{PartialShape{1, 2, 3}, PartialShape::dynamic(1)},
{PartialShape{Dimension(), 2, Dimension()}, PartialShape{2}},
{
PartialShape::dynamic(3),
PartialShape{4},
},
{PartialShape{Dimension(), Dimension(3, 5), Dimension()}, PartialShape{Dimension(1, 3)}},
{PartialShape::dynamic(), PartialShape::dynamic()}};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
GatherTreeInputParams params{GatherTreeInputInfo{et, shapes.first},
GatherTreeInputInfo{et, shapes.first},
GatherTreeInputInfo{et, shapes.second},
GatherTreeInputInfo{et, scalar_shape}};
test_cases.insert(test_cases.end(), params);
});
for (const auto& test_case : test_cases) {
try {
auto gather_tree = makeGatherTreeOp(test_case);
ASSERT_EQ(gather_tree->get_output_shape(0), (Shape{1, 2, 3}));
ASSERT_EQ(gather_tree->get_output_element_type(0), element::i64);
}
TEST(type_prop, gather_tree_pooling_step_ids_invalid_rank) {
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try {
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid step_ids input rank not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("step_ids input rank must equal to 3 (step_ids rank: 4)"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
PartialShape result_shape{test_case.at(step_ids_input_idx).in_pshape};
PartialShape max_seq_len_shape{test_case.at(max_seq_len_input_idx).in_pshape};
if (result_shape.rank().is_static() && max_seq_len_shape.rank().is_static()) {
result_shape[1] = result_shape[1] & max_seq_len_shape[0];
}
ASSERT_EQ(gather_tree->get_output_partial_shape(0), result_shape);
ASSERT_EQ(gather_tree->get_output_element_type(0), et);
} catch (...) {
FAIL() << "Output shape check failed for unexpected reason";
}
}
}
TEST(type_prop, gather_tree_parent_idx_invalid_rank) {
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try {
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid parent_idx input rank not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("parent_idx input rank must equal to 3 (parent_idx rank: 4)"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, gather_tree_invalid_end_token_rank) {
element::Type et = element::f32;
TEST(type_prop, gather_tree_max_seq_len_invalid_rank) {
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1, 2});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
try {
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid parent_idx input rank not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("max_seq_len input rank must equal to 1 (max_seq_len rank: 2)"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
Shape tensor_shape{1, 2, 3};
Shape vector_shape{2};
TEST(type_prop, gather_tree_end_token_invalid_rank) {
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1});
try {
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Should have thrown, so fail if it didn't
FAIL() << "Ivalid end_token input rank not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("end_token input rank must be scalar (end_token rank: 1)"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
std::vector<PartialShape> end_token_shapes = {{3}, {Dimension(), 1}, PartialShape::dynamic(3), {1, 2, 3, 4}};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(end_token_shapes), std::end(end_token_shapes), [&](PartialShape shape) {
GatherTreeInputParams params{GatherTreeInputInfo{et, tensor_shape},
GatherTreeInputInfo{et, tensor_shape},
GatherTreeInputInfo{et, vector_shape},
GatherTreeInputInfo{et, shape}};
test_cases.insert(test_cases.end(), params);
});
for (const auto& test_case : test_cases) {
try {
auto gather_tree = makeGatherTreeOp(test_case);
FAIL() << "Invalid shapes for end_token input not detected";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), "end_token input must be scalar.");
} catch (...) {
FAIL() << "Shape check for end_token input failed for unexpected reason";
}
}
}

View File

@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "util/visitor.hpp"
using namespace ngraph;
using ngraph::test::NodeBuilder;
using ngraph::test::ValueMap;
TEST(attributes, gather_tree_op) {
NodeBuilder::get_ops().register_factory<opset1::GatherTree>();
auto step_ids = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto parent_idx = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto max_seq_len = std::make_shared<op::Parameter>(element::f32, Shape{2});
auto end_token = std::make_shared<op::Parameter>(element::f32, Shape{});
auto gather_tree = std::make_shared<opset1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
NodeBuilder builder(gather_tree);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}