Revise BatchToSpace reference implementation (#6289)
* Add visitor unit test * BatchToSpace operation class refactored * Add RTTI definiton/declaration * Add checks in validate_and_infer_types * Add input/output checks in evaluate method * Add type_prop tests covering checks in validate_and_infer_types method * Fix incorrect variable names to specify crops_begin and crops_end inputs * Add backend tests * Use array stl to handle batch to space inputs in type_prop tests * Add more backend tests * Clean up manifest * Address review comments * Add node validation check for out of bounds crops * Add validation checks in evaluate method * Add error message to ngraph checks * Modify check for input rank to allow rank 4 or greater to align with spec and plugins impl
This commit is contained in:
parent
0d69e7fea2
commit
79f26cea7a
@ -5,7 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -27,8 +27,7 @@ namespace ngraph
|
||||
class NGRAPH_API BatchToSpace : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"BatchToSpace", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
BatchToSpace() = default;
|
||||
/// \brief Constructs a BatchToSpace operation.
|
||||
///
|
||||
|
@ -23,7 +23,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::BatchToSpace::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::BatchToSpace, "BatchToSpace", 1);
|
||||
|
||||
ngraph::op::v1::BatchToSpace::BatchToSpace(const ngraph::Output<ngraph::Node>& data,
|
||||
const ngraph::Output<ngraph::Node>& block_shape,
|
||||
@ -37,83 +37,135 @@ ngraph::op::v1::BatchToSpace::BatchToSpace(const ngraph::Output<ngraph::Node>& d
|
||||
void op::v1::BatchToSpace::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_BatchToSpace_validate_and_infer_types);
|
||||
PartialShape data_pshape = get_input_partial_shape(0);
|
||||
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
const auto& block_shape_type = get_input_element_type(1);
|
||||
const auto& crops_begin_type = get_input_element_type(2);
|
||||
const auto& crops_end_type = get_input_element_type(3);
|
||||
const auto& data_et = get_input_element_type(0);
|
||||
const auto& block_shape_et = get_input_element_type(1);
|
||||
const auto& crops_begin_et = get_input_element_type(2);
|
||||
const auto& crops_end_et = get_input_element_type(3);
|
||||
|
||||
element::Type inputs_integer_et{};
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
element::Type::merge(inputs_integer_et, crops_begin_et, crops_end_et) &&
|
||||
element::Type::merge(inputs_integer_et, inputs_integer_et, block_shape_et),
|
||||
"block_shape, crops_begin and crops_end inputs must have same element type. Got: ",
|
||||
block_shape_et,
|
||||
", ",
|
||||
crops_begin_et,
|
||||
" and ",
|
||||
crops_end_et);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
block_shape_type.is_integral_number(),
|
||||
"block_shape must be an integral number but got (",
|
||||
block_shape_type,
|
||||
").");
|
||||
inputs_integer_et.is_integral_number(),
|
||||
"block_shape and crops inputs must have integer element type. Got: ",
|
||||
inputs_integer_et);
|
||||
|
||||
const PartialShape& data_pshape = get_input_partial_shape(0);
|
||||
const PartialShape& block_shape_ps = get_input_partial_shape(1);
|
||||
const PartialShape& crops_begin_ps = get_input_partial_shape(2);
|
||||
const PartialShape& crops_end_ps = get_input_partial_shape(3);
|
||||
|
||||
PartialShape inputs_same_ps{PartialShape::dynamic()};
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
PartialShape::merge_into(inputs_same_ps, crops_begin_ps) &&
|
||||
PartialShape::merge_into(inputs_same_ps, crops_end_ps) &&
|
||||
PartialShape::merge_into(inputs_same_ps, block_shape_ps),
|
||||
"block_shape, crops_begin and crops_end inputs must have the same shape. Got: ",
|
||||
block_shape_ps,
|
||||
", ",
|
||||
crops_begin_ps,
|
||||
" and ",
|
||||
crops_end_ps);
|
||||
|
||||
const Rank inputs_rank_one = inputs_same_ps.rank();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
crops_begin_type.is_integral_number(),
|
||||
"crops_begin must be an integral number but got (",
|
||||
crops_begin_type,
|
||||
").");
|
||||
inputs_rank_one.compatible(1),
|
||||
"block_shape and crops inputs must have rank 1. Got: ",
|
||||
inputs_rank_one);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
crops_end_type.is_integral_number(),
|
||||
"crops_end must be an integral number but got (",
|
||||
crops_end_type,
|
||||
").");
|
||||
const Rank data_rank = data_pshape.rank();
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(data_rank.get_length() >= 4),
|
||||
"data input must have rank greater than or equal to 4. Got: ",
|
||||
data_rank.get_length());
|
||||
|
||||
auto data = input_value(0);
|
||||
auto block = input_value(1);
|
||||
auto crops_begin = input_value(2);
|
||||
auto crops_end = input_value(3);
|
||||
if (inputs_same_ps.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_rank.get_length() == inputs_same_ps[0].get_length(),
|
||||
"block_shape and crop inputs must have same number of elements "
|
||||
"as data input rank. Got: ",
|
||||
inputs_same_ps[0],
|
||||
" and ",
|
||||
data_rank);
|
||||
}
|
||||
}
|
||||
|
||||
auto block_const = get_constant_from_source(block);
|
||||
auto crops_begin_const = get_constant_from_source(crops_begin);
|
||||
auto crops_end_const = get_constant_from_source(crops_end);
|
||||
const auto block_const = get_constant_from_source(input_value(1));
|
||||
const auto crops_begin_const = get_constant_from_source(input_value(2));
|
||||
const auto crops_end_const = get_constant_from_source(input_value(3));
|
||||
|
||||
if (block_const && crops_begin_const && crops_end_const && data_pshape.is_static())
|
||||
{
|
||||
const auto& data_shape = data.get_shape();
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
(data_shape.size() >= 2),
|
||||
"The data tensor with rank lower than 2 is not supported (data rank: ",
|
||||
data_shape.size(),
|
||||
")");
|
||||
const Shape& data_sshape = data_pshape.to_shape();
|
||||
|
||||
auto block_val = block_const->cast_vector<int64_t>();
|
||||
auto crops_begin_val = crops_begin_const->cast_vector<int64_t>();
|
||||
auto crops_end_val = crops_end_const->cast_vector<int64_t>();
|
||||
|
||||
int64_t block_prod = 1;
|
||||
for (long val : block_val)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this, val > 0, "block_shape values must be greater than 0");
|
||||
block_prod *= val;
|
||||
}
|
||||
bool block_vals_valid =
|
||||
std::all_of(begin(block_val), end(block_val), [](int64_t elem) { return elem >= 1; });
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
block_vals_valid,
|
||||
"Elements of block_shape input must be greater or equal to one.");
|
||||
|
||||
bool crops_begin_vals_valid = std::all_of(
|
||||
begin(crops_begin_val), end(crops_begin_val), [](int64_t elem) { return elem >= 0; });
|
||||
bool crops_end_vals_valid = std::all_of(
|
||||
begin(crops_end_val), end(crops_end_val), [](int64_t elem) { return elem >= 0; });
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
crops_begin_vals_valid && crops_end_vals_valid,
|
||||
"Elements of crops_begin and crops_end inputs must be greater or equal to zero.");
|
||||
|
||||
int64_t block_prod =
|
||||
std::accumulate(begin(block_val), end(block_val), 1, std::multiplies<int64_t>());
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_shape.at(0) % block_prod == 0,
|
||||
"BatchToSpace: The input data's 'batch' axis size: ",
|
||||
data_shape.at(0),
|
||||
" must be a multiple of ",
|
||||
data_sshape[0] % block_prod == 0,
|
||||
"The input data's 'batch' axis size: ",
|
||||
data_sshape[0],
|
||||
" must be a multiple of",
|
||||
" product of block_shape values: ",
|
||||
block_prod);
|
||||
|
||||
Shape output_shape = {static_cast<size_t>(data_shape[0] / block_prod)};
|
||||
for (size_t idx = 1; idx < data_shape.size(); ++idx)
|
||||
for (size_t idx = 0; idx < data_sshape.size(); idx++)
|
||||
{
|
||||
output_shape.push_back(static_cast<size_t>(data_shape[idx] * block_val[idx] -
|
||||
crops_begin_val[idx] - crops_end_val[idx]));
|
||||
const bool is_valid_crops_and_shape =
|
||||
crops_begin_val[idx] + crops_end_val[idx] <=
|
||||
block_val[idx] * static_cast<int64_t>(data_sshape[idx]);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
is_valid_crops_and_shape,
|
||||
"crops_begin[i] + crops_end[i] must be less or equal to "
|
||||
"block_shape[i] * input_shape[i]");
|
||||
}
|
||||
|
||||
Shape output_sshape = {static_cast<size_t>(data_sshape[0] / block_prod)};
|
||||
for (size_t idx = 1; idx < data_sshape.size(); ++idx)
|
||||
{
|
||||
output_sshape.push_back(static_cast<size_t>(data_sshape[idx] * block_val[idx] -
|
||||
crops_begin_val[idx] - crops_end_val[idx]));
|
||||
}
|
||||
|
||||
set_output_size(1);
|
||||
set_output_type(0, data_type, output_shape);
|
||||
set_output_type(0, data_et, output_sshape);
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, data_type, PartialShape::dynamic(data_pshape.rank()));
|
||||
set_output_type(0, data_et, PartialShape::dynamic(data_rank));
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,16 +196,52 @@ namespace
|
||||
return false;
|
||||
}
|
||||
auto data_shape = data->get_shape();
|
||||
|
||||
if (!(data->get_shape().size() == 4 || data->get_shape().size() == 5))
|
||||
auto data_rank = data_shape.size();
|
||||
if (!(data_rank == 4 || data_rank == 5))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t block_values_size = shape_size(inputs[1]->get_shape());
|
||||
size_t crops_begin_size = shape_size(inputs[2]->get_shape());
|
||||
size_t crops_end_size = shape_size(inputs[3]->get_shape());
|
||||
NGRAPH_CHECK(
|
||||
block_values_size == data_rank && crops_begin_size == data_rank &&
|
||||
crops_end_size == data_rank,
|
||||
"Invalid block_shape/crops_begin/crops_end shape with respect to rank of data input");
|
||||
|
||||
const auto* block_values = inputs[1]->get_data_ptr<int64_t>();
|
||||
const auto* crops_begin_values = inputs[2]->get_data_ptr<int64_t>();
|
||||
const auto* crops_end_values = inputs[3]->get_data_ptr<int64_t>();
|
||||
|
||||
const bool block_vals_valid = std::all_of(
|
||||
block_values, block_values + block_values_size, [](int64_t elem) { return elem >= 1; });
|
||||
NGRAPH_CHECK(block_vals_valid, "Invalid element values of block_shape input");
|
||||
|
||||
const bool crops_begin_vals_valid = std::all_of(crops_begin_values,
|
||||
crops_begin_values + crops_begin_size,
|
||||
[](int64_t elem) { return elem >= 0; });
|
||||
const bool crops_end_vals_valid = std::all_of(crops_end_values,
|
||||
crops_end_values + crops_end_size,
|
||||
[](int64_t elem) { return elem >= 0; });
|
||||
NGRAPH_CHECK(crops_begin_vals_valid && crops_end_vals_valid,
|
||||
"Invalid element values of crops_begin/crops_end input/s");
|
||||
|
||||
const std::size_t block_prod = std::accumulate(
|
||||
block_values, block_values + block_values_size, 1UL, std::multiplies<std::size_t>());
|
||||
NGRAPH_CHECK(data_shape[0] % block_prod == 0,
|
||||
"Invalid batch axis of data input with respect to block_shape values");
|
||||
|
||||
for (size_t i = 0; i < data_rank; i++)
|
||||
{
|
||||
const bool is_valid_crops_and_shape =
|
||||
crops_begin_values[i] + crops_end_values[i] <=
|
||||
block_values[i] * static_cast<int64_t>(data_shape[i]);
|
||||
NGRAPH_CHECK(
|
||||
is_valid_crops_and_shape,
|
||||
"Invalid crops values (out of bounds) with respect to the shape of data input");
|
||||
}
|
||||
|
||||
Shape dispersed_shape(1);
|
||||
dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end());
|
||||
std::vector<size_t> axes_order(block_values_size + 1);
|
||||
@ -249,7 +337,9 @@ namespace
|
||||
bool ngraph::op::v1::BatchToSpace::evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_BatchToSpace);
|
||||
NGRAPH_OP_SCOPE(v1_BatchToSpace_evaluate);
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 4));
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
|
||||
return batch_to_space_evaluate(outputs, inputs);
|
||||
}
|
||||
|
||||
|
@ -231,6 +231,7 @@ set(SRC
|
||||
visitors/op/adaptive_max_pool.cpp
|
||||
visitors/op/atan.cpp
|
||||
visitors/op/batch_norm.cpp
|
||||
visitors/op/batch_to_space.cpp
|
||||
visitors/op/broadcast.cpp
|
||||
visitors/op/bucketize.cpp
|
||||
visitors/op/ceiling.cpp
|
||||
@ -373,6 +374,7 @@ set(MULTI_TEST_SRC
|
||||
backend/auto_broadcast.in.cpp
|
||||
backend/avg_pool.in.cpp
|
||||
backend/batch_norm.in.cpp
|
||||
backend/batch_to_space.in.cpp
|
||||
backend/broadcast.in.cpp
|
||||
backend/bucketize.in.cpp
|
||||
backend/builder_reduce_ops_opset1.in.cpp
|
||||
|
179
ngraph/test/backend/batch_to_space.in.cpp
Normal file
179
ngraph/test/backend/batch_to_space.in.cpp
Normal file
@ -0,0 +1,179 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/ndarray.hpp"
|
||||
#include "util/test_case.hpp"
|
||||
#include "util/engine/test_engines.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
|
||||
|
||||
namespace
|
||||
{
|
||||
template<typename dataType>
|
||||
struct BatchToSpaceParams
|
||||
{
|
||||
using Data = test::NDArrayBase<dataType>;
|
||||
using BlockShape = test::NDArrayBase<int64_t>;
|
||||
using Crops = test::NDArrayBase<int64_t>;
|
||||
|
||||
BatchToSpaceParams(Data in_data,
|
||||
BlockShape block_shape,
|
||||
Crops crops_begin,
|
||||
Crops crops_end,
|
||||
Data expected_output)
|
||||
: m_data{std::move(in_data)}
|
||||
, m_block_shape{std::move(block_shape)}
|
||||
, m_crops_begin{std::move(crops_begin)}
|
||||
, m_crops_end{std::move(crops_end)}
|
||||
, m_expected_output{std::move(expected_output)}
|
||||
{
|
||||
}
|
||||
|
||||
Data m_data;
|
||||
BlockShape m_block_shape;
|
||||
Crops m_crops_begin;
|
||||
Crops m_crops_end;
|
||||
Data m_expected_output;
|
||||
};
|
||||
|
||||
template <typename dataType>
|
||||
static void BatchToSpaceTestExecute(const BatchToSpaceParams<dataType>& params)
|
||||
{
|
||||
const auto data =
|
||||
make_shared<op::Parameter>(element::from<dataType>(), params.m_data.get_shape());
|
||||
|
||||
const auto block_shape = op::Constant::create(
|
||||
element::i64, params.m_block_shape.get_shape(), params.m_block_shape.get_vector());
|
||||
|
||||
const auto crops_begin = op::Constant::create(
|
||||
element::i64, params.m_crops_begin.get_shape(), params.m_crops_begin.get_vector());
|
||||
|
||||
const auto crops_end = op::Constant::create(
|
||||
element::i64, params.m_crops_end.get_shape(), params.m_crops_end.get_vector());
|
||||
|
||||
const auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
auto f = make_shared<Function>(batch_to_space, ParameterVector{data});
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input(params.m_data.get_vector());
|
||||
test_case.add_expected_output(params.m_expected_output.get_vector());
|
||||
test_case.run_with_tolerance_as_fp(1e-4f);
|
||||
}
|
||||
|
||||
class BatchToSpaceTestFloat : public testing::TestWithParam<BatchToSpaceParams<float>>
|
||||
{
|
||||
};
|
||||
} // namespace
|
||||
|
||||
NGRAPH_TEST_P(${BACKEND_NAME}, BatchToSpaceTestFloat, BatchToSpaceTestFloatCases)
|
||||
{
|
||||
BatchToSpaceTestExecute(GetParam());
|
||||
}
|
||||
|
||||
const test::NDArray<float, 4> input_with_shape_4x1x1x3(
|
||||
{{{{1.0f, 2.0f, 3.0f}}},
|
||||
{{{4.0f, 5.0f, 6.0f}}},
|
||||
{{{7.0f, 8.0f, 9.0f}}},
|
||||
{{{10.0f, 11.0f, 12.0f}}}});
|
||||
|
||||
const test::NDArray<float, 4> input_with_shape_4x1x2x3(
|
||||
{{{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}},
|
||||
{{{7.0f, 8.0f, 9.0f}, {10.0f, 11.0f, 12.0f}}},
|
||||
{{{13.0f, 14.0f, 15.0f}, {16.0f, 17.0f, 18.0f}}},
|
||||
{{{19.0f, 20.0f, 21.0f}, {22.0f, 23.0f, 24.0f}}}});
|
||||
|
||||
const test::NDArray<int64_t, 1> zero_crops_4d({0, 0, 0, 0});
|
||||
|
||||
NGRAPH_INSTANTIATE_TEST_SUITE_P(
|
||||
${BACKEND_NAME},
|
||||
batch_to_space_4d_without_crops,
|
||||
BatchToSpaceTestFloat,
|
||||
testing::Values(
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x1x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 1, 2}),
|
||||
zero_crops_4d,
|
||||
zero_crops_4d,
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f}}},
|
||||
{{{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x1x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 1}),
|
||||
zero_crops_4d,
|
||||
zero_crops_4d,
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 2.0f, 3.0f}, {7.0f, 8.0f, 9.0f}}},
|
||||
{{{4.0f, 5.0f, 6.0f}, {10.0f, 11.0f, 12.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x1x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
|
||||
zero_crops_4d,
|
||||
zero_crops_4d,
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f},
|
||||
{7.0f, 10.0f, 8.0f, 11.0f, 9.0f, 12.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 1, 2}),
|
||||
zero_crops_4d,
|
||||
zero_crops_4d,
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 13.0f, 2.0f, 14.0f, 3.0f, 15.0f},
|
||||
{4.0f, 16.0f, 5.0f, 17.0f, 6.0f, 18.0f}}},
|
||||
{{{7.0f, 19.0f, 8.0f, 20.0f, 9.0f, 21.0f},
|
||||
{10.0f, 22.0f, 11.0f, 23.0f, 12.0f, 24.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 1}),
|
||||
zero_crops_4d,
|
||||
zero_crops_4d,
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 2.0f, 3.0f}, {13.0f, 14.0f, 15.0f},
|
||||
{4.0f, 5.0f, 6.0f}, {16.0f, 17.0f, 18.0f}}},
|
||||
{{{7.0f, 8.0f, 9.0f}, {19.0f, 20.0f, 21.0f},
|
||||
{10.0f, 11.0f, 12.0f}, {22.0f, 23.0f, 24.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
|
||||
zero_crops_4d,
|
||||
zero_crops_4d,
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f},
|
||||
{13.0f, 19.0f, 14.0f, 20.0f, 15.0f, 21.0f},
|
||||
{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f},
|
||||
{16.0f, 22.0f, 17.0f, 23.0f, 18.0f, 24.0f}}}})}));
|
||||
|
||||
NGRAPH_INSTANTIATE_TEST_SUITE_P(
|
||||
${BACKEND_NAME},
|
||||
batch_to_space_4d_crops,
|
||||
BatchToSpaceTestFloat,
|
||||
testing::Values(
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
|
||||
test::NDArray<int64_t, 1>({0, 0, 0, 0}),
|
||||
test::NDArray<int64_t, 1>({0, 0, 0, 2}),
|
||||
test::NDArray<float, 4>(
|
||||
{{{{1.0f, 7.0f, 2.0f, 8.0f},
|
||||
{13.0f, 19.0f, 14.0f, 20.0f},
|
||||
{4.0f, 10.0f, 5.0f, 11.0f},
|
||||
{16.0f, 22.0f, 17.0f, 23.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
|
||||
test::NDArray<int64_t, 1>({0, 0, 0, 2}),
|
||||
test::NDArray<int64_t, 1>({0, 0, 0, 0}),
|
||||
test::NDArray<float, 4>(
|
||||
{{{{2.0f, 8.0f, 3.0f, 9.0f},
|
||||
{14.0f, 20.0f, 15.0f, 21.0f},
|
||||
{5.0f, 11.0f, 6.0f, 12.0f},
|
||||
{17.0f, 23.0f, 18.0f, 24.0f}}}})},
|
||||
BatchToSpaceParams<float>{input_with_shape_4x1x2x3,
|
||||
test::NDArray<int64_t, 1>({1, 1, 2, 2}),
|
||||
test::NDArray<int64_t, 1>({0, 0, 1, 0}),
|
||||
test::NDArray<int64_t, 1>({0, 0, 1, 0}),
|
||||
test::NDArray<float, 4>(
|
||||
{{{{13.0f, 19.0f, 14.0f, 20.0f, 15.0f, 21.0f},
|
||||
{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f}}}})}));
|
@ -676,10 +676,6 @@ conv_bias_bprop_2d
|
||||
# Cannot cast ngraph node ConvolutionBiasAdd to CNNLayer!
|
||||
conv_bias_add_2d
|
||||
|
||||
# Unsupported primitive of type: SpaceToBatch
|
||||
space_to_batch
|
||||
batch_to_space
|
||||
|
||||
# [Validation] Argument must have rank >= 2 and <= 4 (argument shape: {1,2,2,2,3})
|
||||
normalize_across_1axis_5d
|
||||
normalize_across_123axes_5d
|
||||
|
@ -2,6 +2,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
@ -9,18 +11,407 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, batch_to_space_output_shape_2D)
|
||||
namespace {
|
||||
constexpr size_t data_input_idx = 0;
|
||||
constexpr size_t block_shape_input_idx = 1;
|
||||
constexpr size_t crops_begin_input_idx = 2;
|
||||
constexpr size_t crops_end_input_idx = 3;
|
||||
constexpr size_t batch_to_space_required_inputs = 4;
|
||||
struct InputInfo
|
||||
{
|
||||
element::Type in_et;
|
||||
PartialShape in_pshape;
|
||||
};
|
||||
|
||||
using BatchToSpaceInputParams = std::array<InputInfo, batch_to_space_required_inputs>;
|
||||
|
||||
std::shared_ptr<Node> makeBatchToSpaceOp(const BatchToSpaceInputParams& p)
|
||||
{
|
||||
if(p.size() != batch_to_space_required_inputs)
|
||||
{
|
||||
throw runtime_error("BatchToSpace requires 4 inputs");
|
||||
}
|
||||
auto data = make_shared<op::Parameter>(
|
||||
p.at(data_input_idx).in_et, p.at(data_input_idx).in_pshape);
|
||||
auto block_shape = make_shared<op::Parameter>(
|
||||
p.at(block_shape_input_idx).in_et, p.at(block_shape_input_idx).in_pshape);
|
||||
auto crops_begin = make_shared<op::Parameter>(
|
||||
p.at(crops_begin_input_idx).in_et, p.at(crops_begin_input_idx).in_pshape);
|
||||
auto crops_end = make_shared<op::Parameter>(
|
||||
p.at(crops_end_input_idx).in_et, p.at(crops_end_input_idx).in_pshape);
|
||||
return make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(type_prop, batch_to_space_incompatible_input_element_types)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{10, 26});
|
||||
auto block_shape = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 5});
|
||||
auto pads_begin = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 2});
|
||||
auto pads_end = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 0});
|
||||
element::Type float_et = element::f32;
|
||||
element::Type integer64_et = element::i64;
|
||||
element::Type integer32_et = element::i32;
|
||||
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape inputs_sshape{4};
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_shape(), (Shape{10 / 5, 26 * 5 - 2}));
|
||||
vector<BatchToSpaceInputParams> test_cases;
|
||||
test_cases.push_back(
|
||||
BatchToSpaceInputParams{
|
||||
InputInfo{float_et, data_sshape},
|
||||
InputInfo{integer64_et, inputs_sshape},
|
||||
InputInfo{integer32_et, inputs_sshape},
|
||||
InputInfo{integer32_et, inputs_sshape}});
|
||||
|
||||
test_cases.push_back(
|
||||
BatchToSpaceInputParams{
|
||||
InputInfo{float_et, data_sshape},
|
||||
InputInfo{integer32_et, inputs_sshape},
|
||||
InputInfo{integer64_et, inputs_sshape},
|
||||
InputInfo{integer32_et, inputs_sshape}});
|
||||
|
||||
test_cases.push_back(
|
||||
BatchToSpaceInputParams{
|
||||
InputInfo{float_et, data_sshape},
|
||||
InputInfo{integer64_et, inputs_sshape},
|
||||
InputInfo{float_et, inputs_sshape},
|
||||
InputInfo{float_et, inputs_sshape}});
|
||||
|
||||
for (const auto& test_case : test_cases)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto batch_to_space = makeBatchToSpaceOp(test_case);
|
||||
FAIL() << "Incompatible element types for block_shape/crops_begin/crops_end inputs not detected";
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"block_shape, crops_begin and crops_end inputs must have same element type.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Element type check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_input_element_types)
|
||||
{
|
||||
element::Type float_et = element::f32;
|
||||
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape inputs_sshape{4};
|
||||
|
||||
const BatchToSpaceInputParams params{
|
||||
InputInfo{float_et, data_sshape},
|
||||
InputInfo{float_et, inputs_sshape},
|
||||
InputInfo{float_et, inputs_sshape},
|
||||
InputInfo{float_et, inputs_sshape}};
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space = makeBatchToSpaceOp(params);
|
||||
FAIL() << "Invalid non-integer element type for block_shape/crops_begin/crops_end inputs not detected";
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"block_shape and crops inputs must have integer element type.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Element type check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_data_input_rank)
|
||||
{
|
||||
Shape data_sshape{4, 2};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{2};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
const BatchToSpaceInputParams params{
|
||||
InputInfo{data_et, data_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape}};
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space = makeBatchToSpaceOp(params);
|
||||
FAIL() << "Invalid rank of data input not detected";
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "data input must have rank greater than or equal to 4");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Rank check for data input failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_incompatible_secondary_inputs_shapes)
|
||||
{
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape_1D{4};
|
||||
Shape inputs_sshape_2D{4, 1};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
vector<BatchToSpaceInputParams> test_cases;
|
||||
test_cases.push_back(
|
||||
BatchToSpaceInputParams{
|
||||
InputInfo{data_et, data_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape_2D},
|
||||
InputInfo{inputs_et, inputs_sshape_1D},
|
||||
InputInfo{inputs_et, inputs_sshape_1D}});
|
||||
|
||||
test_cases.push_back(
|
||||
BatchToSpaceInputParams{
|
||||
InputInfo{data_et, data_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape_1D},
|
||||
InputInfo{inputs_et, inputs_sshape_2D},
|
||||
InputInfo{inputs_et, inputs_sshape_1D}});
|
||||
|
||||
test_cases.push_back(
|
||||
BatchToSpaceInputParams{
|
||||
InputInfo{data_et, data_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape_1D},
|
||||
InputInfo{inputs_et, inputs_sshape_2D},
|
||||
InputInfo{inputs_et, inputs_sshape_2D}});
|
||||
|
||||
for (const auto& test_case : test_cases)
|
||||
{
|
||||
try
|
||||
{
|
||||
auto batch_to_space = makeBatchToSpaceOp(test_case);
|
||||
FAIL() << "Incompatible shapes for block_shape/crops_begin/crops_end inputs not detected";
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"block_shape, crops_begin and crops_end inputs must have the same shape.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Shapes check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_secondary_inputs_rank)
|
||||
{
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape_2D{4, 1};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
const BatchToSpaceInputParams params{
|
||||
InputInfo{data_et, data_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape_2D},
|
||||
InputInfo{inputs_et, inputs_sshape_2D},
|
||||
InputInfo{inputs_et, inputs_sshape_2D}};
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space = makeBatchToSpaceOp(params);
|
||||
FAIL() << "Invalid rank for block_shape/crops_begin/crops_end inputs not detected";
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"block_shape and crops inputs must have rank 1.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Rank check for block_shape/crops_begin/crops_end inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_incompatible_data_and_secondary_inputs_shapes)
|
||||
{
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{5};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
const BatchToSpaceInputParams params{
|
||||
InputInfo{data_et, data_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape},
|
||||
InputInfo{inputs_et, inputs_sshape}};
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space = makeBatchToSpaceOp(params);
|
||||
FAIL() << "Incompatible shapes for data and block_shape/crops_begin/crops_end inputs not detected";
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"block_shape and crop inputs must have same number of elements "
|
||||
"as data input rank.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Compatibility shape check for data and block_shape/crops_begin/crops_end inputs failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_block_shape_input)
|
||||
{
|
||||
Shape data_sshape{100, 7, 13, 3};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{4};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
auto data = make_shared<op::Parameter>(data_et, data_sshape);
|
||||
auto block_shape = make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 10, 5, 1});
|
||||
auto crops_begin = make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 0});
|
||||
auto crops_end = make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 0});
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
FAIL() << "Invalid elements of block_shape input not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Elements of block_shape input must be greater or equal to one.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Greater than zero elements of block_shape input check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_crops_input_values)
|
||||
{
|
||||
Shape data_sshape{100, 7, 13, 3};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{4};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
try
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(data_et, data_sshape);
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 10, 5, 1});
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, -1});
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
FAIL() << "Invalid crops_begin input values not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Elements of crops_begin and crops_end inputs must be greater or equal to zero.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Non-negative element check of crops_begin input values failed for unexpected reason";
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(data_et, data_sshape);
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 10, 5, 1});
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 0});
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, -1, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
FAIL() << "Invalid crops_end input values not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"Elements of crops_begin and crops_end inputs must be greater or equal to zero.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Non-negative element check of crops_end input values failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_incompatible_block_shape_input_values_with_data_shape)
|
||||
{
|
||||
Shape data_sshape{80, 7, 13, 3};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{4};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
auto data = make_shared<op::Parameter>(data_et, data_sshape);
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 10, 5, 1});
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 0});
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 0});
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
FAIL() << "Incompatible data shape and block_shape input values not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"The input data's 'batch' axis size: 80 must be a multiple of product of block_shape values: 50");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Data shape and block_shape input values check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_crops_out_of_bounds)
|
||||
{
|
||||
Shape data_sshape{32, 4, 1, 3};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{4};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
auto data = make_shared<op::Parameter>(data_et, data_sshape);
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{1, 2, 2, 1});
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 1, 2});
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(inputs_et, inputs_sshape, vector<int64_t>{0, 3, 0, 2});
|
||||
|
||||
try
|
||||
{
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
FAIL() << "Invalid out of bound crops values not detected";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
"crops_begin[i] + crops_end[i] must be less or equal to block_shape[i] * input_shape[i]");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Crops values check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_output_shape_4D)
|
||||
@ -28,12 +419,12 @@ TEST(type_prop, batch_to_space_output_shape_4D)
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{100, 7, 13, 3});
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 10, 5, 1});
|
||||
auto pads_begin =
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 0});
|
||||
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
|
||||
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_shape(), (Shape{100 / (10 * 5), 7 * 10 - 3 - 3, 13 * 5 - 1, 3}));
|
||||
@ -44,13 +435,12 @@ TEST(type_prop, batch_to_space_output_shape_5D)
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{960, 6, 13, 128, 16});
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(element::i32, Shape{5}, vector<int64_t>{1, 6, 5, 1, 16});
|
||||
auto pads_begin =
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(element::i32, Shape{5}, vector<int64_t>{0, 2, 0, 0, 0});
|
||||
auto pads_end =
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(element::i32, Shape{5}, vector<int64_t>{0, 2, 1, 0, 0});
|
||||
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_shape(),
|
||||
@ -62,19 +452,19 @@ TEST(type_prop, batch_to_space_and_space_to_batch)
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{4800, 9, 11, 2});
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 12, 100, 2});
|
||||
auto pads_begin =
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 38, 1});
|
||||
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 5, 38, 0});
|
||||
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 5, 38, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_shape(),
|
||||
(Shape{4800 / (12 * 100 * 2), 9 * 12 - 3 - 5, 11 * 100 - 38 - 38, 2 * 2 - 1}));
|
||||
|
||||
auto space_to_batch =
|
||||
make_shared<op::v1::SpaceToBatch>(batch_to_space, block_shape, pads_begin, pads_end);
|
||||
make_shared<op::v1::SpaceToBatch>(batch_to_space, block_shape, crops_begin, crops_end);
|
||||
ASSERT_EQ(space_to_batch->get_element_type(), element::f32);
|
||||
ASSERT_EQ(space_to_batch->get_shape(), (Shape{4800, 9, 11, 2}));
|
||||
}
|
||||
@ -84,12 +474,12 @@ TEST(type_prop, batch_to_space_dynamic_shape_static_rank)
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(4));
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 10, 5, 1});
|
||||
auto pads_begin =
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 0});
|
||||
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
|
||||
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_output_partial_shape(0), PartialShape::dynamic(4));
|
||||
@ -100,12 +490,12 @@ TEST(type_prop, batch_to_space_dynamic_shape_dynamic_rank)
|
||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 10, 5, 1});
|
||||
auto pads_begin =
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 1, 0});
|
||||
auto pads_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
|
||||
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 3, 0, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, pads_begin, pads_end);
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_output_partial_shape(0), PartialShape::dynamic());
|
||||
|
29
ngraph/test/visitors/op/batch_to_space.cpp
Normal file
29
ngraph/test/visitors/op/batch_to_space.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
|
||||
TEST(attributes, batch_to_space_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<op::v1::BatchToSpace>();
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{128, 4, 2, 2});
|
||||
auto block_shape = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 2, 2, 2});
|
||||
auto crops_begin = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 2, 0, 1});
|
||||
auto crops_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 0, 1, 0});
|
||||
auto batch2space = make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
NodeBuilder builder(batch2space);
|
||||
const auto expected_attr_count = 0;
|
||||
|
||||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||
}
|
Loading…
Reference in New Issue
Block a user