Enable lower input rank for BatchToSpace operation by fallback (#6553)

* Enable lower input rank for BatchToSpace operation by fallback

* Added conditions to check valid data input rank for evaluate and has_evaluate methods
This commit is contained in:
Gabriele Galiero Casay 2021-07-08 06:47:01 +02:00 committed by GitHub
parent 340583fa35
commit a4fef45e0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 30 deletions

View File

@ -16,6 +16,39 @@ const std::vector<InferenceEngine::Precision> net_precisions = {
InferenceEngine::Precision::I32
};
const std::vector<std::vector<size_t>> data_shapes_2D = {
{12, 4},
{48, 3}
};
const std::vector<std::vector<int64_t>> block_shapes_2D = {
{1, 2},
{1, 6}
};
const std::vector<std::vector<int64_t>> crops_2D = {
{0, 0},
{0, 1}
};
const auto batch_to_space_2d_tests = ::testing::Combine(
::testing::ValuesIn(block_shapes_2D),
::testing::ValuesIn(crops_2D),
::testing::ValuesIn(crops_2D),
::testing::ValuesIn(data_shapes_2D),
::testing::ValuesIn(net_precisions),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(CommonTestUtils::DEVICE_CPU));
INSTANTIATE_TEST_CASE_P(
smoke_BatchToSpace_2D,
BatchToSpaceLayerTest,
batch_to_space_2d_tests,
BatchToSpaceLayerTest::getTestCaseName);
const std::vector<std::vector<size_t>> data_shapes_4D = {
{4, 1, 2, 2},
{4, 3, 2, 2},
@ -39,7 +72,7 @@ const std::vector<std::vector<int64_t>> crops_end_4D = {
{0, 0, 0, 2}
};
const auto space_to_batch_4d_spatial_dims_tests = ::testing::Combine(
const auto batch_to_space_4d_spatial_dims_tests = ::testing::Combine(
::testing::Values(block_shapes_4D[0]),
::testing::ValuesIn(crops_begin_4D),
::testing::ValuesIn(crops_end_4D),
@ -51,7 +84,7 @@ const auto space_to_batch_4d_spatial_dims_tests = ::testing::Combine(
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(CommonTestUtils::DEVICE_CPU));
const auto space_to_batch_4d_channel_dim_tests = ::testing::Combine(
const auto batch_to_space_4d_channel_dim_tests = ::testing::Combine(
::testing::Values(block_shapes_4D[1]),
::testing::Values(crops_begin_4D[0]),
::testing::Values(crops_end_4D[0]),
@ -66,13 +99,13 @@ const auto space_to_batch_4d_channel_dim_tests = ::testing::Combine(
INSTANTIATE_TEST_CASE_P(
smoke_BatchToSpace_4D_spatial_dims,
BatchToSpaceLayerTest,
space_to_batch_4d_spatial_dims_tests,
batch_to_space_4d_spatial_dims_tests,
BatchToSpaceLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(
smoke_BatchToSpace_4D_channel_dim,
BatchToSpaceLayerTest,
space_to_batch_4d_channel_dim_tests,
batch_to_space_4d_channel_dim_tests,
BatchToSpaceLayerTest::getTestCaseName);
const std::vector<std::vector<size_t>> data_shapes_5D = {
@ -96,7 +129,7 @@ const std::vector<std::vector<int64_t>> crops_end_5D = {
{0, 0, 0, 0, 1}
};
const auto space_to_batch_5d_spatial_dims_tests = ::testing::Combine(
const auto batch_to_space_5d_spatial_dims_tests = ::testing::Combine(
::testing::Values(block_shapes_5D[0]),
::testing::ValuesIn(crops_begin_5D),
::testing::ValuesIn(crops_end_5D),
@ -108,7 +141,7 @@ const auto space_to_batch_5d_spatial_dims_tests = ::testing::Combine(
::testing::Values(InferenceEngine::Layout::ANY),
::testing::Values(CommonTestUtils::DEVICE_CPU));
const auto space_to_batch_5d_channel_dim_tests = ::testing::Combine(
const auto batch_to_space_5d_channel_dim_tests = ::testing::Combine(
::testing::Values(block_shapes_5D[1]),
::testing::Values(crops_begin_5D[0]),
::testing::Values(crops_end_5D[0]),
@ -123,13 +156,13 @@ const auto space_to_batch_5d_channel_dim_tests = ::testing::Combine(
INSTANTIATE_TEST_CASE_P(
smoke_BatchToSpace_5D_spatial_dims,
BatchToSpaceLayerTest,
space_to_batch_5d_spatial_dims_tests,
batch_to_space_5d_spatial_dims_tests,
BatchToSpaceLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(
smoke_BatchToSpace_5D_channel_dim,
BatchToSpaceLayerTest,
space_to_batch_5d_channel_dim_tests,
batch_to_space_5d_channel_dim_tests,
BatchToSpaceLayerTest::getTestCaseName);
} // namespace
} // namespace

View File

@ -88,8 +88,8 @@ void op::v1::BatchToSpace::validate_and_infer_types()
if (data_rank.is_static())
{
NODE_VALIDATION_CHECK(this,
(data_rank.get_length() >= 4),
"data input must have rank greater than or equal to 4. Got: ",
(data_rank.get_length() >= 2),
"data input must have rank greater or equal than 2. Got: ",
data_rank.get_length());
if (inputs_same_ps.is_static())
@ -197,7 +197,7 @@ namespace
}
auto data_shape = data->get_shape();
auto data_rank = data_shape.size();
if (!(data_rank == 4 || data_rank == 5))
if (data_rank < 2)
{
return false;
}
@ -346,7 +346,6 @@ bool ngraph::op::v1::BatchToSpace::evaluate(const HostTensorVector& outputs,
bool ngraph::op::v1::BatchToSpace::has_evaluate() const
{
NGRAPH_OP_SCOPE(v1_BatchToSpace_has_evaluate);
return !get_input_partial_shape(0).is_dynamic() &&
(get_input_shape(0).size() == 4 || get_input_shape(0).size() == 5) &&
return !get_input_partial_shape(0).is_dynamic() && get_input_shape(0).size() >= 2 &&
get_input_shape(0).size() <= shape_size(get_input_shape(1));
}

View File

@ -79,6 +79,27 @@ NGRAPH_TEST_P(${BACKEND_NAME}, BatchToSpaceTestFloat, BatchToSpaceTestFloatCases
BatchToSpaceTestExecute(GetParam());
}
const test::NDArray<float, 2> input_with_shape_4x3(
{{1.0f, 2.0f, 3.0f},
{4.0f, 5.0f, 6.0f},
{7.0f, 8.0f, 9.0f},
{10.0f, 11.0f, 12.0f}});
const test::NDArray<int64_t, 1> zero_crops_2d({0, 0});
NGRAPH_INSTANTIATE_TEST_SUITE_P(
${BACKEND_NAME},
batch_to_space_2d_without_crops,
BatchToSpaceTestFloat,
testing::Values(
BatchToSpaceParams<float>{input_with_shape_4x3,
test::NDArray<int64_t, 1>({1, 2}),
zero_crops_2d,
zero_crops_2d,
test::NDArray<float, 2>(
{{1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f},
{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f}})}));
const test::NDArray<float, 4> input_with_shape_4x1x1x3(
{{{{1.0f, 2.0f, 3.0f}}},
{{{4.0f, 5.0f, 6.0f}}},

View File

@ -49,8 +49,8 @@ TEST(type_prop, batch_to_space_incompatible_input_element_types)
element::Type integer64_et = element::i64;
element::Type integer32_et = element::i32;
Shape data_sshape{10, 26, 4, 4};
Shape inputs_sshape{4};
Shape data_sshape{10, 26};
Shape inputs_sshape{2};
vector<BatchToSpaceInputParams> test_cases;
test_cases.push_back(
@ -97,8 +97,8 @@ TEST(type_prop, batch_to_space_invalid_input_element_types)
{
element::Type float_et = element::f32;
Shape data_sshape{10, 26, 4, 4};
Shape inputs_sshape{4};
Shape data_sshape{10, 26};
Shape inputs_sshape{2};
const BatchToSpaceInputParams params{
InputInfo{float_et, data_sshape},
@ -124,7 +124,7 @@ TEST(type_prop, batch_to_space_invalid_input_element_types)
TEST(type_prop, batch_to_space_invalid_data_input_rank)
{
Shape data_sshape{4, 2};
Shape data_sshape{4};
element::Type data_et = element::f32;
Shape inputs_sshape{2};
@ -143,7 +143,7 @@ TEST(type_prop, batch_to_space_invalid_data_input_rank)
}
catch(const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "data input must have rank greater than or equal to 4");
EXPECT_HAS_SUBSTRING(error.what(), "data input must have rank greater or equal than 2.");
}
catch (...)
{
@ -153,11 +153,11 @@ TEST(type_prop, batch_to_space_invalid_data_input_rank)
TEST(type_prop, batch_to_space_incompatible_secondary_inputs_shapes)
{
Shape data_sshape{10, 26, 4, 4};
Shape data_sshape{10, 26};
element::Type data_et = element::f32;
Shape inputs_sshape_1D{4};
Shape inputs_sshape_2D{4, 1};
Shape inputs_sshape_1D{2};
Shape inputs_sshape_2D{2, 1};
element::Type inputs_et = element::i64;
vector<BatchToSpaceInputParams> test_cases;
@ -203,10 +203,10 @@ TEST(type_prop, batch_to_space_incompatible_secondary_inputs_shapes)
TEST(type_prop, batch_to_space_invalid_secondary_inputs_rank)
{
Shape data_sshape{10, 26, 4, 4};
Shape data_sshape{10, 26};
element::Type data_et = element::f32;
Shape inputs_sshape_2D{4, 1};
Shape inputs_sshape_2D{2, 1};
element::Type inputs_et = element::i64;
const BatchToSpaceInputParams params{
@ -233,7 +233,7 @@ TEST(type_prop, batch_to_space_invalid_secondary_inputs_rank)
TEST(type_prop, batch_to_space_incompatible_data_and_secondary_inputs_shapes)
{
Shape data_sshape{10, 26, 4, 4};
Shape data_sshape{10, 26};
element::Type data_et = element::f32;
Shape inputs_sshape{5};
@ -414,6 +414,22 @@ TEST(type_prop, batch_to_space_invalid_crops_out_of_bounds)
}
}
TEST(type_prop, batch_to_space_output_shape_2D)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{10, 26});
auto block_shape =
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 5});
auto crops_begin =
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 2});
auto crops_end =
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 0});
auto batch_to_space =
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
ASSERT_EQ(batch_to_space->get_shape(), (Shape{10 / 5, 26 * 5 - 2}));
}
TEST(type_prop, batch_to_space_output_shape_4D)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{100, 7, 13, 3});

View File

@ -16,10 +16,10 @@ using ngraph::test::NodeBuilder;
TEST(attributes, batch_to_space_op)
{
NodeBuilder::get_ops().register_factory<op::v1::BatchToSpace>();
auto data = make_shared<op::Parameter>(element::f32, Shape{128, 4, 2, 2});
auto block_shape = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 2, 2, 2});
auto crops_begin = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 2, 0, 1});
auto crops_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 0, 1, 0});
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 128});
auto block_shape = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto crops_begin = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 2});
auto crops_end = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 0});
auto batch2space = make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
NodeBuilder builder(batch2space);