[Transformations] Enable dynamic decomposition BTS and STB ops (#15179)

* Add dynamism for BatchToSpace conversion

* Extend dynamism for BatchToSpace conversion

Only block input needs to be const 'cse axes_order_const is freaky

* Enhace dynamism for BatchToSpace conversion

Block input need not be const now.

* Add dynamism for STB by elements conversion

* Remove const need for crops for BTS by_elements

* temp for review

* Try to fix output tensor overwrite

* Make test to reproduce invalid shape inference

* Reproduce the error with template plugin

* Fix code style

* Fix 0D inputs issue

* Remove 0D shape parts before Concat

* Apply nested namespaces

* Enable non-constant STB Block input

* Fix BTS runtime info

* Fix STB by elems runtime info

* Add dynamism for STB conversion

* Add BTS dynamic data test

* Add STB dynamic data test

* Reduce STB concats

* Add tests naming

* Edit

* style

* Consider other block element types

* Enhance type test

* Use opset10 only

* Check block shape
This commit is contained in:
Tomasz Jankowski 2023-02-14 11:45:10 +01:00 committed by GitHub
parent fac03ee5f7
commit c62be51cc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 412 additions and 263 deletions

View File

@ -4,94 +4,82 @@
#include "transformations/op_conversions/convert_batch_to_space.hpp" #include "transformations/op_conversions/convert_batch_to_space.hpp"
#include <algorithm>
#include <climits>
#include <memory> #include <memory>
#include <ngraph/pattern/op/wrap_type.hpp> #include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset3.hpp> #include <openvino/opsets/opset10.hpp>
#include <vector> #include <vector>
#include "itt.hpp" #include "itt.hpp"
using namespace std;
using namespace ov::opset10;
using namespace ov::element;
void ov::pass::ConvertBatchToSpace::convert_batch_to_space() { void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space); MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space);
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>(); const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
matcher_pass_callback callback = [](pattern::Matcher& m) { matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root()); const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
if (!batch_to_space) { if (!batch_to_space || transformation_callback(batch_to_space)) {
return false; return false;
} }
NodeVector new_ops; NodeRegistry rg;
auto data = batch_to_space->input_value(0); const auto data = batch_to_space->input_value(0);
auto block = batch_to_space->input_value(1); const auto block = batch_to_space->input_value(1);
auto crops_begin = batch_to_space->input_value(2); const auto crops_begin = batch_to_space->input_value(2);
auto crops_end = batch_to_space->input_value(3); const auto crops_end = batch_to_space->input_value(3);
if (data.get_partial_shape().is_dynamic()) { const auto data_shape_rank = data.get_partial_shape().rank();
return false; if (data_shape_rank.is_dynamic()) {
} return false; // because StridedSlice masks are std::vector
const auto& data_shape = data.get_shape();
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr());
const auto crops_begin_const = std::dynamic_pointer_cast<opset3::Constant>(crops_begin.get_node_shared_ptr());
const auto crops_end_const = std::dynamic_pointer_cast<opset3::Constant>(crops_end.get_node_shared_ptr());
if (!block_const || !crops_begin_const || !crops_end_const) {
return false;
} }
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>(); if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>(); return false;
}
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
// First we have to disperse the data from batch, then rearrange them // First we have to disperse the data from batch, then rearrange them
// so as appropriate chunks of data where close to their destination place. // so as appropriate chunks of data where close to their destination place.
// Finally squeeze data from respective dimensions.ss // Finally squeeze data from respective dimensions
std::vector<int64_t> dispersed_shape;
int64_t b_dim_divider = 1; const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
for (const auto& el : block_values) { const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
b_dim_divider *= el; const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
} const auto block_prod = rg.make<ReduceProd>(block, zero);
const auto batch_div = rg.make<Divide>(batch, block_prod);
// note: B_0 is expected to be 1. // note: B_0 is expected to be 1.
// x' = reshape(`data`, [B_1, ..., B_{N - 1}, batch / (B_1 * ... B_{N - 1}), D_1, D_2, ..., // x' = reshape(`data`, [B_1, ..., B_{N - 1}, batch / (B_1 * ... B_{N - 1}), D_1, D_2, ...,
// D_{N - 1}]), // D_{N - 1}]),
// where B_i = block_shape[i] // where B_i = block_shape[i]
dispersed_shape.insert(dispersed_shape.begin(), block_values.begin() + 1, block_values.end()); const auto one = rg.make<Constant>(i64, Shape{1}, 1);
dispersed_shape.push_back(data_shape.at(0) / b_dim_divider); const auto end = rg.make<Constant>(i64, Shape{1}, block_length);
for (size_t i = 1; i < data_shape.size(); ++i) { const auto block_tail = rg.make<Slice>(block, one, end, one);
dispersed_shape.push_back(data_shape.at(i)); const auto data_shape_tail = rg.make<Slice>(shape_of_data, one, end, one);
} const auto dispersed_shape = rg.make<Concat>(OutputVector{block_tail, batch_div, data_shape_tail}, 0);
const auto out_pattern_1 =
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
const bool special_zero = false; const bool special_zero = false;
std::shared_ptr<Node> flat_node = std::make_shared<ov::opset3::Reshape>(data, out_pattern_1, special_zero); shared_ptr<Node> flat_node = rg.make<Reshape>(data, dispersed_shape, special_zero);
new_ops.push_back(flat_node);
// calculate axes to transpose // calculate axes to transpose
// x'' = transpose(x', [N, N + 1, 0, N + 2, 1, ..., N + N - 1, N - 1]) // x'' = transpose(x', [N, N + 1, 0, N + 2, 1, ..., N + N - 1, N - 1])
std::vector<size_t> axes_order{block_values.size() - 1}; vector<int64_t> axes_order{block_length - 1};
for (size_t i = 0; i < block_values.size() - 1; ++i) { for (int64_t i = 0; i < block_length - 1; ++i) {
axes_order.push_back(i + block_values.size()); axes_order.push_back(i + block_length);
axes_order.push_back(i); axes_order.push_back(i);
} }
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
const auto axes_order_const =
opset3::Constant::create(element::i64,
Shape{axes_order.size()},
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
// x''' = reshape(x'', [batch / (B_1 * ... * B_{N - 1}), D_1 * B_1, D_2 * B_2, ... , D_{N - 1} // x''' = reshape(x'', [batch / (B_1 * ... * B_{N - 1}), D_1 * B_1, D_2 * B_2, ... , D_{N - 1}
// * B_{N - 1}]) // * B_{N - 1}])
std::vector<int64_t> squeezed_shape; const auto squeezed_shape_tail = rg.make<Multiply>(block_tail, data_shape_tail);
squeezed_shape.push_back(data_shape.at(0) / b_dim_divider); const auto squeezed_shape = rg.make<Concat>(OutputVector{batch_div, squeezed_shape_tail}, 0);
for (size_t i = 1; i < block_values.size(); ++i) { flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
squeezed_shape.push_back(data_shape.at(i) * block_values.at(i));
}
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<opset3::Reshape>(flat_node, out_pattern_2, special_zero);
new_ops.push_back(flat_node);
// Crop the start and end of dimensions according to `crops_begin`, `crops_end` to produce // Crop the start and end of dimensions according to `crops_begin`, `crops_end` to produce
// the output of shape: // the output of shape:
@ -99,129 +87,133 @@ void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
// `y = [batch / (B_1 * ... * B_{N - 1}), crop(D_1 * B_1, crops_begin[1], crops_end[1]), // `y = [batch / (B_1 * ... * B_{N - 1}), crop(D_1 * B_1, crops_begin[1], crops_end[1]),
// crop(D_2 * B_2, crops_begin[2], crops_end[2]), ... , // crop(D_2 * B_2, crops_begin[2], crops_end[2]), ... ,
// crop(D_{N - 1} * B_{N - 1}, crops_begin[N - 1], crops_end[N - 1])]` // crop(D_{N - 1} * B_{N - 1}, crops_begin[N - 1], crops_end[N - 1])]`
std::vector<int64_t> upperbounds_values; const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
auto flat_node_shape = flat_node->get_shape(); const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
}
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(), const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
Shape{upperbounds_values.size()}, const auto& end_mask = begin_mask;
upperbounds_values); flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
std::vector<int64_t> begin_mask(data_shape.size(), 0);
std::vector<int64_t> end_mask(data_shape.size(), 0);
flat_node =
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
new_ops.push_back(flat_node);
flat_node->set_friendly_name(batch_to_space->get_friendly_name()); flat_node->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::copy_runtime_info(batch_to_space, new_ops); copy_runtime_info(batch_to_space, rg.get());
ngraph::replace_node(batch_to_space, flat_node); replace_node(batch_to_space, flat_node);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name); const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }
void ov::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() { void ov::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() {
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space_by_elements); MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space_by_elements);
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>(); const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
matcher_pass_callback callback = [this](pattern::Matcher& m) { matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root()); const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
if (!batch_to_space) { if (!batch_to_space || transformation_callback(batch_to_space)) {
return false; return false;
} }
auto data = batch_to_space->input_value(0); const auto data = batch_to_space->input_value(0);
if (data.get_partial_shape().is_dynamic()) { const auto data_shape_rank = data.get_partial_shape().rank();
return false; if (data_shape_rank.is_dynamic()) {
} return false; // because StridedSlice masks are std::vector
auto data_shape = data.get_shape();
if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
return false;
}
auto block = batch_to_space->input_value(1);
auto crops_begin = batch_to_space->input_value(2);
auto crops_end = batch_to_space->input_value(3);
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr());
const auto crops_begin_const = ov::as_type_ptr<opset3::Constant>(crops_begin.get_node_shared_ptr());
const auto crops_end_const = ov::as_type_ptr<opset3::Constant>(crops_end.get_node_shared_ptr());
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
std::vector<int64_t> dispersed_shape(1);
dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end());
std::vector<size_t> axes_order(block_values.size() + 1);
std::vector<int64_t> squeezed_shape(data_shape.begin(), data_shape.end());
if (squeezed_shape.size() > block_values.size()) {
return false;
} }
NodeVector new_ops; const auto block = batch_to_space->input_value(1);
const auto crops_begin = batch_to_space->input_value(2);
const auto crops_end = batch_to_space->input_value(3);
std::shared_ptr<Node> flat_node = data.get_node_shared_ptr(); if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
for (size_t block_idx = 1; block_idx < block_values.size(); ++block_idx) { return false;
dispersed_shape[0] = block_values[block_idx]; }
dispersed_shape[1] /= block_values[block_idx]; const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
const auto out_pattern_1 =
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
const bool special_zero = false;
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero);
new_ops.push_back(flat_node);
size_t val = 1; NodeRegistry rg;
for (size_t axis_idx = 0; axis_idx <= block_values.size(); ++axis_idx) { const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
if ((block_idx + 1) == axis_idx) { const auto one = rg.make<Constant>(i64, Shape{1}, 1);
const auto two = rg.make<Constant>(i64, Shape{1}, 2);
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
const auto et_zero = rg.make<Constant>(block.get_element_type(), Shape{1}, 0);
shared_ptr<Node> dispersed_shape = rg.make<Concat>(OutputVector{et_zero, shape_of_data}, 0);
shared_ptr<Node> squeezed_shape = shape_of_data;
shared_ptr<Node> flat_node = data.get_node_shared_ptr();
const auto make_concat = [&](OutputVector nodes) {
nodes.erase(remove_if(nodes.begin(),
nodes.end(),
[](const Output<Node>& n) {
return n.get_partial_shape().is_static() && n.get_shape().size() > 0 &&
n.get_shape()[0] == 0;
}),
nodes.end());
return rg.make<Concat>(nodes, 0);
};
shared_ptr<Node> div;
for (int64_t b_idx = 1; b_idx < block_length; ++b_idx) {
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
const auto block_value = rg.make<Gather>(block, block_index, zero);
// dispersed_shape[0] = block[b_idx];
// dispersed_shape[1] /= block[b_idx];
if (!div) {
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
div = rg.make<Divide>(batch, block_value);
} else {
div = rg.make<Divide>(div, block_value);
}
auto ds_tail = rg.make<Slice>(dispersed_shape, two, int_max, one);
dispersed_shape = make_concat({block_value, div, ds_tail});
constexpr auto special_zero = false;
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
vector<int64_t> axes_order(block_length + 1);
int64_t val = 1;
for (int64_t axis_idx = 0; axis_idx <= block_length; ++axis_idx) {
if ((b_idx + 1) == axis_idx) {
axes_order[axis_idx] = 0; axes_order[axis_idx] = 0;
} else { } else {
axes_order[axis_idx] = val; axes_order[axis_idx] = val;
val++; val++;
} }
} }
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
const auto axes_order_const = // squeezed_shape[0] = dispersed_shape[1];
ov::opset3::Constant::create(element::i64, // squeezed_shape[b_idx] *= block[b_idx];
Shape{axes_order.size()}, const auto sq_slice = rg.make<Slice>(squeezed_shape, one, block_index, one);
std::vector<int64_t>(axes_order.begin(), axes_order.end())); const auto sq_bidx_dim = rg.make<Gather>(squeezed_shape, block_index, zero);
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const); const auto sq_mul = rg.make<Multiply>(sq_bidx_dim, block_value);
new_ops.push_back(flat_node); const auto sq_shape_tail = rg.make<Slice>(squeezed_shape, block_index_next, int_max, one);
squeezed_shape.reset();
squeezed_shape = make_concat({div, sq_slice, sq_mul, sq_shape_tail});
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
squeezed_shape[0] = dispersed_shape[1]; // dispersed_shape[b_idx + 1] = squeezed_shape[b_idx];
squeezed_shape[block_idx] *= block_values[block_idx]; const auto ds_front = rg.make<Slice>(dispersed_shape, zero, block_index_next, one);
dispersed_shape[block_idx + 1] = squeezed_shape[block_idx]; ds_tail = rg.make<Slice>(dispersed_shape, rg.make<Constant>(i64, Shape{1}, b_idx + 2), int_max, one);
const auto out_pattern_2 = dispersed_shape = make_concat({ds_front, sq_mul, ds_tail});
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
new_ops.push_back(flat_node);
} }
std::vector<int64_t> upperbounds_values; const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
auto flat_node_shape = flat_node->get_shape(); const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
}
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
Shape{upperbounds_values.size()},
upperbounds_values);
std::vector<int64_t> begin_mask(data_shape.size(), 0); const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
std::vector<int64_t> end_mask(data_shape.size(), 0); const auto& end_mask = begin_mask;
flat_node = flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
new_ops.push_back(flat_node);
flat_node->set_friendly_name(batch_to_space->get_friendly_name()); flat_node->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::copy_runtime_info(batch_to_space, new_ops); copy_runtime_info(batch_to_space, rg.get());
ngraph::replace_node(batch_to_space, flat_node); replace_node(batch_to_space, flat_node);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name); const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }

View File

@ -4,42 +4,43 @@
#include "transformations/op_conversions/convert_space_to_batch.hpp" #include "transformations/op_conversions/convert_space_to_batch.hpp"
#include <climits>
#include <memory> #include <memory>
#include <ngraph/pattern/op/wrap_type.hpp> #include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset3.hpp> #include <openvino/opsets/opset10.hpp>
#include <vector> #include <vector>
#include "itt.hpp" #include "itt.hpp"
using namespace std;
using namespace ov::opset10;
using namespace ov::element;
void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() { void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() {
MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch); MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch);
auto space_to_batch = ngraph::pattern::wrap_type<ov::opset3::SpaceToBatch>(); const auto space_to_batch = pattern::wrap_type<SpaceToBatch>();
matcher_pass_callback callback = [](pattern::Matcher& m) { matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto space_to_batch = std::dynamic_pointer_cast<ov::opset3::SpaceToBatch>(m.get_match_root()); const auto space_to_batch = dynamic_pointer_cast<SpaceToBatch>(m.get_match_root());
if (!space_to_batch) { if (!space_to_batch || transformation_callback(space_to_batch)) {
return false; return false;
} }
NodeVector new_ops; const auto data = space_to_batch->input_value(0);
auto data = space_to_batch->input_value(0); if (data.get_partial_shape().rank().is_dynamic()) {
auto block = space_to_batch->input_value(1);
auto pads_begin = space_to_batch->input_value(2);
auto pads_end = space_to_batch->input_value(3);
if (data.get_partial_shape().is_dynamic()) {
return false; return false;
} }
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr()); const auto block = space_to_batch->input_value(1);
const auto pads_begin_const = std::dynamic_pointer_cast<opset3::Constant>(pads_begin.get_node_shared_ptr()); const auto pads_begin = space_to_batch->input_value(2);
const auto pads_end_const = std::dynamic_pointer_cast<opset3::Constant>(pads_end.get_node_shared_ptr()); const auto pads_end = space_to_batch->input_value(3);
if (!block_const || !pads_begin_const || !pads_end_const) { if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
return false; return false;
} }
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>(); NodeRegistry rg;
// Zero-pad the start and end of dimensions [D_0, ..., D_{N - 1}] of the input according to // Zero-pad the start and end of dimensions [D_0, ..., D_{N - 1}] of the input according to
// `pads_begin` // `pads_begin`
@ -47,162 +48,159 @@ void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() {
// note: P_0 for batch dimension is expected to be 0 (no-padding). // note: P_0 for batch dimension is expected to be 0 (no-padding).
// x = [batch + P_0, D_1 + P_1, D_2 + P_2, ..., D_{N - 1} + P_{N - 1}], where P_i = // x = [batch + P_0, D_1 + P_1, D_2 + P_2, ..., D_{N - 1} + P_{N - 1}], where P_i =
// pads_begin[i] + pads_end[i] // pads_begin[i] + pads_end[i]
std::shared_ptr<Node> flat_node = shared_ptr<Node> flat_node = rg.make<Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
std::make_shared<opset3::Pad>(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT); const auto out_shape = rg.make<ShapeOf>(flat_node, block.get_element_type());
auto out_shape = flat_node->get_shape();
new_ops.push_back(flat_node); const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
// First we have to disperse the data from spatial dimensions, then // First we have to disperse the data from spatial dimensions, then
// rearrange them so as appropriate chunks of data where close to their // rearrange them so as appropriate chunks of data where close to their
// destination place. Finally squeeze data from respective dimensions. // destination place. Finally squeeze data from respective dimensions.
Shape dispersed_shape{out_shape.at(0)};
// note: B_0 for batch is ignored. // note: B_0 for batch is ignored.
// x' = reshape(x, [batch, (D_1 + P_1) / B_1, B_1, (D_2 + P_2) / B_2, B_2, ..., // x' = reshape(x, [batch, (D_1 + P_1) / B_1, B_1, (D_2 + P_2) / B_2, B_2, ...,
// (D_{N - 1} + P_{N - 1}) / B_{N - 1}, B_{N - 1}]), where B_i = block_shape[i] // (D_{N - 1} + P_{N - 1}) / B_{N - 1}, B_{N - 1}]), where B_i = block_shape[i]
for (size_t i = 1; i < block_values.size(); ++i) { const auto batch = rg.make<Gather>(out_shape, zero, zero);
dispersed_shape.push_back(out_shape.at(i) / block_values.at(i)); const auto out_shape_tail = rg.make<Slice>(out_shape, one, int_max, one);
dispersed_shape.push_back(block_values.at(i)); const auto block_tail = rg.make<Slice>(block, one, int_max, one);
} const auto os_tail_div = rg.make<Divide>(out_shape_tail, block_tail);
const auto out_pattern = opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape); // interleave os_tail_div with block_tail
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern, false); const auto c = rg.make<Concat>(NodeVector{os_tail_div, block_tail}, 0);
new_ops.push_back(flat_node); const auto r =
rg.make<Reshape>(c, rg.make<Constant>(i64, Shape{2}, vector<int64_t>{2, block_length - 1}), false);
const auto t = rg.make<Transpose>(r, rg.make<Constant>(i64, Shape{2}, vector<int64_t>{1, 0}));
const auto interleaved = rg.make<Reshape>(t, rg.make<Constant>(i64, Shape{1}, 2 * (block_length - 1)), false);
const auto dispersed_shape = rg.make<Concat>(NodeVector{batch, interleaved}, 0);
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, false);
// x'' = transpose(x', [2, 4, ..., (N - 1) + (N - 1), 0, 1, 3, ..., N + (N - 1)]) // x'' = transpose(x', [2, 4, ..., (N - 1) + (N - 1), 0, 1, 3, ..., N + (N - 1)])
std::vector<size_t> axes_order; vector<int64_t> axes_order;
for (size_t i = 0, j = 2; i < block_values.size() - 1; ++i, j += 2) { for (int64_t i = 0, j = 2; i < block_length - 1; ++i, j += 2) {
axes_order.push_back(j); axes_order.push_back(j);
} }
axes_order.push_back(0); axes_order.push_back(0);
for (size_t i = 0, j = 1; i < block_values.size() - 1; ++i, j += 2) { for (int64_t i = 0, j = 1; i < block_length - 1; ++i, j += 2) {
axes_order.push_back(j); axes_order.push_back(j);
} }
const auto axes_order_const = const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
opset3::Constant::create(element::i64, flat_node = rg.make<Transpose>(flat_node, axes_order_const);
Shape{axes_order.size()},
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
Shape squeezed_shape; // y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ...,
int64_t prod = 1;
for (const auto& el : block_values) {
prod *= el;
}
// y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ...
// ,
// (D_{N - 1} + P_{N - 1}) / B_{N - 1}]) // (D_{N - 1} + P_{N - 1}) / B_{N - 1}])
squeezed_shape.push_back(out_shape.at(0) * prod); // note: B_0 is assumed to be 1 by op definion
for (size_t i = 1; i < block_values.size(); ++i) { const auto block_prod = rg.make<ReduceProd>(block, zero);
squeezed_shape.push_back(out_shape.at(i) / block_values.at(i)); const auto squeezed_shape = rg.make<Concat>(NodeVector{rg.make<Multiply>(batch, block_prod), os_tail_div}, 0);
} flat_node = rg.make<Reshape>(flat_node, squeezed_shape, false);
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, false);
new_ops.push_back(flat_node);
flat_node->set_friendly_name(space_to_batch->get_friendly_name()); flat_node->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::copy_runtime_info(space_to_batch, new_ops); copy_runtime_info(space_to_batch, rg.get());
ngraph::replace_node(space_to_batch, flat_node); replace_node(space_to_batch, flat_node);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(space_to_batch, matcher_name); const auto m = make_shared<pattern::Matcher>(space_to_batch, matcher_name);
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }
void ov::pass::ConvertSpaceToBatch::convert_space_to_batch_by_elements() { void ov::pass::ConvertSpaceToBatch::convert_space_to_batch_by_elements() {
MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch_by_elements); MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch_by_elements);
auto space_to_batch = ngraph::pattern::wrap_type<ov::opset3::SpaceToBatch>(); const auto space_to_batch = pattern::wrap_type<SpaceToBatch>();
matcher_pass_callback callback = [this](pattern::Matcher& m) { matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto space_to_batch = std::dynamic_pointer_cast<ov::opset3::SpaceToBatch>(m.get_match_root()); const auto space_to_batch = dynamic_pointer_cast<SpaceToBatch>(m.get_match_root());
if (!space_to_batch) { if (!space_to_batch || transformation_callback(space_to_batch)) {
return false; return false;
} }
auto data = space_to_batch->input_value(0); const auto data = space_to_batch->input_value(0);
if (data.get_partial_shape().rank().is_dynamic()) {
if (data.get_partial_shape().is_dynamic()) {
return false;
}
const auto& data_shape = data.get_shape();
if (transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
return false; return false;
} }
auto block = space_to_batch->input_value(1); const auto block = space_to_batch->input_value(1);
auto pads_begin = space_to_batch->input_value(2); const auto pads_begin = space_to_batch->input_value(2);
auto pads_end = space_to_batch->input_value(3); const auto pads_end = space_to_batch->input_value(3);
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr()); if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
const auto pads_begin_const = ov::as_type_ptr<opset3::Constant>(pads_begin.get_node_shared_ptr());
const auto pads_end_const = ov::as_type_ptr<opset3::Constant>(pads_end.get_node_shared_ptr());
if (!block_const || !pads_begin_const || !pads_end_const) {
return false; return false;
} }
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>(); const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
NodeVector new_ops; NodeRegistry rg;
std::shared_ptr<Node> flat_node = shared_ptr<Node> flat_node = rg.make<Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
std::make_shared<opset3::Pad>(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT);
new_ops.push_back(flat_node);
auto out_shape = flat_node->get_shape();
std::vector<int64_t> dispersed_shape(block_values.size() + 1); shared_ptr<Node> squeezed_shape = rg.make<ShapeOf>(flat_node, block.get_element_type());
std::vector<size_t> axes_order(block_values.size() + 1);
std::vector<int64_t> squeezed_shape(out_shape.begin(), out_shape.end()); const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
for (int64_t block_idx = block_values.size() - 1; block_idx >= 0; --block_idx) { const auto one = rg.make<Constant>(i64, Shape{1}, 1);
int64_t sq_shape_idx = block_values.size() - 1; const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
for (int64_t b_idx = block_length - 1; b_idx >= 0; --b_idx) {
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
const auto block_value = rg.make<Gather>(block, block_index, zero);
NodeVector dispersed_shape_prep;
dispersed_shape_prep.reserve(block_length + 1);
if (b_idx > 0) // avoid addind empty Slice into Concat
dispersed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, zero, block_index, one));
const auto squeezed_element = rg.make<Gather>(squeezed_shape, block_index, zero);
dispersed_shape_prep.push_back(rg.make<Divide>(squeezed_element, block_value));
dispersed_shape_prep.push_back(block_value);
if (b_idx + 1 < block_length) // avoid addind empty Slice into Concat
dispersed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, block_index_next, int_max, one));
const auto dispersed_shape = rg.make<Concat>(dispersed_shape_prep, 0);
constexpr auto special_zero = false;
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
vector<int64_t> axes_order(block_length + 1);
int64_t axis_idx = axes_order.size() - 1; int64_t axis_idx = axes_order.size() - 1;
for (int64_t shape_idx = dispersed_shape.size() - 1; shape_idx >= 0; --shape_idx) { for (int64_t ds_idx = block_length; ds_idx >= 0; --ds_idx) {
if (shape_idx == (block_idx + 1)) { if (ds_idx == (b_idx + 1)) {
dispersed_shape[shape_idx] = block_values[block_idx]; axes_order[0] = ds_idx;
axes_order[0] = shape_idx; } else if (ds_idx == b_idx) {
} else if (shape_idx == block_idx) { axes_order[axis_idx] = ds_idx;
dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx] / block_values[block_idx];
axes_order[axis_idx] = shape_idx;
axis_idx--; axis_idx--;
sq_shape_idx--;
} else { } else {
dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx]; axes_order[axis_idx] = ds_idx;
axes_order[axis_idx] = shape_idx;
axis_idx--; axis_idx--;
sq_shape_idx--;
} }
} }
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
const auto out_pattern_1 = // don't change squeezed_shape at the last iteration, block[0] is assumed to be 1 by op definion
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape); if (b_idx > 0) {
const bool special_zero = false; NodeVector squeezed_shape_prep;
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero); squeezed_shape_prep.reserve(block_length);
new_ops.push_back(flat_node); squeezed_shape_prep.push_back(
rg.make<Multiply>(rg.make<Gather>(squeezed_shape, zero, zero), block_value));
if (b_idx > 1) { // avoid addind empty Slice into Concat
squeezed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, one, block_index, one));
}
squeezed_shape_prep.push_back(
rg.make<Divide>(rg.make<Gather>(squeezed_shape, block_index, zero), block_value));
if (b_idx + 1 < block_length) { // avoid addind empty Slice into Concat
squeezed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, block_index_next, int_max, one));
}
const auto axes_order_const = squeezed_shape = rg.make<Concat>(squeezed_shape_prep, 0);
opset3::Constant::create(element::i64, }
Shape{axes_order.size()}, flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
squeezed_shape[0] *= block_values[block_idx];
squeezed_shape[block_idx] /= block_values[block_idx];
const auto out_pattern_2 =
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
new_ops.push_back(flat_node);
} }
flat_node->set_friendly_name(space_to_batch->get_friendly_name()); flat_node->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::copy_runtime_info(space_to_batch, new_ops); copy_runtime_info(space_to_batch, rg.get());
ngraph::replace_node(space_to_batch, flat_node); replace_node(space_to_batch, flat_node);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(space_to_batch, matcher_name); const auto m = make_shared<pattern::Matcher>(space_to_batch, matcher_name);
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }

View File

@ -8,6 +8,7 @@
#include <ngraph/function.hpp> #include <ngraph/function.hpp>
#include <ngraph/opsets/opset3.hpp> #include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <queue> #include <queue>
#include <sstream> #include <sstream>
#include <string> #include <string>
@ -19,6 +20,7 @@
#include "common_test_utils/ngraph_test_utils.hpp" #include "common_test_utils/ngraph_test_utils.hpp"
#include "common_test_utils/test_common.hpp" #include "common_test_utils/test_common.hpp"
using namespace std;
using namespace testing; using namespace testing;
using namespace ngraph; using namespace ngraph;
@ -35,6 +37,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecompositionByElements) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertBatchToSpace>(); manager.register_pass<ov::pass::ConvertBatchToSpace>();
manager.register_pass<ov::pass::ConstantFolding>();
} }
{ {
@ -93,6 +96,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecompositionByElements) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertSpaceToBatch>(); manager.register_pass<ov::pass::ConvertSpaceToBatch>();
manager.register_pass<ov::pass::ConstantFolding>();
} }
{ {
@ -159,6 +163,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecomposition) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertSpaceToBatch>(false); manager.register_pass<ov::pass::ConvertSpaceToBatch>(false);
manager.register_pass<ov::pass::ConstantFolding>();
} }
{ {
@ -195,6 +200,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data}); std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertBatchToSpace>(false); manager.register_pass<ov::pass::ConvertBatchToSpace>(false);
manager.register_pass<ov::pass::ConstantFolding>();
} }
{ {
@ -218,3 +224,156 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data}); function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
} }
} }
template <typename Op, typename Conversion, typename Params>
void op_convertion_type_test(const Params& params) {
using namespace ov::opset10;
using namespace ov::pass;
const auto by_elements = get<0>(params);
const auto block_elem_type = get<1>(params);
const auto data = make_shared<Parameter>(element::f32, Shape{1, 1});
const auto block_p = Constant::create(block_elem_type, Shape{2}, {1, 1});
const auto input_2_p = Constant::create(block_elem_type, Shape{2}, {0, 0});
const auto input_3_p = Constant::create(block_elem_type, Shape{2}, {0, 0});
const auto bts_or_stb = make_shared<Op>(data, block_p, input_2_p, input_3_p);
const auto f = make_shared<Function>(NodeVector{bts_or_stb}, ParameterVector{data});
Manager m;
m.register_pass<Conversion>(by_elements);
m.register_pass<ConstantFolding>();
ASSERT_NO_THROW(m.run_passes(f));
EXPECT_EQ(f->get_result()->get_input_shape(0), (Shape{1, 1}));
}
using ElementTypeParams = tuple<bool, // by_elements
element::Type // block element type
>;
class BatchToSpaceDecomposition2D : public testing::WithParamInterface<ElementTypeParams>,
public TransformationTests {};
TEST_P(BatchToSpaceDecomposition2D, BlockElemType) {
op_convertion_type_test<ov::opset10::BatchToSpace, ov::pass::ConvertBatchToSpace>(GetParam());
}
INSTANTIATE_TEST_SUITE_P(TransformationTests,
BatchToSpaceDecomposition2D,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn({element::i32, element::i64})));
class SpaceToBatchDecomposition2D : public testing::WithParamInterface<ElementTypeParams>,
public TransformationTests {};
TEST_P(SpaceToBatchDecomposition2D, BlockElemType) {
op_convertion_type_test<ov::opset10::SpaceToBatch, ov::pass::ConvertSpaceToBatch>(GetParam());
}
INSTANTIATE_TEST_SUITE_P(TransformationTests,
SpaceToBatchDecomposition2D,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn({element::i32, element::i64})));
template <typename Op, typename Conversion, typename Params>
void op_convertion_test(const Params& params) {
using namespace ov::opset10;
using namespace ov::pass;
const bool by_elements = get<0>(params);
Shape data_shape;
Shape expected_output_shape;
vector<int64_t> block;
vector<int64_t> input_2; // crops_begin or pads_begin
vector<int64_t> input_3; // crops_end or pads_end
tie(data_shape, block, input_2, input_3, expected_output_shape) = get<1>(params);
const auto data = make_shared<Parameter>(element::f32, PartialShape::dynamic(data_shape.size()));
const auto block_p = Constant::create(element::i64, Shape{block.size()}, block);
const auto input_2_p = Constant::create(element::i64, Shape{input_2.size()}, input_2);
const auto input_3_p = Constant::create(element::i64, Shape{input_3.size()}, input_3);
const auto bts_or_stb = make_shared<Op>(data, block_p, input_2_p, input_3_p);
const auto f = make_shared<Function>(NodeVector{bts_or_stb}, ParameterVector{data});
Manager m;
m.set_per_pass_validation(false);
m.register_pass<Conversion>(by_elements);
m.run_passes(f);
ASSERT_EQ(count_ops_of_type<Op>(f), 0);
EXPECT_TRUE(f->get_result()->get_input_partial_shape(0).is_dynamic());
data->set_partial_shape(data_shape);
f->validate_nodes_and_infer_types();
ASSERT_EQ(f->get_result()->get_input_shape(0), expected_output_shape);
}
template <typename Params>
string get_test_name(testing::TestParamInfo<Params> obj) {
const auto& params = obj.param;
const bool by_elements = get<0>(params);
const auto& data_shape = get<0>(get<1>(params));
ostringstream result;
result << data_shape.size() << "D" << (by_elements ? "_by_elements" : "");
return result.str();
}
using BatchToSpaceParams = tuple<Shape, // data_shape
vector<int64_t>, // block
vector<int64_t>, // crops_begin
vector<int64_t>, // crops_end
Shape // expected_output_shape
>;
using BatchToSpaceDecomposeParams = tuple<bool, // by_elements
BatchToSpaceParams>;
class BatchToSpaceDecompositionWithParams : public testing::WithParamInterface<BatchToSpaceDecomposeParams>,
public TransformationTests {};
TEST_P(BatchToSpaceDecompositionWithParams, DynamicInputs) {
op_convertion_test<ov::opset10::BatchToSpace, ov::pass::ConvertBatchToSpace>(GetParam());
}
static vector<BatchToSpaceParams> batch_to_space_params = {
{{4, 3}, {1, 2}, {0, 0}, {0, 0}, {2, 6}},
{{6, 5, 7}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {1, 8, 17}},
{{30, 4, 1, 1}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {1, 20, 3, 2}},
{{96, 3, 5, 7, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 12, 15, 14, 1}},
};
INSTANTIATE_TEST_SUITE_P(TransformationTests,
BatchToSpaceDecompositionWithParams,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn(batch_to_space_params)),
get_test_name<BatchToSpaceDecomposeParams>);
using SpaceToBatchParams = tuple<Shape, // data_shape
vector<int64_t>, // block
vector<int64_t>, // pads_begin
vector<int64_t>, // pads_end
Shape // expected_output_shape
>;
using SpaceToBatchDecomposeParams = tuple<bool, // by_elements
SpaceToBatchParams>;
class SpaceToBatchDecompositionWithParams : public testing::WithParamInterface<SpaceToBatchDecomposeParams>,
public TransformationTests {};
TEST_P(SpaceToBatchDecompositionWithParams, DynamicInputs) {
op_convertion_test<ov::opset10::SpaceToBatch, ov::pass::ConvertSpaceToBatch>(GetParam());
}
static vector<SpaceToBatchParams> space_to_batch_params = {
{{2, 6}, {1, 2}, {0, 0}, {0, 0}, {4, 3}},
{{1, 8, 17}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {6, 5, 7}},
{{1, 20, 3, 2}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {30, 4, 1, 1}},
{{4, 12, 15, 14, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {96, 3, 5, 7, 1}},
};
INSTANTIATE_TEST_SUITE_P(TransformationTests,
SpaceToBatchDecompositionWithParams,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn(space_to_batch_params)),
get_test_name<SpaceToBatchDecomposeParams>);