Revise BatchToSpace reference implementation (#6289)

* Add visitor unit test

* BatchToSpace operation class refactored

 * Add RTTI definiton/declaration
 * Add checks in validate_and_infer_types
 * Add input/output checks in evaluate method

* Add type_prop tests covering checks in validate_and_infer_types method

* Fix incorrect variable names to specify crops_begin and crops_end inputs

* Add backend tests

* Use array stl to handle batch to space inputs in type_prop tests

* Add more backend tests

* Clean up manifest

* Address review comments

* Add node validation check for out of bounds crops

* Add validation checks in evaluate method

* Add error message to ngraph checks

* Modify check for input rank to allow rank 4 or greater to align with spec and plugins impl
This commit is contained in:
Gabriele Galiero Casay 2021-07-07 07:37:58 +02:00 committed by GitHub
parent 0d69e7fea2
commit 79f26cea7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 775 additions and 90 deletions

View File

@ -5,7 +5,7 @@
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
@ -27,8 +27,7 @@ namespace ngraph
class NGRAPH_API BatchToSpace : public Op
{
public:
static constexpr NodeTypeInfo type_info{"BatchToSpace", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
BatchToSpace() = default;
/// \brief Constructs a BatchToSpace operation.
///

View File

@ -23,7 +23,7 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::BatchToSpace::type_info;
NGRAPH_RTTI_DEFINITION(op::v1::BatchToSpace, "BatchToSpace", 1);
ngraph::op::v1::BatchToSpace::BatchToSpace(const ngraph::Output<ngraph::Node>& data,
const ngraph::Output<ngraph::Node>& block_shape,
@ -37,83 +37,135 @@ ngraph::op::v1::BatchToSpace::BatchToSpace(const ngraph::Output<ngraph::Node>& d
void op::v1::BatchToSpace::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v1_BatchToSpace_validate_and_infer_types);
PartialShape data_pshape = get_input_partial_shape(0);
const auto& data_type = get_input_element_type(0);
const auto& block_shape_type = get_input_element_type(1);
const auto& crops_begin_type = get_input_element_type(2);
const auto& crops_end_type = get_input_element_type(3);
const auto& data_et = get_input_element_type(0);
const auto& block_shape_et = get_input_element_type(1);
const auto& crops_begin_et = get_input_element_type(2);
const auto& crops_end_et = get_input_element_type(3);
element::Type inputs_integer_et{};
NODE_VALIDATION_CHECK(
this,
element::Type::merge(inputs_integer_et, crops_begin_et, crops_end_et) &&
element::Type::merge(inputs_integer_et, inputs_integer_et, block_shape_et),
"block_shape, crops_begin and crops_end inputs must have same element type. Got: ",
block_shape_et,
", ",
crops_begin_et,
" and ",
crops_end_et);
NODE_VALIDATION_CHECK(this,
block_shape_type.is_integral_number(),
"block_shape must be an integral number but got (",
block_shape_type,
").");
inputs_integer_et.is_integral_number(),
"block_shape and crops inputs must have integer element type. Got: ",
inputs_integer_et);
const PartialShape& data_pshape = get_input_partial_shape(0);
const PartialShape& block_shape_ps = get_input_partial_shape(1);
const PartialShape& crops_begin_ps = get_input_partial_shape(2);
const PartialShape& crops_end_ps = get_input_partial_shape(3);
PartialShape inputs_same_ps{PartialShape::dynamic()};
NODE_VALIDATION_CHECK(
this,
PartialShape::merge_into(inputs_same_ps, crops_begin_ps) &&
PartialShape::merge_into(inputs_same_ps, crops_end_ps) &&
PartialShape::merge_into(inputs_same_ps, block_shape_ps),
"block_shape, crops_begin and crops_end inputs must have the same shape. Got: ",
block_shape_ps,
", ",
crops_begin_ps,
" and ",
crops_end_ps);
const Rank inputs_rank_one = inputs_same_ps.rank();
NODE_VALIDATION_CHECK(this,
crops_begin_type.is_integral_number(),
"crops_begin must be an integral number but got (",
crops_begin_type,
").");
inputs_rank_one.compatible(1),
"block_shape and crops inputs must have rank 1. Got: ",
inputs_rank_one);
NODE_VALIDATION_CHECK(this,
crops_end_type.is_integral_number(),
"crops_end must be an integral number but got (",
crops_end_type,
").");
const Rank data_rank = data_pshape.rank();
if (data_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
(data_rank.get_length() >= 4),
"data input must have rank greater than or equal to 4. Got: ",
data_rank.get_length());
auto data = input_value(0);
auto block = input_value(1);
auto crops_begin = input_value(2);
auto crops_end = input_value(3);
if (inputs_same_ps.is_static())
{
NODE_VALIDATION_CHECK(this,
data_rank.get_length() == inputs_same_ps[0].get_length(),
"block_shape and crop inputs must have same number of elements "
"as data input rank. Got: ",
inputs_same_ps[0],
" and ",
data_rank);
}
}
auto block_const = get_constant_from_source(block);
auto crops_begin_const = get_constant_from_source(crops_begin);
auto crops_end_const = get_constant_from_source(crops_end);
const auto block_const = get_constant_from_source(input_value(1));
const auto crops_begin_const = get_constant_from_source(input_value(2));
const auto crops_end_const = get_constant_from_source(input_value(3));
if (block_const && crops_begin_const && crops_end_const && data_pshape.is_static())
{
const auto& data_shape = data.get_shape();
NODE_VALIDATION_CHECK(
this,
(data_shape.size() >= 2),
"The data tensor with rank lower than 2 is not supported (data rank: ",
data_shape.size(),
")");
const Shape& data_sshape = data_pshape.to_shape();
auto block_val = block_const->cast_vector<int64_t>();
auto crops_begin_val = crops_begin_const->cast_vector<int64_t>();
auto crops_end_val = crops_end_const->cast_vector<int64_t>();
int64_t block_prod = 1;
for (long val : block_val)
{
NODE_VALIDATION_CHECK(this, val > 0, "block_shape values must be greater than 0");
block_prod *= val;
}
bool block_vals_valid =
std::all_of(begin(block_val), end(block_val), [](int64_t elem) { return elem >= 1; });
NODE_VALIDATION_CHECK(this,
block_vals_valid,
"Elements of block_shape input must be greater or equal to one.");
bool crops_begin_vals_valid = std::all_of(
begin(crops_begin_val), end(crops_begin_val), [](int64_t elem) { return elem >= 0; });
bool crops_end_vals_valid = std::all_of(
begin(crops_end_val), end(crops_end_val), [](int64_t elem) { return elem >= 0; });
NODE_VALIDATION_CHECK(
this,
crops_begin_vals_valid && crops_end_vals_valid,
"Elements of crops_begin and crops_end inputs must be greater or equal to zero.");
int64_t block_prod =
std::accumulate(begin(block_val), end(block_val), 1, std::multiplies<int64_t>());
NODE_VALIDATION_CHECK(this,
data_shape.at(0) % block_prod == 0,
"BatchToSpace: The input data's 'batch' axis size: ",
data_shape.at(0),
" must be a multiple of ",
data_sshape[0] % block_prod == 0,
"The input data's 'batch' axis size: ",
data_sshape[0],
" must be a multiple of",
" product of block_shape values: ",
block_prod);
Shape output_shape = {static_cast<size_t>(data_shape[0] / block_prod)};
for (size_t idx = 1; idx < data_shape.size(); ++idx)
for (size_t idx = 0; idx < data_sshape.size(); idx++)
{
output_shape.push_back(static_cast<size_t>(data_shape[idx] * block_val[idx] -
crops_begin_val[idx] - crops_end_val[idx]));
const bool is_valid_crops_and_shape =
crops_begin_val[idx] + crops_end_val[idx] <=
block_val[idx] * static_cast<int64_t>(data_sshape[idx]);
NODE_VALIDATION_CHECK(this,
is_valid_crops_and_shape,
"crops_begin[i] + crops_end[i] must be less or equal to "
"block_shape[i] * input_shape[i]");
}
Shape output_sshape = {static_cast<size_t>(data_sshape[0] / block_prod)};
for (size_t idx = 1; idx < data_sshape.size(); ++idx)
{
output_sshape.push_back(static_cast<size_t>(data_sshape[idx] * block_val[idx] -
crops_begin_val[idx] - crops_end_val[idx]));
}
set_output_size(1);
set_output_type(0, data_type, output_shape);
set_output_type(0, data_et, output_sshape);
}
else
{
set_output_type(0, data_type, PartialShape::dynamic(data_pshape.rank()));
set_output_type(0, data_et, PartialShape::dynamic(data_rank));
}
}
@ -144,16 +196,52 @@ namespace
return false;
}
auto data_shape = data->get_shape();
if (!(data->get_shape().size() == 4 || data->get_shape().size() == 5))
auto data_rank = data_shape.size();
if (!(data_rank == 4 || data_rank == 5))
{
return false;
}
size_t block_values_size = shape_size(inputs[1]->get_shape());
size_t crops_begin_size = shape_size(inputs[2]->get_shape());
size_t crops_end_size = shape_size(inputs[3]->get_shape());
NGRAPH_CHECK(
block_values_size == data_rank && crops_begin_size == data_rank &&
crops_end_size == data_rank,
"Invalid block_shape/crops_begin/crops_end shape with respect to rank of data input");
const auto* block_values = inputs[1]->get_data_ptr<int64_t>();
const auto* crops_begin_values = inputs[2]->get_data_ptr<int64_t>();
const auto* crops_end_values = inputs[3]->get_data_ptr<int64_t>();
const bool block_vals_valid = std::all_of(
block_values, block_values + block_values_size, [](int64_t elem) { return elem >= 1; });
NGRAPH_CHECK(block_vals_valid, "Invalid element values of block_shape input");
const bool crops_begin_vals_valid = std::all_of(crops_begin_values,
crops_begin_values + crops_begin_size,
[](int64_t elem) { return elem >= 0; });
const bool crops_end_vals_valid = std::all_of(crops_end_values,
crops_end_values + crops_end_size,
[](int64_t elem) { return elem >= 0; });
NGRAPH_CHECK(crops_begin_vals_valid && crops_end_vals_valid,
"Invalid element values of crops_begin/crops_end input/s");
const std::size_t block_prod = std::accumulate(
block_values, block_values + block_values_size, 1UL, std::multiplies<std::size_t>());
NGRAPH_CHECK(data_shape[0] % block_prod == 0,
"Invalid batch axis of data input with respect to block_shape values");
for (size_t i = 0; i < data_rank; i++)
{
const bool is_valid_crops_and_shape =
crops_begin_values[i] + crops_end_values[i] <=
block_values[i] * static_cast<int64_t>(data_shape[i]);
NGRAPH_CHECK(
is_valid_crops_and_shape,
"Invalid crops values (out of bounds) with respect to the shape of data input");
}
Shape dispersed_shape(1);
dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end());
std::vector<size_t> axes_order(block_values_size + 1);
@ -249,7 +337,9 @@ namespace
bool ngraph::op::v1::BatchToSpace::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
NGRAPH_OP_SCOPE(v1_BatchToSpace);
NGRAPH_OP_SCOPE(v1_BatchToSpace_evaluate);
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 4));
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
return batch_to_space_evaluate(outputs, inputs);
}

