Revise GatherTree reference implementation (#7275)
* Add visitor api test * Review ngraph op shell with type_prop tests * Add op to list of trusted operations * Change name of struct with information of inputs * Add include of array data structure to fix windowds compilation error * Add template plugin test class * Remove usage of CoordinateTransform index function call from reference implementation * Rename SLT test suite * Add template plugin unit test * Add serialization SLTs * Add indentation on GatherTreeParams class data members
This commit is contained in:
parent
288a7633bf
commit
deeb96440f
@ -0,0 +1,100 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ie_core.hpp>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <shared_test_classes/base/layer_test_utils.hpp>
|
||||
|
||||
#include "base_reference_test.hpp"
|
||||
|
||||
using namespace reference_tests;
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace {
|
||||
struct GatherTreeParams {
|
||||
template <class IN_ET>
|
||||
GatherTreeParams(const ngraph::Shape inShape, std::vector<IN_ET> stepIds, const std::vector<IN_ET> parentIds,
|
||||
const std::vector<IN_ET> maxSeqLen, const std::vector<IN_ET> endToken, std::vector<IN_ET> output) :
|
||||
stepIdsTensor(inShape, element::from<IN_ET>(), stepIds), parentIdsTensor(inShape, element::from<IN_ET>(), parentIds),
|
||||
maxSeqLenTensor(ngraph::Shape{inShape[1]}, element::from<IN_ET>(), maxSeqLen), endTokenTensor(ngraph::Shape{}, element::from<IN_ET>(), endToken),
|
||||
expectedTensor(inShape, element::from<IN_ET>(), output) {}
|
||||
Tensor stepIdsTensor;
|
||||
Tensor parentIdsTensor;
|
||||
Tensor maxSeqLenTensor;
|
||||
Tensor endTokenTensor;
|
||||
Tensor expectedTensor;
|
||||
};
|
||||
|
||||
class ReferenceGatherTreeTest : public testing::TestWithParam<GatherTreeParams>, public CommonReferenceTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.stepIdsTensor.data, params.parentIdsTensor.data, params.maxSeqLenTensor.data, params.endTokenTensor.data};
|
||||
refOutData = {params.expectedTensor.data};
|
||||
}
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherTreeParams>& obj) {
|
||||
auto param = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "iType=" << param.stepIdsTensor.type << "_";
|
||||
result << "iShape=" << param.stepIdsTensor.shape;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::shared_ptr<Function> CreateFunction(const GatherTreeParams& params) {
|
||||
const auto stepIds = std::make_shared<op::Parameter>(params.stepIdsTensor.type, params.stepIdsTensor.shape);
|
||||
const auto parentIds = std::make_shared<op::Parameter>(params.parentIdsTensor.type, params.parentIdsTensor.shape);
|
||||
const auto maxSeqLen = std::make_shared<op::Parameter>(params.maxSeqLenTensor.type, params.maxSeqLenTensor.shape);
|
||||
const auto endToken = std::make_shared<op::Parameter>(params.endTokenTensor.type, params.endTokenTensor.shape);
|
||||
const auto gatherTree = std::make_shared<op::v1::GatherTree>(stepIds, parentIds, maxSeqLen, endToken);
|
||||
return std::make_shared<Function>(NodeVector {gatherTree}, ParameterVector {stepIds, parentIds, maxSeqLen, endToken});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ReferenceGatherTreeTest, CompareWithRefs) {
|
||||
Exec();
|
||||
}
|
||||
|
||||
template <element::Type_t IN_ET>
|
||||
std::vector<GatherTreeParams> generateGatherTreeParams() {
|
||||
using T = typename element_type_traits<IN_ET>::value_type;
|
||||
std::vector<GatherTreeParams> gatherTreeParams {
|
||||
GatherTreeParams(Shape{4, 1, 3},
|
||||
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1},
|
||||
std::vector<T>{0, 0, 0, 0, 1, 1, 2, 1, 2, -1, -1, -1},
|
||||
std::vector<T>{3},
|
||||
std::vector<T>{10},
|
||||
std::vector<T>{2, 2, 2, 6, 5, 6, 7, 8, 9, 10, 10, 10}),
|
||||
GatherTreeParams(Shape{2, 2, 2},
|
||||
std::vector<T>{1, 2, 3, 4, 5, 6, 7, 8},
|
||||
std::vector<T>{0, 0, 0, 0, 0, 0, 0, 0},
|
||||
std::vector<T>{2, 4},
|
||||
std::vector<T>{0},
|
||||
std::vector<T>{1, 1, 3, 3, 5, 6, 7, 8})
|
||||
};
|
||||
return gatherTreeParams;
|
||||
}
|
||||
|
||||
std::vector<GatherTreeParams> generateGatherTreeCombinedParams() {
|
||||
const std::vector<std::vector<GatherTreeParams>> gatherTreeTypeParams {
|
||||
generateGatherTreeParams<element::Type_t::f32>(),
|
||||
generateGatherTreeParams<element::Type_t::i32>()};
|
||||
std::vector<GatherTreeParams> combinedParams;
|
||||
|
||||
for (const auto& params : gatherTreeTypeParams) {
|
||||
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
|
||||
}
|
||||
return combinedParams;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_With_Hardcoded_Refs, ReferenceGatherTreeTest,
|
||||
testing::ValuesIn(generateGatherTreeCombinedParams()), ReferenceGatherTreeTest::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,41 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "shared_test_classes/single_layer/gather_tree.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
|
||||
TEST_P(GatherTreeLayerTest, Serialize) {
|
||||
Serialize();
|
||||
}
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::I32
|
||||
};
|
||||
|
||||
const std::vector<std::vector<size_t>> inputShapes = { {5, 1, 10}, {1, 1, 10}, {20, 1, 10}, {20, 20, 10} };
|
||||
|
||||
const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
|
||||
ngraph::helpers::InputLayerType::CONSTANT,
|
||||
ngraph::helpers::InputLayerType::PARAMETER
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree_Serialization, GatherTreeLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(secondaryInputTypes),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GatherTreeLayerTest::getTestCaseName);
|
||||
} // namespace
|
@ -23,7 +23,7 @@ const std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
|
||||
ngraph::helpers::InputLayerType::PARAMETER
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(Basic_smoke, GatherTreeLayerTest,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_GatherTree, GatherTreeLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(secondaryInputTypes),
|
||||
|
@ -72,11 +72,12 @@ void runtime::reference::gather_tree(const char* step_ids,
|
||||
throw ngraph_error("max_seq_len must have size of BATCH_SIZE");
|
||||
}
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
ngraph::CoordinateTransform cordinate_transform(step_ids_shape);
|
||||
const auto in_strides = row_major_strides(step_ids_shape);
|
||||
ngraph::CoordinateTransformBasic cordinate_transform(step_ids_shape);
|
||||
|
||||
for (const auto& coord : cordinate_transform) {
|
||||
memcpy(out + cordinate_transform.index(coord) * elem_size, end_token, elem_size);
|
||||
const auto out_idx = std::inner_product(coord.begin(), coord.end(), in_strides.begin(), 0);
|
||||
memcpy(out + out_idx * elem_size, end_token, elem_size);
|
||||
}
|
||||
|
||||
for (size_t batch = 0; batch < batch_size; ++batch) {
|
||||
@ -87,31 +88,35 @@ void runtime::reference::gather_tree(const char* step_ids,
|
||||
continue;
|
||||
}
|
||||
|
||||
auto offset = cordinate_transform.index({max_seq_in_beam - 1, batch, beam}) * elem_size;
|
||||
|
||||
const auto coord = Coordinate({max_seq_in_beam - 1, batch, beam});
|
||||
const auto offset = std::inner_product(coord.begin(), coord.end(), in_strides.begin(), 0) * elem_size;
|
||||
memcpy(out + offset, step_ids + offset, elem_size);
|
||||
|
||||
size_t parent = _asIndex(parent_ids + offset, element_type);
|
||||
|
||||
for (size_t level = max_seq_in_beam - 1; level-- > 0;) {
|
||||
memcpy(out + cordinate_transform.index({level, batch, beam}) * elem_size,
|
||||
step_ids + cordinate_transform.index({level, batch, parent}) * elem_size,
|
||||
elem_size);
|
||||
const auto coord_beam = Coordinate({level, batch, beam});
|
||||
const auto out_idx = std::inner_product(coord_beam.begin(), coord_beam.end(), in_strides.begin(), 0);
|
||||
|
||||
parent =
|
||||
_asIndex(parent_ids + cordinate_transform.index({level, batch, parent}) * elem_size, element_type);
|
||||
const auto coord_parent = Coordinate({level, batch, parent});
|
||||
const auto step_ids_idx =
|
||||
std::inner_product(coord_parent.begin(), coord_parent.end(), in_strides.begin(), 0);
|
||||
|
||||
memcpy(out + out_idx * elem_size, step_ids + step_ids_idx * elem_size, elem_size);
|
||||
|
||||
parent = _asIndex(parent_ids + step_ids_idx * elem_size, element_type);
|
||||
}
|
||||
|
||||
bool finished = false;
|
||||
for (size_t time = 0; time < max_seq_in_beam; ++time) {
|
||||
const auto out_coord = Coordinate({time, batch, beam});
|
||||
const auto out_idx = std::inner_product(out_coord.begin(), out_coord.end(), in_strides.begin(), 0);
|
||||
if (finished) {
|
||||
memcpy(out + cordinate_transform.index({time, batch, beam}) * elem_size, end_token, elem_size);
|
||||
} else if (_asIndex(out + cordinate_transform.index({time, batch, beam}) * elem_size, element_type) ==
|
||||
_asIndex(end_token, element_type)) {
|
||||
memcpy(out + out_idx * elem_size, end_token, elem_size);
|
||||
} else if (_asIndex(out + out_idx * elem_size, element_type) == _asIndex(end_token, element_type)) {
|
||||
finished = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
@ -33,35 +33,68 @@ bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
||||
void op::v1::GatherTree::validate_and_infer_types() {
|
||||
NGRAPH_OP_SCOPE(v1_GatherTree_validate_and_infer_types);
|
||||
const auto& step_ids_rank = get_input_partial_shape(0);
|
||||
const auto& parent_idx_rank = get_input_partial_shape(1);
|
||||
const auto& max_seq_len_rank = get_input_partial_shape(2);
|
||||
const auto& end_token_rank = get_input_partial_shape(3);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
step_ids_rank.rank().is_dynamic() || step_ids_rank.rank().get_length() == 3,
|
||||
"step_ids input rank must equal to 3 (step_ids rank: ",
|
||||
step_ids_rank.rank().get_length(),
|
||||
")");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
parent_idx_rank.rank().is_dynamic() || parent_idx_rank.rank().get_length() == 3,
|
||||
"parent_idx input rank must equal to 3 (parent_idx rank: ",
|
||||
parent_idx_rank.rank().get_length(),
|
||||
")");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
max_seq_len_rank.rank().is_dynamic() || max_seq_len_rank.rank().get_length() == 1,
|
||||
"max_seq_len input rank must equal to 1 (max_seq_len rank: ",
|
||||
max_seq_len_rank.rank().get_length(),
|
||||
")");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
end_token_rank.rank().is_dynamic() || end_token_rank.rank().get_length() == 0,
|
||||
"end_token input rank must be scalar (end_token rank: ",
|
||||
end_token_rank.rank().get_length(),
|
||||
")");
|
||||
|
||||
const auto& step_ids_et = get_input_element_type(0);
|
||||
set_output_type(0, step_ids_et, step_ids_rank);
|
||||
const auto& parent_idx_et = get_input_element_type(1);
|
||||
const auto& max_seq_len_et = get_input_element_type(2);
|
||||
const auto& end_token_et = get_input_element_type(3);
|
||||
|
||||
element::Type result_et;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(result_et, step_ids_et, parent_idx_et) &&
|
||||
element::Type::merge(result_et, result_et, max_seq_len_et) &&
|
||||
element::Type::merge(result_et, result_et, end_token_et),
|
||||
"Inputs must have the same element type. Got: step_ids (",
|
||||
step_ids_et,
|
||||
"), parent_idx_et (",
|
||||
parent_idx_et,
|
||||
"), max_seq_len (",
|
||||
max_seq_len_et,
|
||||
"), end_token (",
|
||||
end_token_et,
|
||||
")");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
result_et.is_real() || result_et.is_integral_number(),
|
||||
"Element type of inputs must be numeric. Got: ",
|
||||
result_et);
|
||||
|
||||
const auto& step_ids_pshape = get_input_partial_shape(0);
|
||||
const auto& parent_idx_pshape = get_input_partial_shape(1);
|
||||
const auto& max_seq_len_pshape = get_input_partial_shape(2);
|
||||
const auto& end_token_pshape = get_input_partial_shape(3);
|
||||
|
||||
PartialShape result_pshape{PartialShape::dynamic()};
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
PartialShape::merge_into(result_pshape, step_ids_pshape) &&
|
||||
PartialShape::merge_into(result_pshape, parent_idx_pshape) &&
|
||||
result_pshape.rank().compatible(3),
|
||||
"step_ids and parent_idx inputs must have the same shape with rank 3. Got: ",
|
||||
step_ids_pshape,
|
||||
" and ",
|
||||
parent_idx_pshape,
|
||||
", respectively");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
max_seq_len_pshape.rank().compatible(1),
|
||||
"max_seq_len input must have rank 1. Got: ",
|
||||
max_seq_len_pshape);
|
||||
|
||||
if (result_pshape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(result_pshape[1], result_pshape[1], max_seq_len_pshape[0]),
|
||||
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
|
||||
"step_ids/parent_idx inputs. Got: ",
|
||||
result_pshape[1],
|
||||
" and ",
|
||||
max_seq_len_pshape[0],
|
||||
", respectively");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
end_token_pshape.rank().compatible(0),
|
||||
"end_token input must be scalar. Got: ",
|
||||
end_token_pshape);
|
||||
|
||||
set_output_type(0, result_et, result_pshape);
|
||||
}
|
||||
|
@ -270,6 +270,7 @@ set(SRC
|
||||
visitors/op/floor_mod.cpp
|
||||
visitors/op/floor.cpp
|
||||
visitors/op/gather.cpp
|
||||
visitors/op/gather_tree.cpp
|
||||
visitors/op/gelu.cpp
|
||||
visitors/op/greater_equal.cpp
|
||||
visitors/op/greater.cpp
|
||||
|
@ -2,6 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
@ -9,78 +12,280 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
namespace {
|
||||
constexpr size_t step_ids_input_idx = 0;
|
||||
constexpr size_t parent_idx_input_idx = 1;
|
||||
constexpr size_t max_seq_len_input_idx = 2;
|
||||
constexpr size_t end_token_input_idx = 3;
|
||||
constexpr size_t gather_tree_required_inputs = 4;
|
||||
struct GatherTreeInputInfo {
|
||||
element::Type in_et;
|
||||
PartialShape in_pshape;
|
||||
};
|
||||
|
||||
using GatherTreeInputParams = std::array<GatherTreeInputInfo, gather_tree_required_inputs>;
|
||||
|
||||
std::shared_ptr<Node> makeGatherTreeOp(const GatherTreeInputParams& p) {
|
||||
if (p.size() != gather_tree_required_inputs) {
|
||||
throw runtime_error("GatherTree requires 4 inputs");
|
||||
}
|
||||
auto step_ids = make_shared<op::Parameter>(p.at(step_ids_input_idx).in_et, p.at(step_ids_input_idx).in_pshape);
|
||||
auto parent_idx =
|
||||
make_shared<op::Parameter>(p.at(parent_idx_input_idx).in_et, p.at(parent_idx_input_idx).in_pshape);
|
||||
auto max_seq_len =
|
||||
make_shared<op::Parameter>(p.at(max_seq_len_input_idx).in_et, p.at(max_seq_len_input_idx).in_pshape);
|
||||
auto end_token = make_shared<op::Parameter>(p.at(end_token_input_idx).in_et, p.at(end_token_input_idx).in_pshape);
|
||||
return make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(type_prop, gather_tree_invalid_input_element_type) {
|
||||
Shape scalar_shape{};
|
||||
Shape vector_shape{2};
|
||||
Shape tensor_shape{1, 2, 3};
|
||||
|
||||
element::Type input_et = element::boolean;
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{input_et, tensor_shape},
|
||||
GatherTreeInputInfo{input_et, tensor_shape},
|
||||
GatherTreeInputInfo{input_et, vector_shape},
|
||||
GatherTreeInputInfo{input_et, scalar_shape}};
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(params);
|
||||
FAIL() << "Invalid element types for inputs not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "Element type of inputs must be numeric.");
|
||||
} catch (...) {
|
||||
FAIL() << "Element type check for inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_incompatible_input_element_types) {
|
||||
element::Type float_et = element::f32;
|
||||
element::Type integer_et = element::i32;
|
||||
|
||||
Shape scalar_shape{};
|
||||
Shape vector_shape{2};
|
||||
Shape tensor_shape{1, 2, 3};
|
||||
|
||||
vector<GatherTreeInputParams> test_cases = {// step_ids input has incompatible element type
|
||||
GatherTreeInputParams{GatherTreeInputInfo{integer_et, tensor_shape},
|
||||
GatherTreeInputInfo{float_et, tensor_shape},
|
||||
GatherTreeInputInfo{float_et, vector_shape},
|
||||
GatherTreeInputInfo{float_et, scalar_shape}},
|
||||
// parent_idx input has incompatible element type
|
||||
GatherTreeInputParams{GatherTreeInputInfo{float_et, tensor_shape},
|
||||
GatherTreeInputInfo{integer_et, tensor_shape},
|
||||
GatherTreeInputInfo{float_et, vector_shape},
|
||||
GatherTreeInputInfo{float_et, scalar_shape}},
|
||||
// max_seq_len input has incompatible element type
|
||||
GatherTreeInputParams{GatherTreeInputInfo{float_et, tensor_shape},
|
||||
GatherTreeInputInfo{float_et, tensor_shape},
|
||||
GatherTreeInputInfo{integer_et, vector_shape},
|
||||
GatherTreeInputInfo{float_et, scalar_shape}},
|
||||
// end_token input has incompatible element type
|
||||
GatherTreeInputParams{GatherTreeInputInfo{float_et, tensor_shape},
|
||||
GatherTreeInputInfo{float_et, tensor_shape},
|
||||
GatherTreeInputInfo{float_et, vector_shape},
|
||||
GatherTreeInputInfo{integer_et, scalar_shape}}};
|
||||
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(test_case);
|
||||
FAIL() << "Incompatible element types for inputs not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "Inputs must have the same element type.");
|
||||
} catch (...) {
|
||||
FAIL() << "Element type check for inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_input_element_types) {
|
||||
Shape scalar_shape{};
|
||||
Shape vector_shape{2};
|
||||
Shape tensor_shape{1, 2, 3};
|
||||
|
||||
std::vector<element::Type> element_types{element::u4,
|
||||
element::u8,
|
||||
element::u16,
|
||||
element::u32,
|
||||
element::i8,
|
||||
element::i16,
|
||||
element::i32,
|
||||
element::i64,
|
||||
element::f32,
|
||||
element::f64,
|
||||
element::u32};
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(element_types), std::end(element_types), [&](element::Type et) {
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{et, tensor_shape},
|
||||
GatherTreeInputInfo{et, tensor_shape},
|
||||
GatherTreeInputInfo{et, vector_shape},
|
||||
GatherTreeInputInfo{et, scalar_shape}};
|
||||
test_cases.insert(test_cases.end(), params);
|
||||
});
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
EXPECT_NO_THROW(makeGatherTreeOp(test_case));
|
||||
} catch (...) {
|
||||
FAIL() << "Inputs element type validation check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_invalid_step_ids_and_parent_idx_input_shapes) {
|
||||
element::Type et = element::f32;
|
||||
|
||||
Shape scalar_shape{};
|
||||
PartialShape vector_shape{Dimension()};
|
||||
|
||||
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
|
||||
{PartialShape{1}, PartialShape{1, 2, 3}},
|
||||
{PartialShape{1, 2, 3}, PartialShape{3, 3, 3, 3}},
|
||||
{PartialShape{Dimension(), Dimension(), 3}, PartialShape::dynamic(4)},
|
||||
{PartialShape::dynamic(2), PartialShape::dynamic()},
|
||||
{PartialShape{1, 2, 3}, PartialShape{Dimension(), Dimension(3, 5), 3}}};
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{et, shapes.first},
|
||||
GatherTreeInputInfo{et, shapes.second},
|
||||
GatherTreeInputInfo{et, vector_shape},
|
||||
GatherTreeInputInfo{et, scalar_shape}};
|
||||
test_cases.insert(test_cases.end(), params);
|
||||
});
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(test_case);
|
||||
FAIL() << "Incompatible shapes for inputs step_ids and parent_idx not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "step_ids and parent_idx inputs must have the same shape with rank 3.");
|
||||
} catch (...) {
|
||||
FAIL() << "Shape check for step_ids and parent_idx inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_invalid_max_seq_len_rank) {
|
||||
element::Type et = element::f32;
|
||||
|
||||
Shape tensor_shape{1, 2, 3};
|
||||
Shape scalar_shape{};
|
||||
|
||||
std::vector<PartialShape> max_seq_len_shapes = {{}, {Dimension(), 1}, PartialShape::dynamic(3), {1, 2, 3, 4}};
|
||||
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(max_seq_len_shapes), std::end(max_seq_len_shapes), [&](PartialShape shape) {
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{et, tensor_shape},
|
||||
GatherTreeInputInfo{et, tensor_shape},
|
||||
GatherTreeInputInfo{et, shape},
|
||||
GatherTreeInputInfo{et, scalar_shape}};
|
||||
test_cases.insert(test_cases.end(), params);
|
||||
});
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(test_case);
|
||||
FAIL() << "Invalid shapes for max_seq_len input not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "max_seq_len input must have rank 1.");
|
||||
} catch (...) {
|
||||
FAIL() << "Shape check for max_seq_len input failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_incompatible_step_ids_and_max_seq_len_shapes) {
|
||||
element::Type et = element::f32;
|
||||
|
||||
Shape scalar_shape{};
|
||||
|
||||
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
|
||||
{PartialShape{1, 2, 3}, PartialShape{4}},
|
||||
{PartialShape{Dimension(), 2, 3}, PartialShape{Dimension(3, 6)}}};
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{et, shapes.first},
|
||||
GatherTreeInputInfo{et, shapes.first},
|
||||
GatherTreeInputInfo{et, shapes.second},
|
||||
GatherTreeInputInfo{et, scalar_shape}};
|
||||
test_cases.insert(test_cases.end(), params);
|
||||
});
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(test_case);
|
||||
FAIL() << "Incompatible shapes for inputs step_ids and max_seq_len not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
|
||||
"step_ids/parent_idx inputs.");
|
||||
} catch (...) {
|
||||
FAIL() << "Shape check for step_ids and max_seq_len inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_output_shape) {
|
||||
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
element::Type et = element::f32;
|
||||
Shape scalar_shape{};
|
||||
|
||||
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
|
||||
{PartialShape{1, 2, 3}, PartialShape{2}},
|
||||
{PartialShape{1, 2, 3}, PartialShape::dynamic(1)},
|
||||
{PartialShape{Dimension(), 2, Dimension()}, PartialShape{2}},
|
||||
{
|
||||
PartialShape::dynamic(3),
|
||||
PartialShape{4},
|
||||
},
|
||||
{PartialShape{Dimension(), Dimension(3, 5), Dimension()}, PartialShape{Dimension(1, 3)}},
|
||||
{PartialShape::dynamic(), PartialShape::dynamic()}};
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{et, shapes.first},
|
||||
GatherTreeInputInfo{et, shapes.first},
|
||||
GatherTreeInputInfo{et, shapes.second},
|
||||
GatherTreeInputInfo{et, scalar_shape}};
|
||||
test_cases.insert(test_cases.end(), params);
|
||||
});
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(test_case);
|
||||
|
||||
ASSERT_EQ(gather_tree->get_output_shape(0), (Shape{1, 2, 3}));
|
||||
ASSERT_EQ(gather_tree->get_output_element_type(0), element::i64);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_pooling_step_ids_invalid_rank) {
|
||||
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
|
||||
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
try {
|
||||
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Ivalid step_ids input rank not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("step_ids input rank must equal to 3 (step_ids rank: 4)"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
PartialShape result_shape{test_case.at(step_ids_input_idx).in_pshape};
|
||||
PartialShape max_seq_len_shape{test_case.at(max_seq_len_input_idx).in_pshape};
|
||||
if (result_shape.rank().is_static() && max_seq_len_shape.rank().is_static()) {
|
||||
result_shape[1] = result_shape[1] & max_seq_len_shape[0];
|
||||
}
|
||||
ASSERT_EQ(gather_tree->get_output_partial_shape(0), result_shape);
|
||||
ASSERT_EQ(gather_tree->get_output_element_type(0), et);
|
||||
} catch (...) {
|
||||
FAIL() << "Output shape check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_parent_idx_invalid_rank) {
|
||||
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3, 4});
|
||||
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
try {
|
||||
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Ivalid parent_idx input rank not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("parent_idx input rank must equal to 3 (parent_idx rank: 4)"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
TEST(type_prop, gather_tree_invalid_end_token_rank) {
|
||||
element::Type et = element::f32;
|
||||
|
||||
TEST(type_prop, gather_tree_max_seq_len_invalid_rank) {
|
||||
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1, 2});
|
||||
auto end_token = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
try {
|
||||
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Ivalid parent_idx input rank not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("max_seq_len input rank must equal to 1 (max_seq_len rank: 2)"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
Shape tensor_shape{1, 2, 3};
|
||||
Shape vector_shape{2};
|
||||
|
||||
TEST(type_prop, gather_tree_end_token_invalid_rank) {
|
||||
auto step_ids = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto parent_idx = make_shared<op::Parameter>(element::i64, Shape{1, 2, 3});
|
||||
auto max_seq_len = make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
auto end_token = make_shared<op::Parameter>(element::i64, Shape{1});
|
||||
try {
|
||||
auto gather_tree = make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Ivalid end_token input rank not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("end_token input rank must be scalar (end_token rank: 1)"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
std::vector<PartialShape> end_token_shapes = {{3}, {Dimension(), 1}, PartialShape::dynamic(3), {1, 2, 3, 4}};
|
||||
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(end_token_shapes), std::end(end_token_shapes), [&](PartialShape shape) {
|
||||
GatherTreeInputParams params{GatherTreeInputInfo{et, tensor_shape},
|
||||
GatherTreeInputInfo{et, tensor_shape},
|
||||
GatherTreeInputInfo{et, vector_shape},
|
||||
GatherTreeInputInfo{et, shape}};
|
||||
test_cases.insert(test_cases.end(), params);
|
||||
});
|
||||
for (const auto& test_case : test_cases) {
|
||||
try {
|
||||
auto gather_tree = makeGatherTreeOp(test_case);
|
||||
FAIL() << "Invalid shapes for end_token input not detected";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "end_token input must be scalar.");
|
||||
} catch (...) {
|
||||
FAIL() << "Shape check for end_token input failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
28
ngraph/test/visitors/op/gather_tree.cpp
Normal file
28
ngraph/test/visitors/op/gather_tree.cpp
Normal file
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, gather_tree_op) {
|
||||
NodeBuilder::get_ops().register_factory<opset1::GatherTree>();
|
||||
|
||||
auto step_ids = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
auto parent_idx = std::make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
|
||||
auto max_seq_len = std::make_shared<op::Parameter>(element::f32, Shape{2});
|
||||
auto end_token = std::make_shared<op::Parameter>(element::f32, Shape{});
|
||||
|
||||
auto gather_tree = std::make_shared<opset1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
NodeBuilder builder(gather_tree);
|
||||
|
||||
const auto expected_attr_count = 0;
|
||||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||
}
|
Loading…
Reference in New Issue
Block a user