Enable lower input rank for BatchToSpace operation by fallback (#6553)
* Enable lower input rank for BatchToSpace operation by fallback * Added conditions to check valid data input rank for evaluate and has_evaluate methods
This commit is contained in:
parent
340583fa35
commit
a4fef45e0c
@ -16,6 +16,39 @@ const std::vector<InferenceEngine::Precision> net_precisions = {
|
||||
InferenceEngine::Precision::I32
|
||||
};
|
||||
|
||||
const std::vector<std::vector<size_t>> data_shapes_2D = {
|
||||
{12, 4},
|
||||
{48, 3}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int64_t>> block_shapes_2D = {
|
||||
{1, 2},
|
||||
{1, 6}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<int64_t>> crops_2D = {
|
||||
{0, 0},
|
||||
{0, 1}
|
||||
};
|
||||
|
||||
const auto batch_to_space_2d_tests = ::testing::Combine(
|
||||
::testing::ValuesIn(block_shapes_2D),
|
||||
::testing::ValuesIn(crops_2D),
|
||||
::testing::ValuesIn(crops_2D),
|
||||
::testing::ValuesIn(data_shapes_2D),
|
||||
::testing::ValuesIn(net_precisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_BatchToSpace_2D,
|
||||
BatchToSpaceLayerTest,
|
||||
batch_to_space_2d_tests,
|
||||
BatchToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<size_t>> data_shapes_4D = {
|
||||
{4, 1, 2, 2},
|
||||
{4, 3, 2, 2},
|
||||
@ -39,7 +72,7 @@ const std::vector<std::vector<int64_t>> crops_end_4D = {
|
||||
{0, 0, 0, 2}
|
||||
};
|
||||
|
||||
const auto space_to_batch_4d_spatial_dims_tests = ::testing::Combine(
|
||||
const auto batch_to_space_4d_spatial_dims_tests = ::testing::Combine(
|
||||
::testing::Values(block_shapes_4D[0]),
|
||||
::testing::ValuesIn(crops_begin_4D),
|
||||
::testing::ValuesIn(crops_end_4D),
|
||||
@ -51,7 +84,7 @@ const auto space_to_batch_4d_spatial_dims_tests = ::testing::Combine(
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU));
|
||||
|
||||
const auto space_to_batch_4d_channel_dim_tests = ::testing::Combine(
|
||||
const auto batch_to_space_4d_channel_dim_tests = ::testing::Combine(
|
||||
::testing::Values(block_shapes_4D[1]),
|
||||
::testing::Values(crops_begin_4D[0]),
|
||||
::testing::Values(crops_end_4D[0]),
|
||||
@ -66,13 +99,13 @@ const auto space_to_batch_4d_channel_dim_tests = ::testing::Combine(
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_BatchToSpace_4D_spatial_dims,
|
||||
BatchToSpaceLayerTest,
|
||||
space_to_batch_4d_spatial_dims_tests,
|
||||
batch_to_space_4d_spatial_dims_tests,
|
||||
BatchToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_BatchToSpace_4D_channel_dim,
|
||||
BatchToSpaceLayerTest,
|
||||
space_to_batch_4d_channel_dim_tests,
|
||||
batch_to_space_4d_channel_dim_tests,
|
||||
BatchToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<size_t>> data_shapes_5D = {
|
||||
@ -96,7 +129,7 @@ const std::vector<std::vector<int64_t>> crops_end_5D = {
|
||||
{0, 0, 0, 0, 1}
|
||||
};
|
||||
|
||||
const auto space_to_batch_5d_spatial_dims_tests = ::testing::Combine(
|
||||
const auto batch_to_space_5d_spatial_dims_tests = ::testing::Combine(
|
||||
::testing::Values(block_shapes_5D[0]),
|
||||
::testing::ValuesIn(crops_begin_5D),
|
||||
::testing::ValuesIn(crops_end_5D),
|
||||
@ -108,7 +141,7 @@ const auto space_to_batch_5d_spatial_dims_tests = ::testing::Combine(
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU));
|
||||
|
||||
const auto space_to_batch_5d_channel_dim_tests = ::testing::Combine(
|
||||
const auto batch_to_space_5d_channel_dim_tests = ::testing::Combine(
|
||||
::testing::Values(block_shapes_5D[1]),
|
||||
::testing::Values(crops_begin_5D[0]),
|
||||
::testing::Values(crops_end_5D[0]),
|
||||
@ -123,13 +156,13 @@ const auto space_to_batch_5d_channel_dim_tests = ::testing::Combine(
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_BatchToSpace_5D_spatial_dims,
|
||||
BatchToSpaceLayerTest,
|
||||
space_to_batch_5d_spatial_dims_tests,
|
||||
batch_to_space_5d_spatial_dims_tests,
|
||||
BatchToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_BatchToSpace_5D_channel_dim,
|
||||
BatchToSpaceLayerTest,
|
||||
space_to_batch_5d_channel_dim_tests,
|
||||
batch_to_space_5d_channel_dim_tests,
|
||||
BatchToSpaceLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
@ -88,8 +88,8 @@ void op::v1::BatchToSpace::validate_and_infer_types()
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(data_rank.get_length() >= 4),
|
||||
"data input must have rank greater than or equal to 4. Got: ",
|
||||
(data_rank.get_length() >= 2),
|
||||
"data input must have rank greater or equal than 2. Got: ",
|
||||
data_rank.get_length());
|
||||
|
||||
if (inputs_same_ps.is_static())
|
||||
@ -197,7 +197,7 @@ namespace
|
||||
}
|
||||
auto data_shape = data->get_shape();
|
||||
auto data_rank = data_shape.size();
|
||||
if (!(data_rank == 4 || data_rank == 5))
|
||||
if (data_rank < 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@ -346,7 +346,6 @@ bool ngraph::op::v1::BatchToSpace::evaluate(const HostTensorVector& outputs,
|
||||
bool ngraph::op::v1::BatchToSpace::has_evaluate() const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_BatchToSpace_has_evaluate);
|
||||
return !get_input_partial_shape(0).is_dynamic() &&
|
||||
(get_input_shape(0).size() == 4 || get_input_shape(0).size() == 5) &&
|
||||
return !get_input_partial_shape(0).is_dynamic() && get_input_shape(0).size() >= 2 &&
|
||||
get_input_shape(0).size() <= shape_size(get_input_shape(1));
|
||||
}
|
||||
|
@ -79,6 +79,27 @@ NGRAPH_TEST_P(${BACKEND_NAME}, BatchToSpaceTestFloat, BatchToSpaceTestFloatCases
|
||||
BatchToSpaceTestExecute(GetParam());
|
||||
}
|
||||
|
||||
const test::NDArray<float, 2> input_with_shape_4x3(
|
||||
{{1.0f, 2.0f, 3.0f},
|
||||
{4.0f, 5.0f, 6.0f},
|
||||
{7.0f, 8.0f, 9.0f},
|
||||
{10.0f, 11.0f, 12.0f}});
|
||||
|
||||
const test::NDArray<int64_t, 1> zero_crops_2d({0, 0});
|
||||
|
||||
NGRAPH_INSTANTIATE_TEST_SUITE_P(
|
||||
${BACKEND_NAME},
|
||||
batch_to_space_2d_without_crops,
|
||||
BatchToSpaceTestFloat,
|
||||
testing::Values(
|
||||
BatchToSpaceParams<float>{input_with_shape_4x3,
|
||||
test::NDArray<int64_t, 1>({1, 2}),
|
||||
zero_crops_2d,
|
||||
zero_crops_2d,
|
||||
test::NDArray<float, 2>(
|
||||
{{1.0f, 7.0f, 2.0f, 8.0f, 3.0f, 9.0f},
|
||||
{4.0f, 10.0f, 5.0f, 11.0f, 6.0f, 12.0f}})}));
|
||||
|
||||
const test::NDArray<float, 4> input_with_shape_4x1x1x3(
|
||||
{{{{1.0f, 2.0f, 3.0f}}},
|
||||
{{{4.0f, 5.0f, 6.0f}}},
|
||||
|
@ -49,8 +49,8 @@ TEST(type_prop, batch_to_space_incompatible_input_element_types)
|
||||
element::Type integer64_et = element::i64;
|
||||
element::Type integer32_et = element::i32;
|
||||
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape inputs_sshape{4};
|
||||
Shape data_sshape{10, 26};
|
||||
Shape inputs_sshape{2};
|
||||
|
||||
vector<BatchToSpaceInputParams> test_cases;
|
||||
test_cases.push_back(
|
||||
@ -97,8 +97,8 @@ TEST(type_prop, batch_to_space_invalid_input_element_types)
|
||||
{
|
||||
element::Type float_et = element::f32;
|
||||
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape inputs_sshape{4};
|
||||
Shape data_sshape{10, 26};
|
||||
Shape inputs_sshape{2};
|
||||
|
||||
const BatchToSpaceInputParams params{
|
||||
InputInfo{float_et, data_sshape},
|
||||
@ -124,7 +124,7 @@ TEST(type_prop, batch_to_space_invalid_input_element_types)
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_data_input_rank)
|
||||
{
|
||||
Shape data_sshape{4, 2};
|
||||
Shape data_sshape{4};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{2};
|
||||
@ -143,7 +143,7 @@ TEST(type_prop, batch_to_space_invalid_data_input_rank)
|
||||
}
|
||||
catch(const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "data input must have rank greater than or equal to 4");
|
||||
EXPECT_HAS_SUBSTRING(error.what(), "data input must have rank greater or equal than 2.");
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
@ -153,11 +153,11 @@ TEST(type_prop, batch_to_space_invalid_data_input_rank)
|
||||
|
||||
TEST(type_prop, batch_to_space_incompatible_secondary_inputs_shapes)
|
||||
{
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape data_sshape{10, 26};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape_1D{4};
|
||||
Shape inputs_sshape_2D{4, 1};
|
||||
Shape inputs_sshape_1D{2};
|
||||
Shape inputs_sshape_2D{2, 1};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
vector<BatchToSpaceInputParams> test_cases;
|
||||
@ -203,10 +203,10 @@ TEST(type_prop, batch_to_space_incompatible_secondary_inputs_shapes)
|
||||
|
||||
TEST(type_prop, batch_to_space_invalid_secondary_inputs_rank)
|
||||
{
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape data_sshape{10, 26};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape_2D{4, 1};
|
||||
Shape inputs_sshape_2D{2, 1};
|
||||
element::Type inputs_et = element::i64;
|
||||
|
||||
const BatchToSpaceInputParams params{
|
||||
@ -233,7 +233,7 @@ TEST(type_prop, batch_to_space_invalid_secondary_inputs_rank)
|
||||
|
||||
TEST(type_prop, batch_to_space_incompatible_data_and_secondary_inputs_shapes)
|
||||
{
|
||||
Shape data_sshape{10, 26, 4, 4};
|
||||
Shape data_sshape{10, 26};
|
||||
element::Type data_et = element::f32;
|
||||
|
||||
Shape inputs_sshape{5};
|
||||
@ -414,6 +414,22 @@ TEST(type_prop, batch_to_space_invalid_crops_out_of_bounds)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_output_shape_2D)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{10, 26});
|
||||
auto block_shape =
|
||||
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 5});
|
||||
auto crops_begin =
|
||||
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 2});
|
||||
auto crops_end =
|
||||
make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 0});
|
||||
auto batch_to_space =
|
||||
make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
ASSERT_EQ(batch_to_space->get_element_type(), element::f32);
|
||||
ASSERT_EQ(batch_to_space->get_shape(), (Shape{10 / 5, 26 * 5 - 2}));
|
||||
}
|
||||
|
||||
TEST(type_prop, batch_to_space_output_shape_4D)
|
||||
{
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{100, 7, 13, 3});
|
||||
|
@ -16,10 +16,10 @@ using ngraph::test::NodeBuilder;
|
||||
TEST(attributes, batch_to_space_op)
|
||||
{
|
||||
NodeBuilder::get_ops().register_factory<op::v1::BatchToSpace>();
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{128, 4, 2, 2});
|
||||
auto block_shape = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{1, 2, 2, 2});
|
||||
auto crops_begin = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 2, 0, 1});
|
||||
auto crops_end = make_shared<op::Constant>(element::i64, Shape{4}, vector<int64_t>{0, 0, 1, 0});
|
||||
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 128});
|
||||
auto block_shape = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
|
||||
auto crops_begin = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 2});
|
||||
auto crops_end = make_shared<op::Constant>(element::i64, Shape{2}, vector<int64_t>{0, 0});
|
||||
auto batch2space = make_shared<op::v1::BatchToSpace>(data, block_shape, crops_begin, crops_end);
|
||||
|
||||
NodeBuilder builder(batch2space);
|
||||
|
Loading…
Reference in New Issue
Block a user