View File

@ -231,6 +231,7 @@ set(SRC
visitors/op/adaptive_max_pool.cpp
visitors/op/atan.cpp
visitors/op/batch_norm.cpp
visitors/op/batch_to_space.cpp
visitors/op/broadcast.cpp
visitors/op/bucketize.cpp
visitors/op/ceiling.cpp
@ -373,6 +374,7 @@ set(MULTI_TEST_SRC
backend/auto_broadcast.in.cpp
backend/avg_pool.in.cpp
backend/batch_norm.in.cpp
backend/batch_to_space.in.cpp
backend/broadcast.in.cpp
backend/bucketize.in.cpp
backend/builder_reduce_ops_opset1.in.cpp

View File

@ -0,0 +1,179 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_control.hpp"
using namespace std;
using namespace ngraph;
static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
namespace
{
template<typename dataType>
struct BatchToSpaceParams
{
using Data = test::NDArrayBase<dataType>;
using BlockShape = test::NDArrayBase<int64_t>;
using Crops = test::NDArrayBase<int64_t>;
BatchToSpaceParams(Data in_data,
BlockShape block_shape,
Crops crops_begin,
Crops crops_end,
Data expected_output)
: m_data{std::move(in_data)}
, m_block_shape{std::move(block_shape)}
, m_crops_begin{std::move(crops_begin)}
, m_crops_end{std::move(crops_end)}
, m_expected_output{std::move(expected_output)}
{
}
Data m_data;
BlockShape m_block_shape;
Crops m_crops_begin;
Crops m_crops_end;
Data m_expected_output;
};
template <typename dataType>
static void BatchToSpaceTestExecute(const BatchToSpaceParams<dataType>& params)
{
const auto data =
make_shared<op::Parameter>(element::from<dataType>(), params.m_data.get_shape());
const auto block_shape = op::Constant::create(
element::i64, params.m_block_shape.get_shape(), params.m_block_shape.get_vector());
const auto crops_begin = op::Constant::create(
element::i64, params.m_crops_begin.get_shape(), params.m_crops_begin.get_vector());
const auto crops_end = op::Constant::create(
element::i64, params.m_crops_end.get_shape(), params.m_crops_end.get_vector());
const auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
auto f = make_shared<Function>(batch_to_space, ParameterVector{data});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input(params.m_data.get_vector());
test_case.add_expected_output(params.m_expected_output.get_vector());
test_case.run_with_tolerance_as_fp(1e-4f);
}
class BatchToSpaceTestFloat : public testing::TestWithParam<BatchToSpaceParams<float>>
{
};
} // namespace
NGRAPH_TEST_P(${BACKEND_NAME}, BatchToSpaceTestFloat, BatchToSpaceTestFloatCases)
{
BatchToSpaceTestExecute(GetParam());
}
const test::NDArray<float, 4> input_with_shape_4x1x1x3(
{{{{1.0f, 2.0f, 3.0f}}},
{{{4.0f, 5.0f, 6.0f}}},
{{{7.0f, 8.0f, 9.0f}}},
{{{10.0f, 11.0f, 12.0f}}}});
const test::NDArray<float, 4> input_with_shape_4x1x2x3(
{{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}},
{{{7.0f, 8.0f, 9.0f}, {10.0f, 11.0f, 12.0f}}},
{{{13.0f, 14.0f, 15.0f}, {16.0f, 17.0f, 18.0f}}},
{{{19.0f, 20.0f, 21.0f}, {22.0f, 23.0f, 24.0f}}}});
const test::NDArray<int64_t, 1> zero_crops_4d({0, 0, 0, 0});
NGRAPH_INSTANTIATE_TEST_SUITE_P(
${BACKEND_NAME},
batch_to_space_4d_without_crops,
BatchToSpaceTestFloat,
testing::Values(
BatchToSpaceParams<float>{input_with_shape_4x1x1x3,
test::NDArray<int64_t, 1>({1, 1, 1, 2}),
zero_crops_4d,
zero_crops_4d,
test::NDArray<float, 4>(
{{{{1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f}}},
{{{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x1x3,
test::NDArray<int64_t, 1>({1, 1, 2, 1}),
zero_crops_4d,
zero_crops_4d,
test::NDArray<float, 4>(
{{{{1.0f, 2.0f, 3.0f}, {7.0f, 8.0f, 9.0f}}},
{{{4.0f, 5.0f, 6.0f}, {10.0f, 11.0f, 12.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x1x3,
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
zero_crops_4d,
zero_crops_4d,
test::NDArray<float, 4>(
{{{{1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f},
{7.0f, 10.0f, 8.0f, 11.0f, 9.0f, 12.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
test::NDArray<int64_t, 1>({1, 1, 1, 2}),
zero_crops_4d,
zero_crops_4d,
test::NDArray<float, 4>(
{{{{1.0f, 13.0f, 2.0f, 14.0f, 3.0f, 15.0f},
{4.0f, 16.0f, 5.0f, 17.0f, 6.0f, 18.0f}}},
{{{7.0f, 19.0f, 8.0f, 20.0f, 9.0f, 21.0f},
{10.0f, 22.0f, 11.0f, 23.0f, 12.0f, 24.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
test::NDArray<int64_t, 1>({1, 1, 2, 1}),
zero_crops_4d,
zero_crops_4d,
test::NDArray<float, 4>(
{{{{1.0f, 2.0f, 3.0f}, {13.0f, 14.0f, 15.0f},
{4.0f, 5.0f, 6.0f}, {16.0f, 17.0f, 18.0f}}},
{{{7.0f, 8.0f, 9.0f}, {19.0f, 20.0f, 21.0f},
{10.0f, 11.0f, 12.0f}, {22.0f, 23.0f, 24.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
zero_crops_4d,
zero_crops_4d,
test::NDArray<float, 4>(
{{{{1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f},
{13.0f, 19.0f, 14.0f, 20.0f, 15.0f, 21.0f},
{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f},
{16.0f, 22.0f, 17.0f, 23.0f, 18.0f, 24.0f}}}})}));
NGRAPH_INSTANTIATE_TEST_SUITE_P(
${BACKEND_NAME},
batch_to_space_4d_crops,
BatchToSpaceTestFloat,
testing::Values(
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
test::NDArray<int64_t, 1>({0, 0, 0, 0}),
test::NDArray<int64_t, 1>({0, 0, 0, 2}),
test::NDArray<float, 4>(
{{{{1.0f, 7.0f, 2.0f, 8.0f},
{13.0f, 19.0f, 14.0f, 20.0f},
{4.0f, 10.0f, 5.0f, 11.0f},
{16.0f, 22.0f, 17.0f, 23.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
test::NDArray<int64_t, 1>({0, 0, 0, 2}),
test::NDArray<int64_t, 1>({0, 0, 0, 0}),
test::NDArray<float, 4>(
{{{{2.0f, 8.0f, 3.0f, 9.0f},
{14.0f, 20.0f, 15.0f, 21.0f},
{5.0f, 11.0f, 6.0f, 12.0f},
{17.0f, 23.0f, 18.0f, 24.0f}}}})},
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
test::NDArray<int64_t, 1>({0, 0, 1, 0}),
test::NDArray<int64_t, 1>({0, 0, 1, 0}),
test::NDArray<float, 4>(
{{{{13.0f, 19.0f, 14.0f, 20.0f, 15.0f, 21.0f},
{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f}}}})}));

View File

@ -676,10 +676,6 @@ conv_bias_bprop_2d
# Cannot cast ngraph node ConvolutionBiasAdd to CNNLayer!
conv_bias_add_2d
# Unsupported primitive of type: SpaceToBatch
space_to_batch
batch_to_space
# [Validation] Argument must have rank >= 2 and <= 4 (argument shape: {1,2,2,2,3})
normalize_across_1axis_5d
normalize_across_123axes_5d

View File

@ -2,6 +2,8 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <array>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
@ -9,18 +11,407 @@
using namespace std;
using namespace ngraph;
TEST(type_prop, batch_to_space_output_shape_2D)
namespace {
constexpr size_t data_input_idx = 0;
constexpr size_t block_shape_input_idx = 1;
constexpr size_t crops_begin_input_idx = 2;
constexpr size_t crops_end_input_idx = 3;
constexpr size_t batch_to_space_required_inputs = 4;
struct InputInfo
{
element::Type in_et;
PartialShape in_pshape;
};
using BatchToSpaceInputParams = std::array<InputInfo, batch_to_space_required_inputs>;
std::shared_ptr<Node> makeBatchToSpaceOp(const BatchToSpaceInputParams& p)
{
if(p.size() != batch_to_space_required_inputs)
{
throw runtime_error("BatchToSpace requires 4 inputs");
}
auto data = make_shared<op::Parameter>(
p.at(data_input_idx).in_et, p.at(data_input_idx).in_pshape);
auto block_shape = make_shared<op::Parameter>(
p.at(block_shape_input_idx).in_et, p.at(block_shape_input_idx).in_pshape);
auto crops_begin = make_shared<op::Parameter>(
p.at(crops_begin_input_idx).in_et, p.at(crops_begin_input_idx).in_pshape);
auto crops_end = make_shared<op::Parameter>(
p.at(crops_end_input_idx).in_et, p.at(crops_end_input_idx).in_pshape);
return make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
}
} // namespace
TEST(type_prop, batch_to_space_incompatible_input_element_types)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{10, 26});
auto block_shape = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 5});
auto pads_begin = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 2});
auto pads_end = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 0});
element::Type float_et = element::f32;
element::Type integer64_et = element::i64;
element::Type integer32_et = element::i32;
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
Shape data_sshape{10, 26, 4, 4};
Shape inputs_sshape{4};
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_shape(), (Shape{10 / 5, 26 * 5 - 2}));
vector<BatchToSpaceInputParams> test_cases;
test_cases.push_back(
BatchToSpaceInputParams{
InputInfo{float_et, data_sshape},
InputInfo{integer64_et, inputs_sshape},
InputInfo{integer32_et, inputs_sshape},
InputInfo{integer32_et, inputs_sshape}});
test_cases.push_back(
BatchToSpaceInputParams{
InputInfo{float_et, data_sshape},
InputInfo{integer32_et, inputs_sshape},
InputInfo{integer64_et, inputs_sshape},
InputInfo{integer32_et, inputs_sshape}});
test_cases.push_back(
BatchToSpaceInputParams{
InputInfo{float_et, data_sshape},
InputInfo{integer64_et, inputs_sshape},
InputInfo{float_et, inputs_sshape},
InputInfo{float_et, inputs_sshape}});
for (const auto& test_case : test_cases)
{
try
{
auto batch_to_space = makeBatchToSpaceOp(test_case);
FAIL() << "Incompatible element types for block_shape/crops_begin/crops_end inputs not detected";
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"block_shape, crops_begin and crops_end inputs must have same element type.");
}
catch (...)
{
FAIL() << "Element type check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
}
}
}
TEST(type_prop, batch_to_space_invalid_input_element_types)
{
element::Type float_et = element::f32;
Shape data_sshape{10, 26, 4, 4};
Shape inputs_sshape{4};
const BatchToSpaceInputParams params{
InputInfo{float_et, data_sshape},
InputInfo{float_et, inputs_sshape},
InputInfo{float_et, inputs_sshape},
InputInfo{float_et, inputs_sshape}};
try
{
auto batch_to_space = makeBatchToSpaceOp(params);
FAIL() << "Invalid non-integer element type for block_shape/crops_begin/crops_end inputs not detected";
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"block_shape and crops inputs must have integer element type.");
}
catch (...)
{
FAIL() << "Element type check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_invalid_data_input_rank)
{
Shape data_sshape{4, 2};
element::Type data_et = element::f32;
Shape inputs_sshape{2};
element::Type inputs_et = element::i64;
const BatchToSpaceInputParams params{
InputInfo{data_et, data_sshape},
InputInfo{inputs_et, inputs_sshape},
InputInfo{inputs_et, inputs_sshape},
InputInfo{inputs_et, inputs_sshape}};
try
{
auto batch_to_space = makeBatchToSpaceOp(params);
FAIL() << "Invalid rank of data input not detected";
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "data input must have rank greater than or equal to 4");
}
catch (...)
{
FAIL() << "Rank check for data input failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_incompatible_secondary_inputs_shapes)
{
Shape data_sshape{10, 26, 4, 4};
element::Type data_et = element::f32;
Shape inputs_sshape_1D{4};
Shape inputs_sshape_2D{4, 1};
element::Type inputs_et = element::i64;
vector<BatchToSpaceInputParams> test_cases;
test_cases.push_back(
BatchToSpaceInputParams{
InputInfo{data_et, data_sshape},
InputInfo{inputs_et, inputs_sshape_2D},
InputInfo{inputs_et, inputs_sshape_1D},
InputInfo{inputs_et, inputs_sshape_1D}});
test_cases.push_back(
BatchToSpaceInputParams{
InputInfo{data_et, data_sshape},
InputInfo{inputs_et, inputs_sshape_1D},
InputInfo{inputs_et, inputs_sshape_2D},
InputInfo{inputs_et, inputs_sshape_1D}});
test_cases.push_back(
BatchToSpaceInputParams{
InputInfo{data_et, data_sshape},
InputInfo{inputs_et, inputs_sshape_1D},
InputInfo{inputs_et, inputs_sshape_2D},
InputInfo{inputs_et, inputs_sshape_2D}});
for (const auto& test_case : test_cases)
{
try
{
auto batch_to_space = makeBatchToSpaceOp(test_case);
FAIL() << "Incompatible shapes for block_shape/crops_begin/crops_end inputs not detected";
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"block_shape, crops_begin and crops_end inputs must have the same shape.");
}
catch (...)
{
FAIL() << "Shapes check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
}
}
}
TEST(type_prop, batch_to_space_invalid_secondary_inputs_rank)
{
Shape data_sshape{10, 26, 4, 4};
element::Type data_et = element::f32;
Shape inputs_sshape_2D{4, 1};
element::Type inputs_et = element::i64;
const BatchToSpaceInputParams params{
InputInfo{data_et, data_sshape},
InputInfo{inputs_et, inputs_sshape_2D},
InputInfo{inputs_et, inputs_sshape_2D},
InputInfo{inputs_et, inputs_sshape_2D}};
try
{
auto batch_to_space = makeBatchToSpaceOp(params);
FAIL() << "Invalid rank for block_shape/crops_begin/crops_end inputs not detected";
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"block_shape and crops inputs must have rank 1.");
}
catch (...)
{
FAIL() << "Rank check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_incompatible_data_and_secondary_inputs_shapes)
{
Shape data_sshape{10, 26, 4, 4};
element::Type data_et = element::f32;
Shape inputs_sshape{5};
element::Type inputs_et = element::i64;
const BatchToSpaceInputParams params{
InputInfo{data_et, data_sshape},
InputInfo{inputs_et, inputs_sshape},
InputInfo{inputs_et, inputs_sshape},
InputInfo{inputs_et, inputs_sshape}};
try
{
auto batch_to_space = makeBatchToSpaceOp(params);
FAIL() << "Incompatible shapes for data and block_shape/crops_begin/crops_end inputs not detected";
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"block_shape and crop inputs must have same number of elements "
"as data input rank.");
}
catch (...)
{
FAIL() << "Compatibility shape check for data and block_shape/crops_begin/crops_end inputs failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_invalid_block_shape_input)
{
Shape data_sshape{100, 7, 13, 3};
element::Type data_et = element::f32;
Shape inputs_sshape{4};
element::Type inputs_et = element::i64;
auto data = make_shared<op::Parameter>(data_et, data_sshape);
auto block_shape = make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 10, 5, 1});
auto crops_begin = make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 0});
auto crops_end = make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 0});
try
{
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
FAIL() << "Invalid elements of block_shape input not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Elements of block_shape input must be greater or equal to one.");
}
catch (...)
{
FAIL() << "Greater than zero elements of block_shape input check failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_invalid_crops_input_values)
{
Shape data_sshape{100, 7, 13, 3};
element::Type data_et = element::f32;
Shape inputs_sshape{4};
element::Type inputs_et = element::i64;
try
{
auto data = make_shared<op::Parameter>(data_et, data_sshape);
auto block_shape =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 10, 5, 1});
auto crops_begin =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, -1});
auto crops_end =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
FAIL() << "Invalid crops_begin input values not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Elements of crops_begin and crops_end inputs must be greater or equal to zero.");
}
catch (...)
{
FAIL() << "Non-negative element check of crops_begin input values failed for unexpected reason";
}
try
{
auto data = make_shared<op::Parameter>(data_et, data_sshape);
auto block_shape =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 10, 5, 1});
auto crops_begin =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 0});
auto crops_end =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, -1, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
FAIL() << "Invalid crops_end input values not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Elements of crops_begin and crops_end inputs must be greater or equal to zero.");
}
catch (...)
{
FAIL() << "Non-negative element check of crops_end input values failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_incompatible_block_shape_input_values_with_data_shape)
{
Shape data_sshape{80, 7, 13, 3};
element::Type data_et = element::f32;
Shape inputs_sshape{4};
element::Type inputs_et = element::i64;
auto data = make_shared<op::Parameter>(data_et, data_sshape);
auto block_shape =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 10, 5, 1});
auto crops_begin =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 0});
auto crops_end =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 0});
try
{
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
FAIL() << "Incompatible data shape and block_shape input values not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"The input data's 'batch' axis size: 80 must be a multiple of product of block_shape values: 50");
}
catch (...)
{
FAIL() << "Data shape and block_shape input values check failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_invalid_crops_out_of_bounds)
{
Shape data_sshape{32, 4, 1, 3};
element::Type data_et = element::f32;
Shape inputs_sshape{4};
element::Type inputs_et = element::i64;
auto data = make_shared<op::Parameter>(data_et, data_sshape);
auto block_shape =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 2, 2, 1});
auto crops_begin =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 2});
auto crops_end =
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 2});
try
{
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
FAIL() << "Invalid out of bound crops values not detected";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"crops_begin[i] + crops_end[i] must be less or equal to block_shape[i] * input_shape[i]");
}
catch (...)
{
FAIL() << "Crops values check failed for unexpected reason";
}
}
TEST(type_prop, batch_to_space_output_shape_4D)
@ -28,12 +419,12 @@ TEST(type_prop, batch_to_space_output_shape_4D)
auto data = make_shared<op::Parameter>(element::f32, Shape{100, 7, 13, 3});
auto block_shape =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 10, 5, 1});
auto pads_begin =
auto crops_begin =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 0});
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
auto crops_end =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_shape(), (Shape{100 / (10 * 5), 7 * 10 - 3 - 3, 13 * 5 - 1, 3}));
@ -44,13 +435,12 @@ TEST(type_prop, batch_to_space_output_shape_5D)
auto data = make_shared<op::Parameter>(element::f32, Shape{960, 6, 13, 128, 16});
auto block_shape =
make_shared<op::Constant>(element::i32, Shape{5}, vector<int64_t>{1, 6, 5, 1, 16});
auto pads_begin =
auto crops_begin =
make_shared<op::Constant>(element::i32, Shape{5}, vector<int64_t>{0, 2, 0, 0, 0});
auto pads_end =
auto crops_end =
make_shared<op::Constant>(element::i32, Shape{5}, vector<int64_t>{0, 2, 1, 0, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_shape(),
@ -62,19 +452,19 @@ TEST(type_prop, batch_to_space_and_space_to_batch)
auto data = make_shared<op::Parameter>(element::f32, Shape{4800, 9, 11, 2});
auto block_shape =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 12, 100, 2});
auto pads_begin =
auto crops_begin =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 38, 1});
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 5, 38, 0});
auto crops_end =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 5, 38, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_shape(),
(Shape{4800 / (12 * 100 * 2), 9 * 12 - 3 - 5, 11 * 100 - 38 - 38, 2 * 2 - 1}));
auto space_to_batch =
make_shared<op::v1::SpaceToBatch>(batch_to_space, block_shape, pads_begin, pads_end);
make_shared<op::v1::SpaceToBatch>(batch_to_space, block_shape, crops_begin, crops_end);
ASSERT_EQ(space_to_batch->get_element_type(), element::f32);
ASSERT_EQ(space_to_batch->get_shape(), (Shape{4800, 9, 11, 2}));
}
@ -84,12 +474,12 @@ TEST(type_prop, batch_to_space_dynamic_shape_static_rank)
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
auto block_shape =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 10, 5, 1});
auto pads_begin =
auto crops_begin =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 0});
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
auto crops_end =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_output_partial_shape(0), PartialShape::dynamic(4));
@ -100,12 +490,12 @@ TEST(type_prop, batch_to_space_dynamic_shape_dynamic_rank)
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto block_shape =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 10, 5, 1});
auto pads_begin =
auto crops_begin =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 0});
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
auto crops_end =
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_output_partial_shape(0), PartialShape::dynamic());

View File

@ -0,0 +1,29 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "util/visitor.hpp"
using namespace std;
using namespace ngraph;
using ngraph::test::NodeBuilder;
TEST(attributes, batch_to_space_op)
{
NodeBuilder::get_ops().register_factory<op::v1::BatchToSpace>();
auto data = make_shared<op::Parameter>(element::f32, Shape{128, 4, 2, 2});
auto block_shape = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 2, 2, 2});
auto crops_begin = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 2, 0, 1});
auto crops_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 0, 1, 0});
auto batch2space = make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
NodeBuilder builder(batch2space);
const auto expected_attr_count = 0;
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
}