From c62be51cc1e41e3267ec2f2abfbec488b4b9569d Mon Sep 17 00:00:00 2001 From: Tomasz Jankowski Date: Tue, 14 Feb 2023 11:45:10 +0100 Subject: [PATCH] [Transformations] Enable dynamic decomposition BTS and STB ops (#15179) * Add dynamism for BatchToSpace conversion * Extend dynamism for BatchToSpace conversion Only block input needs to be const 'cse axes_order_const is freaky * Enhace dynamism for BatchToSpace conversion Block input need not be const now. * Add dynamism for STB by elements conversion * Remove const need for crops for BTS by_elements * temp for review * Try to fix output tensor overwrite * Make test to reproduce invalid shape inference * Reproduce the error with template plugin * Fix code style * Fix 0D inputs issue * Remove 0D shape parts before Concat * Apply nested namespaces * Enable non-constant STB Block input * Fix BTS runtime info * Fix STB by elems runtime info * Add dynamism for STB conversion * Add BTS dynamic data test * Add STB dynamic data test * Reduce STB concats * Add tests naming * Edit * style * Consider other block element types * Enhance type test * Use opset10 only * Check block shape --- .../op_conversions/convert_batch_to_space.cpp | 282 +++++++++--------- .../op_conversions/convert_space_to_batch.cpp | 234 +++++++-------- .../batch_to_space_decomposition_test.cpp | 159 ++++++++++ 3 files changed, 412 insertions(+), 263 deletions(-) diff --git a/src/common/transformations/src/transformations/op_conversions/convert_batch_to_space.cpp b/src/common/transformations/src/transformations/op_conversions/convert_batch_to_space.cpp index 91e6b5cac53..7a9f75c5091 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_batch_to_space.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_batch_to_space.cpp @@ -4,94 +4,82 @@ #include "transformations/op_conversions/convert_batch_to_space.hpp" +#include +#include #include #include #include -#include +#include #include #include "itt.hpp" +using namespace std; +using namespace ov::opset10; +using namespace ov::element; + void ov::pass::ConvertBatchToSpace::convert_batch_to_space() { MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space); - auto batch_to_space = ngraph::pattern::wrap_type(); - matcher_pass_callback callback = [](pattern::Matcher& m) { - auto batch_to_space = std::dynamic_pointer_cast(m.get_match_root()); - if (!batch_to_space) { + const auto batch_to_space = pattern::wrap_type(); + matcher_pass_callback callback = [this](pattern::Matcher& m) { + const auto batch_to_space = dynamic_pointer_cast(m.get_match_root()); + if (!batch_to_space || transformation_callback(batch_to_space)) { return false; } - NodeVector new_ops; - auto data = batch_to_space->input_value(0); - auto block = batch_to_space->input_value(1); - auto crops_begin = batch_to_space->input_value(2); - auto crops_end = batch_to_space->input_value(3); + NodeRegistry rg; + const auto data = batch_to_space->input_value(0); + const auto block = batch_to_space->input_value(1); + const auto crops_begin = batch_to_space->input_value(2); + const auto crops_end = batch_to_space->input_value(3); - if (data.get_partial_shape().is_dynamic()) { - return false; - } - const auto& data_shape = data.get_shape(); - - const auto block_const = std::dynamic_pointer_cast(block.get_node_shared_ptr()); - const auto crops_begin_const = std::dynamic_pointer_cast(crops_begin.get_node_shared_ptr()); - const auto crops_end_const = std::dynamic_pointer_cast(crops_end.get_node_shared_ptr()); - - if (!block_const || !crops_begin_const || !crops_end_const) { - return false; + const auto data_shape_rank = data.get_partial_shape().rank(); + if (data_shape_rank.is_dynamic()) { + return false; // because StridedSlice masks are std::vector } - const std::vector& block_values = block_const->cast_vector(); - const std::vector& crops_end_values = crops_end_const->cast_vector(); + if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) { + return false; + } + const auto block_length = static_cast(block.get_shape()[0]); // First we have to disperse the data from batch, then rearrange them // so as appropriate chunks of data where close to their destination place. - // Finally squeeze data from respective dimensions.ss - std::vector dispersed_shape; - int64_t b_dim_divider = 1; - for (const auto& el : block_values) { - b_dim_divider *= el; - } + // Finally squeeze data from respective dimensions + + const auto zero = rg.make(i64, Shape{1}, 0); + const auto shape_of_data = rg.make(data, block.get_element_type()); + const auto batch = rg.make(shape_of_data, zero, zero); + const auto block_prod = rg.make(block, zero); + const auto batch_div = rg.make(batch, block_prod); // note: B_0 is expected to be 1. // x' = reshape(`data`, [B_1, ..., B_{N - 1}, batch / (B_1 * ... B_{N - 1}), D_1, D_2, ..., // D_{N - 1}]), // where B_i = block_shape[i] - dispersed_shape.insert(dispersed_shape.begin(), block_values.begin() + 1, block_values.end()); - dispersed_shape.push_back(data_shape.at(0) / b_dim_divider); - for (size_t i = 1; i < data_shape.size(); ++i) { - dispersed_shape.push_back(data_shape.at(i)); - } - - const auto out_pattern_1 = - opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape); + const auto one = rg.make(i64, Shape{1}, 1); + const auto end = rg.make(i64, Shape{1}, block_length); + const auto block_tail = rg.make(block, one, end, one); + const auto data_shape_tail = rg.make(shape_of_data, one, end, one); + const auto dispersed_shape = rg.make(OutputVector{block_tail, batch_div, data_shape_tail}, 0); const bool special_zero = false; - std::shared_ptr flat_node = std::make_shared(data, out_pattern_1, special_zero); - new_ops.push_back(flat_node); + shared_ptr flat_node = rg.make(data, dispersed_shape, special_zero); + // calculate axes to transpose // x'' = transpose(x', [N, N + 1, 0, N + 2, 1, ..., N + N - 1, N - 1]) - std::vector axes_order{block_values.size() - 1}; - for (size_t i = 0; i < block_values.size() - 1; ++i) { - axes_order.push_back(i + block_values.size()); + vector axes_order{block_length - 1}; + for (int64_t i = 0; i < block_length - 1; ++i) { + axes_order.push_back(i + block_length); axes_order.push_back(i); } + const auto axes_order_const = rg.make(i64, Shape{axes_order.size()}, axes_order); + flat_node = rg.make(flat_node, axes_order_const); - const auto axes_order_const = - opset3::Constant::create(element::i64, - Shape{axes_order.size()}, - std::vector(axes_order.begin(), axes_order.end())); - flat_node = std::make_shared(flat_node, axes_order_const); - new_ops.push_back(flat_node); // x''' = reshape(x'', [batch / (B_1 * ... * B_{N - 1}), D_1 * B_1, D_2 * B_2, ... , D_{N - 1} // * B_{N - 1}]) - std::vector squeezed_shape; - squeezed_shape.push_back(data_shape.at(0) / b_dim_divider); - for (size_t i = 1; i < block_values.size(); ++i) { - squeezed_shape.push_back(data_shape.at(i) * block_values.at(i)); - } - - const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape); - flat_node = std::make_shared(flat_node, out_pattern_2, special_zero); - new_ops.push_back(flat_node); + const auto squeezed_shape_tail = rg.make(block_tail, data_shape_tail); + const auto squeezed_shape = rg.make(OutputVector{batch_div, squeezed_shape_tail}, 0); + flat_node = rg.make(flat_node, squeezed_shape, special_zero); // Crop the start and end of dimensions according to `crops_begin`, `crops_end` to produce // the output of shape: @@ -99,129 +87,133 @@ void ov::pass::ConvertBatchToSpace::convert_batch_to_space() { // `y = [batch / (B_1 * ... * B_{N - 1}), crop(D_1 * B_1, crops_begin[1], crops_end[1]), // crop(D_2 * B_2, crops_begin[2], crops_end[2]), ... , // crop(D_{N - 1} * B_{N - 1}, crops_begin[N - 1], crops_end[N - 1])]` - std::vector upperbounds_values; - auto flat_node_shape = flat_node->get_shape(); - for (size_t i = 0; i < flat_node_shape.size(); ++i) { - upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i)); - } + const auto shape_of_flat_node = rg.make(flat_node, crops_end.get_element_type()); + const auto upperbounds = rg.make(shape_of_flat_node, crops_end); - const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(), - Shape{upperbounds_values.size()}, - upperbounds_values); - - std::vector begin_mask(data_shape.size(), 0); - std::vector end_mask(data_shape.size(), 0); - flat_node = - std::make_shared(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask); - new_ops.push_back(flat_node); + const auto begin_mask = vector(data_shape_rank.get_length(), 0); + const auto& end_mask = begin_mask; + flat_node = rg.make(flat_node, crops_begin, upperbounds, begin_mask, end_mask); flat_node->set_friendly_name(batch_to_space->get_friendly_name()); - ngraph::copy_runtime_info(batch_to_space, new_ops); - ngraph::replace_node(batch_to_space, flat_node); + copy_runtime_info(batch_to_space, rg.get()); + replace_node(batch_to_space, flat_node); return true; }; - auto m = std::make_shared(batch_to_space, matcher_name); + const auto m = make_shared(batch_to_space, matcher_name); this->register_matcher(m, callback); } void ov::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() { MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space_by_elements); - auto batch_to_space = ngraph::pattern::wrap_type(); + const auto batch_to_space = pattern::wrap_type(); matcher_pass_callback callback = [this](pattern::Matcher& m) { - auto batch_to_space = std::dynamic_pointer_cast(m.get_match_root()); - if (!batch_to_space) { + const auto batch_to_space = dynamic_pointer_cast(m.get_match_root()); + if (!batch_to_space || transformation_callback(batch_to_space)) { return false; } - auto data = batch_to_space->input_value(0); + const auto data = batch_to_space->input_value(0); - if (data.get_partial_shape().is_dynamic()) { - return false; - } - auto data_shape = data.get_shape(); - - if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) { - return false; - } - auto block = batch_to_space->input_value(1); - auto crops_begin = batch_to_space->input_value(2); - auto crops_end = batch_to_space->input_value(3); - - const auto block_const = ov::as_type_ptr(block.get_node_shared_ptr()); - const auto crops_begin_const = ov::as_type_ptr(crops_begin.get_node_shared_ptr()); - const auto crops_end_const = ov::as_type_ptr(crops_end.get_node_shared_ptr()); - - const std::vector& block_values = block_const->cast_vector(); - const std::vector& crops_end_values = crops_end_const->cast_vector(); - - std::vector dispersed_shape(1); - dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end()); - std::vector axes_order(block_values.size() + 1); - std::vector squeezed_shape(data_shape.begin(), data_shape.end()); - if (squeezed_shape.size() > block_values.size()) { - return false; + const auto data_shape_rank = data.get_partial_shape().rank(); + if (data_shape_rank.is_dynamic()) { + return false; // because StridedSlice masks are std::vector } - NodeVector new_ops; + const auto block = batch_to_space->input_value(1); + const auto crops_begin = batch_to_space->input_value(2); + const auto crops_end = batch_to_space->input_value(3); - std::shared_ptr flat_node = data.get_node_shared_ptr(); - for (size_t block_idx = 1; block_idx < block_values.size(); ++block_idx) { - dispersed_shape[0] = block_values[block_idx]; - dispersed_shape[1] /= block_values[block_idx]; - const auto out_pattern_1 = - opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape); - const bool special_zero = false; - flat_node = std::make_shared(flat_node, out_pattern_1, special_zero); - new_ops.push_back(flat_node); + if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) { + return false; + } + const auto block_length = static_cast(block.get_shape()[0]); - size_t val = 1; - for (size_t axis_idx = 0; axis_idx <= block_values.size(); ++axis_idx) { - if ((block_idx + 1) == axis_idx) { + NodeRegistry rg; + const auto zero = rg.make(i64, Shape{1}, 0); + const auto one = rg.make(i64, Shape{1}, 1); + const auto two = rg.make(i64, Shape{1}, 2); + const auto int_max = rg.make(i64, Shape{1}, INT_MAX); + + const auto shape_of_data = rg.make(data, block.get_element_type()); + const auto et_zero = rg.make(block.get_element_type(), Shape{1}, 0); + shared_ptr dispersed_shape = rg.make(OutputVector{et_zero, shape_of_data}, 0); + shared_ptr squeezed_shape = shape_of_data; + + shared_ptr flat_node = data.get_node_shared_ptr(); + + const auto make_concat = [&](OutputVector nodes) { + nodes.erase(remove_if(nodes.begin(), + nodes.end(), + [](const Output& n) { + return n.get_partial_shape().is_static() && n.get_shape().size() > 0 && + n.get_shape()[0] == 0; + }), + nodes.end()); + return rg.make(nodes, 0); + }; + + shared_ptr div; + for (int64_t b_idx = 1; b_idx < block_length; ++b_idx) { + const auto block_index = rg.make(i64, Shape{1}, b_idx); + const auto block_index_next = rg.make(i64, Shape{1}, b_idx + 1); + const auto block_value = rg.make(block, block_index, zero); + + // dispersed_shape[0] = block[b_idx]; + // dispersed_shape[1] /= block[b_idx]; + if (!div) { + const auto batch = rg.make(shape_of_data, zero, zero); + div = rg.make(batch, block_value); + } else { + div = rg.make(div, block_value); + } + auto ds_tail = rg.make(dispersed_shape, two, int_max, one); + dispersed_shape = make_concat({block_value, div, ds_tail}); + constexpr auto special_zero = false; + flat_node = rg.make(flat_node, dispersed_shape, special_zero); + + vector axes_order(block_length + 1); + int64_t val = 1; + for (int64_t axis_idx = 0; axis_idx <= block_length; ++axis_idx) { + if ((b_idx + 1) == axis_idx) { axes_order[axis_idx] = 0; } else { axes_order[axis_idx] = val; val++; } } + const auto axes_order_const = rg.make(i64, Shape{axes_order.size()}, axes_order); + flat_node = rg.make(flat_node, axes_order_const); - const auto axes_order_const = - ov::opset3::Constant::create(element::i64, - Shape{axes_order.size()}, - std::vector(axes_order.begin(), axes_order.end())); - flat_node = std::make_shared(flat_node, axes_order_const); - new_ops.push_back(flat_node); + // squeezed_shape[0] = dispersed_shape[1]; + // squeezed_shape[b_idx] *= block[b_idx]; + const auto sq_slice = rg.make(squeezed_shape, one, block_index, one); + const auto sq_bidx_dim = rg.make(squeezed_shape, block_index, zero); + const auto sq_mul = rg.make(sq_bidx_dim, block_value); + const auto sq_shape_tail = rg.make(squeezed_shape, block_index_next, int_max, one); + squeezed_shape.reset(); + squeezed_shape = make_concat({div, sq_slice, sq_mul, sq_shape_tail}); + flat_node = rg.make(flat_node, squeezed_shape, special_zero); - squeezed_shape[0] = dispersed_shape[1]; - squeezed_shape[block_idx] *= block_values[block_idx]; - dispersed_shape[block_idx + 1] = squeezed_shape[block_idx]; - const auto out_pattern_2 = - opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape); - flat_node = std::make_shared(flat_node, out_pattern_2, special_zero); - new_ops.push_back(flat_node); + // dispersed_shape[b_idx + 1] = squeezed_shape[b_idx]; + const auto ds_front = rg.make(dispersed_shape, zero, block_index_next, one); + ds_tail = rg.make(dispersed_shape, rg.make(i64, Shape{1}, b_idx + 2), int_max, one); + dispersed_shape = make_concat({ds_front, sq_mul, ds_tail}); } - std::vector upperbounds_values; - auto flat_node_shape = flat_node->get_shape(); - for (size_t i = 0; i < flat_node_shape.size(); ++i) { - upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i)); - } - const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(), - Shape{upperbounds_values.size()}, - upperbounds_values); + const auto shape_of_flat_node = rg.make(flat_node, crops_end.get_element_type()); + const auto upperbounds = rg.make(shape_of_flat_node, crops_end); - std::vector begin_mask(data_shape.size(), 0); - std::vector end_mask(data_shape.size(), 0); - flat_node = - std::make_shared(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask); - new_ops.push_back(flat_node); + const auto begin_mask = vector(data_shape_rank.get_length(), 0); + const auto& end_mask = begin_mask; + flat_node = rg.make(flat_node, crops_begin, upperbounds, begin_mask, end_mask); flat_node->set_friendly_name(batch_to_space->get_friendly_name()); - ngraph::copy_runtime_info(batch_to_space, new_ops); - ngraph::replace_node(batch_to_space, flat_node); + copy_runtime_info(batch_to_space, rg.get()); + replace_node(batch_to_space, flat_node); return true; }; - auto m = std::make_shared(batch_to_space, matcher_name); + const auto m = make_shared(batch_to_space, matcher_name); this->register_matcher(m, callback); } diff --git a/src/common/transformations/src/transformations/op_conversions/convert_space_to_batch.cpp b/src/common/transformations/src/transformations/op_conversions/convert_space_to_batch.cpp index 0eaaceaa535..4434c58d7d9 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_space_to_batch.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_space_to_batch.cpp @@ -4,42 +4,43 @@ #include "transformations/op_conversions/convert_space_to_batch.hpp" +#include #include #include #include -#include +#include #include #include "itt.hpp" +using namespace std; +using namespace ov::opset10; +using namespace ov::element; + void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() { MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch); - auto space_to_batch = ngraph::pattern::wrap_type(); - matcher_pass_callback callback = [](pattern::Matcher& m) { - auto space_to_batch = std::dynamic_pointer_cast(m.get_match_root()); - if (!space_to_batch) { + const auto space_to_batch = pattern::wrap_type(); + matcher_pass_callback callback = [this](pattern::Matcher& m) { + const auto space_to_batch = dynamic_pointer_cast(m.get_match_root()); + if (!space_to_batch || transformation_callback(space_to_batch)) { return false; } - NodeVector new_ops; - auto data = space_to_batch->input_value(0); - auto block = space_to_batch->input_value(1); - auto pads_begin = space_to_batch->input_value(2); - auto pads_end = space_to_batch->input_value(3); - - if (data.get_partial_shape().is_dynamic()) { + const auto data = space_to_batch->input_value(0); + if (data.get_partial_shape().rank().is_dynamic()) { return false; } - const auto block_const = std::dynamic_pointer_cast(block.get_node_shared_ptr()); - const auto pads_begin_const = std::dynamic_pointer_cast(pads_begin.get_node_shared_ptr()); - const auto pads_end_const = std::dynamic_pointer_cast(pads_end.get_node_shared_ptr()); + const auto block = space_to_batch->input_value(1); + const auto pads_begin = space_to_batch->input_value(2); + const auto pads_end = space_to_batch->input_value(3); - if (!block_const || !pads_begin_const || !pads_end_const) { + if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) { return false; } + const auto block_length = static_cast(block.get_shape()[0]); - const std::vector& block_values = block_const->cast_vector(); + NodeRegistry rg; // Zero-pad the start and end of dimensions [D_0, ..., D_{N - 1}] of the input according to // `pads_begin` @@ -47,162 +48,159 @@ void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() { // note: P_0 for batch dimension is expected to be 0 (no-padding). // x = [batch + P_0, D_1 + P_1, D_2 + P_2, ..., D_{N - 1} + P_{N - 1}], where P_i = // pads_begin[i] + pads_end[i] - std::shared_ptr flat_node = - std::make_shared(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT); - auto out_shape = flat_node->get_shape(); - new_ops.push_back(flat_node); + shared_ptr flat_node = rg.make(data, pads_begin, pads_end, op::PadMode::CONSTANT); + const auto out_shape = rg.make(flat_node, block.get_element_type()); + + const auto zero = rg.make(i64, Shape{1}, 0); + const auto one = rg.make(i64, Shape{1}, 1); + const auto int_max = rg.make(i64, Shape{1}, INT_MAX); // First we have to disperse the data from spatial dimensions, then // rearrange them so as appropriate chunks of data where close to their // destination place. Finally squeeze data from respective dimensions. - Shape dispersed_shape{out_shape.at(0)}; // note: B_0 for batch is ignored. // x' = reshape(x, [batch, (D_1 + P_1) / B_1, B_1, (D_2 + P_2) / B_2, B_2, ..., // (D_{N - 1} + P_{N - 1}) / B_{N - 1}, B_{N - 1}]), where B_i = block_shape[i] - for (size_t i = 1; i < block_values.size(); ++i) { - dispersed_shape.push_back(out_shape.at(i) / block_values.at(i)); - dispersed_shape.push_back(block_values.at(i)); - } + const auto batch = rg.make(out_shape, zero, zero); + const auto out_shape_tail = rg.make(out_shape, one, int_max, one); + const auto block_tail = rg.make(block, one, int_max, one); + const auto os_tail_div = rg.make(out_shape_tail, block_tail); - const auto out_pattern = opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape); - flat_node = std::make_shared(flat_node, out_pattern, false); - new_ops.push_back(flat_node); + // interleave os_tail_div with block_tail + const auto c = rg.make(NodeVector{os_tail_div, block_tail}, 0); + const auto r = + rg.make(c, rg.make(i64, Shape{2}, vector{2, block_length - 1}), false); + const auto t = rg.make(r, rg.make(i64, Shape{2}, vector{1, 0})); + const auto interleaved = rg.make(t, rg.make(i64, Shape{1}, 2 * (block_length - 1)), false); + + const auto dispersed_shape = rg.make(NodeVector{batch, interleaved}, 0); + flat_node = rg.make(flat_node, dispersed_shape, false); // x'' = transpose(x', [2, 4, ..., (N - 1) + (N - 1), 0, 1, 3, ..., N + (N - 1)]) - std::vector axes_order; - for (size_t i = 0, j = 2; i < block_values.size() - 1; ++i, j += 2) { + vector axes_order; + for (int64_t i = 0, j = 2; i < block_length - 1; ++i, j += 2) { axes_order.push_back(j); } axes_order.push_back(0); - for (size_t i = 0, j = 1; i < block_values.size() - 1; ++i, j += 2) { + for (int64_t i = 0, j = 1; i < block_length - 1; ++i, j += 2) { axes_order.push_back(j); } - const auto axes_order_const = - opset3::Constant::create(element::i64, - Shape{axes_order.size()}, - std::vector(axes_order.begin(), axes_order.end())); - flat_node = std::make_shared(flat_node, axes_order_const); - new_ops.push_back(flat_node); + const auto axes_order_const = rg.make(i64, Shape{axes_order.size()}, axes_order); + flat_node = rg.make(flat_node, axes_order_const); - Shape squeezed_shape; - int64_t prod = 1; - for (const auto& el : block_values) { - prod *= el; - } - - // y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ... - // , + // y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ..., // (D_{N - 1} + P_{N - 1}) / B_{N - 1}]) - squeezed_shape.push_back(out_shape.at(0) * prod); - for (size_t i = 1; i < block_values.size(); ++i) { - squeezed_shape.push_back(out_shape.at(i) / block_values.at(i)); - } - - const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape); - flat_node = std::make_shared(flat_node, out_pattern_2, false); - new_ops.push_back(flat_node); + // note: B_0 is assumed to be 1 by op definion + const auto block_prod = rg.make(block, zero); + const auto squeezed_shape = rg.make(NodeVector{rg.make(batch, block_prod), os_tail_div}, 0); + flat_node = rg.make(flat_node, squeezed_shape, false); flat_node->set_friendly_name(space_to_batch->get_friendly_name()); - ngraph::copy_runtime_info(space_to_batch, new_ops); - ngraph::replace_node(space_to_batch, flat_node); + copy_runtime_info(space_to_batch, rg.get()); + replace_node(space_to_batch, flat_node); return true; }; - auto m = std::make_shared(space_to_batch, matcher_name); + const auto m = make_shared(space_to_batch, matcher_name); this->register_matcher(m, callback); } void ov::pass::ConvertSpaceToBatch::convert_space_to_batch_by_elements() { MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch_by_elements); - auto space_to_batch = ngraph::pattern::wrap_type(); + const auto space_to_batch = pattern::wrap_type(); matcher_pass_callback callback = [this](pattern::Matcher& m) { - auto space_to_batch = std::dynamic_pointer_cast(m.get_match_root()); - if (!space_to_batch) { + const auto space_to_batch = dynamic_pointer_cast(m.get_match_root()); + if (!space_to_batch || transformation_callback(space_to_batch)) { return false; } - auto data = space_to_batch->input_value(0); - - if (data.get_partial_shape().is_dynamic()) { - return false; - } - const auto& data_shape = data.get_shape(); - - if (transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) { + const auto data = space_to_batch->input_value(0); + if (data.get_partial_shape().rank().is_dynamic()) { return false; } - auto block = space_to_batch->input_value(1); - auto pads_begin = space_to_batch->input_value(2); - auto pads_end = space_to_batch->input_value(3); + const auto block = space_to_batch->input_value(1); + const auto pads_begin = space_to_batch->input_value(2); + const auto pads_end = space_to_batch->input_value(3); - const auto block_const = ov::as_type_ptr(block.get_node_shared_ptr()); - const auto pads_begin_const = ov::as_type_ptr(pads_begin.get_node_shared_ptr()); - const auto pads_end_const = ov::as_type_ptr(pads_end.get_node_shared_ptr()); - - if (!block_const || !pads_begin_const || !pads_end_const) { + if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) { return false; } - const std::vector& block_values = block_const->cast_vector(); + const auto block_length = static_cast(block.get_shape()[0]); - NodeVector new_ops; + NodeRegistry rg; - std::shared_ptr flat_node = - std::make_shared(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT); - new_ops.push_back(flat_node); - auto out_shape = flat_node->get_shape(); + shared_ptr flat_node = rg.make(data, pads_begin, pads_end, op::PadMode::CONSTANT); - std::vector dispersed_shape(block_values.size() + 1); - std::vector axes_order(block_values.size() + 1); - std::vector squeezed_shape(out_shape.begin(), out_shape.end()); - for (int64_t block_idx = block_values.size() - 1; block_idx >= 0; --block_idx) { - int64_t sq_shape_idx = block_values.size() - 1; + shared_ptr squeezed_shape = rg.make(flat_node, block.get_element_type()); + + const auto zero = rg.make(i64, Shape{1}, 0); + const auto one = rg.make(i64, Shape{1}, 1); + const auto int_max = rg.make(i64, Shape{1}, INT_MAX); + + for (int64_t b_idx = block_length - 1; b_idx >= 0; --b_idx) { + const auto block_index = rg.make(i64, Shape{1}, b_idx); + const auto block_index_next = rg.make(i64, Shape{1}, b_idx + 1); + const auto block_value = rg.make(block, block_index, zero); + + NodeVector dispersed_shape_prep; + dispersed_shape_prep.reserve(block_length + 1); + if (b_idx > 0) // avoid addind empty Slice into Concat + dispersed_shape_prep.push_back(rg.make(squeezed_shape, zero, block_index, one)); + const auto squeezed_element = rg.make(squeezed_shape, block_index, zero); + dispersed_shape_prep.push_back(rg.make(squeezed_element, block_value)); + dispersed_shape_prep.push_back(block_value); + if (b_idx + 1 < block_length) // avoid addind empty Slice into Concat + dispersed_shape_prep.push_back(rg.make(squeezed_shape, block_index_next, int_max, one)); + + const auto dispersed_shape = rg.make(dispersed_shape_prep, 0); + constexpr auto special_zero = false; + flat_node = rg.make(flat_node, dispersed_shape, special_zero); + + vector axes_order(block_length + 1); int64_t axis_idx = axes_order.size() - 1; - for (int64_t shape_idx = dispersed_shape.size() - 1; shape_idx >= 0; --shape_idx) { - if (shape_idx == (block_idx + 1)) { - dispersed_shape[shape_idx] = block_values[block_idx]; - axes_order[0] = shape_idx; - } else if (shape_idx == block_idx) { - dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx] / block_values[block_idx]; - axes_order[axis_idx] = shape_idx; + for (int64_t ds_idx = block_length; ds_idx >= 0; --ds_idx) { + if (ds_idx == (b_idx + 1)) { + axes_order[0] = ds_idx; + } else if (ds_idx == b_idx) { + axes_order[axis_idx] = ds_idx; axis_idx--; - sq_shape_idx--; } else { - dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx]; - axes_order[axis_idx] = shape_idx; + axes_order[axis_idx] = ds_idx; axis_idx--; - sq_shape_idx--; } } + const auto axes_order_const = rg.make(i64, Shape{axes_order.size()}, axes_order); + flat_node = rg.make(flat_node, axes_order_const); - const auto out_pattern_1 = - opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape); - const bool special_zero = false; - flat_node = std::make_shared(flat_node, out_pattern_1, special_zero); - new_ops.push_back(flat_node); + // don't change squeezed_shape at the last iteration, block[0] is assumed to be 1 by op definion + if (b_idx > 0) { + NodeVector squeezed_shape_prep; + squeezed_shape_prep.reserve(block_length); + squeezed_shape_prep.push_back( + rg.make(rg.make(squeezed_shape, zero, zero), block_value)); + if (b_idx > 1) { // avoid addind empty Slice into Concat + squeezed_shape_prep.push_back(rg.make(squeezed_shape, one, block_index, one)); + } + squeezed_shape_prep.push_back( + rg.make(rg.make(squeezed_shape, block_index, zero), block_value)); + if (b_idx + 1 < block_length) { // avoid addind empty Slice into Concat + squeezed_shape_prep.push_back(rg.make(squeezed_shape, block_index_next, int_max, one)); + } - const auto axes_order_const = - opset3::Constant::create(element::i64, - Shape{axes_order.size()}, - std::vector(axes_order.begin(), axes_order.end())); - flat_node = std::make_shared(flat_node, axes_order_const); - new_ops.push_back(flat_node); - squeezed_shape[0] *= block_values[block_idx]; - squeezed_shape[block_idx] /= block_values[block_idx]; - const auto out_pattern_2 = - opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape); - flat_node = std::make_shared(flat_node, out_pattern_2, special_zero); - new_ops.push_back(flat_node); + squeezed_shape = rg.make(squeezed_shape_prep, 0); + } + flat_node = rg.make(flat_node, squeezed_shape, special_zero); } flat_node->set_friendly_name(space_to_batch->get_friendly_name()); - ngraph::copy_runtime_info(space_to_batch, new_ops); - ngraph::replace_node(space_to_batch, flat_node); + copy_runtime_info(space_to_batch, rg.get()); + replace_node(space_to_batch, flat_node); return true; }; - auto m = std::make_shared(space_to_batch, matcher_name); + const auto m = make_shared(space_to_batch, matcher_name); this->register_matcher(m, callback); } diff --git a/src/common/transformations/tests/op_conversions/batch_to_space_decomposition_test.cpp b/src/common/transformations/tests/op_conversions/batch_to_space_decomposition_test.cpp index d351a607341..5e498a51469 100644 --- a/src/common/transformations/tests/op_conversions/batch_to_space_decomposition_test.cpp +++ b/src/common/transformations/tests/op_conversions/batch_to_space_decomposition_test.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -19,6 +20,7 @@ #include "common_test_utils/ngraph_test_utils.hpp" #include "common_test_utils/test_common.hpp" +using namespace std; using namespace testing; using namespace ngraph; @@ -35,6 +37,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecompositionByElements) { std::make_shared(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); manager.register_pass(); + manager.register_pass(); } { @@ -93,6 +96,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecompositionByElements) { std::make_shared(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); manager.register_pass(); + manager.register_pass(); } { @@ -159,6 +163,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecomposition) { std::make_shared(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); manager.register_pass(false); + manager.register_pass(); } { @@ -195,6 +200,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) { std::make_shared(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); manager.register_pass(false); + manager.register_pass(); } { @@ -218,3 +224,156 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) { function_ref = std::make_shared(ngraph::NodeVector{ss}, ngraph::ParameterVector{data}); } } + +template +void op_convertion_type_test(const Params& params) { + using namespace ov::opset10; + using namespace ov::pass; + + const auto by_elements = get<0>(params); + const auto block_elem_type = get<1>(params); + + const auto data = make_shared(element::f32, Shape{1, 1}); + const auto block_p = Constant::create(block_elem_type, Shape{2}, {1, 1}); + const auto input_2_p = Constant::create(block_elem_type, Shape{2}, {0, 0}); + const auto input_3_p = Constant::create(block_elem_type, Shape{2}, {0, 0}); + const auto bts_or_stb = make_shared(data, block_p, input_2_p, input_3_p); + const auto f = make_shared(NodeVector{bts_or_stb}, ParameterVector{data}); + + Manager m; + m.register_pass(by_elements); + m.register_pass(); + ASSERT_NO_THROW(m.run_passes(f)); + EXPECT_EQ(f->get_result()->get_input_shape(0), (Shape{1, 1})); +} + +using ElementTypeParams = tuple; + +class BatchToSpaceDecomposition2D : public testing::WithParamInterface, + public TransformationTests {}; + +TEST_P(BatchToSpaceDecomposition2D, BlockElemType) { + op_convertion_type_test(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P(TransformationTests, + BatchToSpaceDecomposition2D, + ::testing::Combine(::testing::ValuesIn({false, true}), + ::testing::ValuesIn({element::i32, element::i64}))); + +class SpaceToBatchDecomposition2D : public testing::WithParamInterface, + public TransformationTests {}; + +TEST_P(SpaceToBatchDecomposition2D, BlockElemType) { + op_convertion_type_test(GetParam()); +} + +INSTANTIATE_TEST_SUITE_P(TransformationTests, + SpaceToBatchDecomposition2D, + ::testing::Combine(::testing::ValuesIn({false, true}), + ::testing::ValuesIn({element::i32, element::i64}))); + +template +void op_convertion_test(const Params& params) { + using namespace ov::opset10; + using namespace ov::pass; + + const bool by_elements = get<0>(params); + Shape data_shape; + Shape expected_output_shape; + vector block; + vector input_2; // crops_begin or pads_begin + vector input_3; // crops_end or pads_end + tie(data_shape, block, input_2, input_3, expected_output_shape) = get<1>(params); + + const auto data = make_shared(element::f32, PartialShape::dynamic(data_shape.size())); + const auto block_p = Constant::create(element::i64, Shape{block.size()}, block); + const auto input_2_p = Constant::create(element::i64, Shape{input_2.size()}, input_2); + const auto input_3_p = Constant::create(element::i64, Shape{input_3.size()}, input_3); + const auto bts_or_stb = make_shared(data, block_p, input_2_p, input_3_p); + const auto f = make_shared(NodeVector{bts_or_stb}, ParameterVector{data}); + + Manager m; + m.set_per_pass_validation(false); + m.register_pass(by_elements); + m.run_passes(f); + ASSERT_EQ(count_ops_of_type(f), 0); + EXPECT_TRUE(f->get_result()->get_input_partial_shape(0).is_dynamic()); + + data->set_partial_shape(data_shape); + f->validate_nodes_and_infer_types(); + ASSERT_EQ(f->get_result()->get_input_shape(0), expected_output_shape); +} + +template +string get_test_name(testing::TestParamInfo obj) { + const auto& params = obj.param; + const bool by_elements = get<0>(params); + const auto& data_shape = get<0>(get<1>(params)); + + ostringstream result; + result << data_shape.size() << "D" << (by_elements ? "_by_elements" : ""); + return result.str(); +} + +using BatchToSpaceParams = tuple, // block + vector, // crops_begin + vector, // crops_end + Shape // expected_output_shape + >; + +using BatchToSpaceDecomposeParams = tuple; + +class BatchToSpaceDecompositionWithParams : public testing::WithParamInterface, + public TransformationTests {}; + +TEST_P(BatchToSpaceDecompositionWithParams, DynamicInputs) { + op_convertion_test(GetParam()); +} + +static vector batch_to_space_params = { + {{4, 3}, {1, 2}, {0, 0}, {0, 0}, {2, 6}}, + {{6, 5, 7}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {1, 8, 17}}, + {{30, 4, 1, 1}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {1, 20, 3, 2}}, + {{96, 3, 5, 7, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 12, 15, 14, 1}}, +}; + +INSTANTIATE_TEST_SUITE_P(TransformationTests, + BatchToSpaceDecompositionWithParams, + ::testing::Combine(::testing::ValuesIn({false, true}), + ::testing::ValuesIn(batch_to_space_params)), + get_test_name); + +using SpaceToBatchParams = tuple, // block + vector, // pads_begin + vector, // pads_end + Shape // expected_output_shape + >; + +using SpaceToBatchDecomposeParams = tuple; + +class SpaceToBatchDecompositionWithParams : public testing::WithParamInterface, + public TransformationTests {}; + +TEST_P(SpaceToBatchDecompositionWithParams, DynamicInputs) { + op_convertion_test(GetParam()); +} + +static vector space_to_batch_params = { + {{2, 6}, {1, 2}, {0, 0}, {0, 0}, {4, 3}}, + {{1, 8, 17}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {6, 5, 7}}, + {{1, 20, 3, 2}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {30, 4, 1, 1}}, + {{4, 12, 15, 14, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {96, 3, 5, 7, 1}}, +}; + +INSTANTIATE_TEST_SUITE_P(TransformationTests, + SpaceToBatchDecompositionWithParams, + ::testing::Combine(::testing::ValuesIn({false, true}), + ::testing::ValuesIn(space_to_batch_params)), + get_test_name);