[Transformations] Enable dynamic decomposition BTS and STB ops (#15179)

* Add dynamism for BatchToSpace conversion

* Extend dynamism for BatchToSpace conversion

Only block input needs to be const 'cse axes_order_const is freaky

* Enhace dynamism for BatchToSpace conversion

Block input need not be const now.

* Add dynamism for STB by elements conversion

* Remove const need for crops for BTS by_elements

* temp for review

* Try to fix output tensor overwrite

* Make test to reproduce invalid shape inference

* Reproduce the error with template plugin

* Fix code style

* Fix 0D inputs issue

* Remove 0D shape parts before Concat

* Apply nested namespaces

* Enable non-constant STB Block input

* Fix BTS runtime info

* Fix STB by elems runtime info

* Add dynamism for STB conversion

* Add BTS dynamic data test

* Add STB dynamic data test

* Reduce STB concats

* Add tests naming

* Edit

* style

* Consider other block element types

* Enhance type test

* Use opset10 only

* Check block shape
This commit is contained in:
Tomasz Jankowski 2023-02-14 11:45:10 +01:00 committed by GitHub
parent fac03ee5f7
commit c62be51cc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 412 additions and 263 deletions

View File

@ -4,94 +4,82 @@
#include "transformations/op_conversions/convert_batch_to_space.hpp"
#include <algorithm>
#include <climits>
#include <memory>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset3.hpp>
#include <openvino/opsets/opset10.hpp>
#include <vector>
#include "itt.hpp"
using namespace std;
using namespace ov::opset10;
using namespace ov::element;
void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space);
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>();
matcher_pass_callback callback = [](pattern::Matcher& m) {
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root());
if (!batch_to_space) {
const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
matcher_pass_callback callback = [this](pattern::Matcher& m) {
const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
if (!batch_to_space || transformation_callback(batch_to_space)) {
return false;
}
NodeVector new_ops;
auto data = batch_to_space->input_value(0);
auto block = batch_to_space->input_value(1);
auto crops_begin = batch_to_space->input_value(2);
auto crops_end = batch_to_space->input_value(3);
NodeRegistry rg;
const auto data = batch_to_space->input_value(0);
const auto block = batch_to_space->input_value(1);
const auto crops_begin = batch_to_space->input_value(2);
const auto crops_end = batch_to_space->input_value(3);
if (data.get_partial_shape().is_dynamic()) {
return false;
}
const auto& data_shape = data.get_shape();
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr());
const auto crops_begin_const = std::dynamic_pointer_cast<opset3::Constant>(crops_begin.get_node_shared_ptr());
const auto crops_end_const = std::dynamic_pointer_cast<opset3::Constant>(crops_end.get_node_shared_ptr());
if (!block_const || !crops_begin_const || !crops_end_const) {
return false;
const auto data_shape_rank = data.get_partial_shape().rank();
if (data_shape_rank.is_dynamic()) {
return false; // because StridedSlice masks are std::vector
}
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
return false;
}
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
// First we have to disperse the data from batch, then rearrange them
// so as appropriate chunks of data where close to their destination place.
// Finally squeeze data from respective dimensions.ss
std::vector<int64_t> dispersed_shape;
int64_t b_dim_divider = 1;
for (const auto& el : block_values) {
b_dim_divider *= el;
}
// Finally squeeze data from respective dimensions
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
const auto block_prod = rg.make<ReduceProd>(block, zero);
const auto batch_div = rg.make<Divide>(batch, block_prod);
// note: B_0 is expected to be 1.
// x' = reshape(`data`, [B_1, ..., B_{N - 1}, batch / (B_1 * ... B_{N - 1}), D_1, D_2, ...,
// D_{N - 1}]),
// where B_i = block_shape[i]
dispersed_shape.insert(dispersed_shape.begin(), block_values.begin() + 1, block_values.end());
dispersed_shape.push_back(data_shape.at(0) / b_dim_divider);
for (size_t i = 1; i < data_shape.size(); ++i) {
dispersed_shape.push_back(data_shape.at(i));
}
const auto out_pattern_1 =
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
const auto end = rg.make<Constant>(i64, Shape{1}, block_length);
const auto block_tail = rg.make<Slice>(block, one, end, one);
const auto data_shape_tail = rg.make<Slice>(shape_of_data, one, end, one);
const auto dispersed_shape = rg.make<Concat>(OutputVector{block_tail, batch_div, data_shape_tail}, 0);
const bool special_zero = false;
std::shared_ptr<Node> flat_node = std::make_shared<ov::opset3::Reshape>(data, out_pattern_1, special_zero);
new_ops.push_back(flat_node);
shared_ptr<Node> flat_node = rg.make<Reshape>(data, dispersed_shape, special_zero);
// calculate axes to transpose
// x'' = transpose(x', [N, N + 1, 0, N + 2, 1, ..., N + N - 1, N - 1])
std::vector<size_t> axes_order{block_values.size() - 1};
for (size_t i = 0; i < block_values.size() - 1; ++i) {
axes_order.push_back(i + block_values.size());
vector<int64_t> axes_order{block_length - 1};
for (int64_t i = 0; i < block_length - 1; ++i) {
axes_order.push_back(i + block_length);
axes_order.push_back(i);
}
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
const auto axes_order_const =
opset3::Constant::create(element::i64,
Shape{axes_order.size()},
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
// x''' = reshape(x'', [batch / (B_1 * ... * B_{N - 1}), D_1 * B_1, D_2 * B_2, ... , D_{N - 1}
// * B_{N - 1}])
std::vector<int64_t> squeezed_shape;
squeezed_shape.push_back(data_shape.at(0) / b_dim_divider);
for (size_t i = 1; i < block_values.size(); ++i) {
squeezed_shape.push_back(data_shape.at(i) * block_values.at(i));
}
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<opset3::Reshape>(flat_node, out_pattern_2, special_zero);
new_ops.push_back(flat_node);
const auto squeezed_shape_tail = rg.make<Multiply>(block_tail, data_shape_tail);
const auto squeezed_shape = rg.make<Concat>(OutputVector{batch_div, squeezed_shape_tail}, 0);
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
// Crop the start and end of dimensions according to `crops_begin`, `crops_end` to produce
// the output of shape:
@ -99,129 +87,133 @@ void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
// `y = [batch / (B_1 * ... * B_{N - 1}), crop(D_1 * B_1, crops_begin[1], crops_end[1]),
// crop(D_2 * B_2, crops_begin[2], crops_end[2]), ... ,
// crop(D_{N - 1} * B_{N - 1}, crops_begin[N - 1], crops_end[N - 1])]`
std::vector<int64_t> upperbounds_values;
auto flat_node_shape = flat_node->get_shape();
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
}
const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
Shape{upperbounds_values.size()},
upperbounds_values);
std::vector<int64_t> begin_mask(data_shape.size(), 0);
std::vector<int64_t> end_mask(data_shape.size(), 0);
flat_node =
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
new_ops.push_back(flat_node);
const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
const auto& end_mask = begin_mask;
flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
flat_node->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::copy_runtime_info(batch_to_space, new_ops);
ngraph::replace_node(batch_to_space, flat_node);
copy_runtime_info(batch_to_space, rg.get());
replace_node(batch_to_space, flat_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name);
const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
this->register_matcher(m, callback);
}
void ov::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() {
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space_by_elements);
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>();
const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root());
if (!batch_to_space) {
const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
if (!batch_to_space || transformation_callback(batch_to_space)) {
return false;
}
auto data = batch_to_space->input_value(0);
const auto data = batch_to_space->input_value(0);
if (data.get_partial_shape().is_dynamic()) {
return false;
}
auto data_shape = data.get_shape();
if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
return false;
}
auto block = batch_to_space->input_value(1);
auto crops_begin = batch_to_space->input_value(2);
auto crops_end = batch_to_space->input_value(3);
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr());
const auto crops_begin_const = ov::as_type_ptr<opset3::Constant>(crops_begin.get_node_shared_ptr());
const auto crops_end_const = ov::as_type_ptr<opset3::Constant>(crops_end.get_node_shared_ptr());
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
std::vector<int64_t> dispersed_shape(1);
dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end());
std::vector<size_t> axes_order(block_values.size() + 1);
std::vector<int64_t> squeezed_shape(data_shape.begin(), data_shape.end());
if (squeezed_shape.size() > block_values.size()) {
return false;
const auto data_shape_rank = data.get_partial_shape().rank();
if (data_shape_rank.is_dynamic()) {
return false; // because StridedSlice masks are std::vector
}
NodeVector new_ops;
const auto block = batch_to_space->input_value(1);
const auto crops_begin = batch_to_space->input_value(2);
const auto crops_end = batch_to_space->input_value(3);
std::shared_ptr<Node> flat_node = data.get_node_shared_ptr();
for (size_t block_idx = 1; block_idx < block_values.size(); ++block_idx) {
dispersed_shape[0] = block_values[block_idx];
dispersed_shape[1] /= block_values[block_idx];
const auto out_pattern_1 =
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
const bool special_zero = false;
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero);
new_ops.push_back(flat_node);
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
return false;
}
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
size_t val = 1;
for (size_t axis_idx = 0; axis_idx <= block_values.size(); ++axis_idx) {
if ((block_idx + 1) == axis_idx) {
NodeRegistry rg;
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
const auto two = rg.make<Constant>(i64, Shape{1}, 2);
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
const auto et_zero = rg.make<Constant>(block.get_element_type(), Shape{1}, 0);
shared_ptr<Node> dispersed_shape = rg.make<Concat>(OutputVector{et_zero, shape_of_data}, 0);
shared_ptr<Node> squeezed_shape = shape_of_data;
shared_ptr<Node> flat_node = data.get_node_shared_ptr();
const auto make_concat = [&](OutputVector nodes) {
nodes.erase(remove_if(nodes.begin(),
nodes.end(),
[](const Output<Node>& n) {
return n.get_partial_shape().is_static() && n.get_shape().size() > 0 &&
n.get_shape()[0] == 0;
}),
nodes.end());
return rg.make<Concat>(nodes, 0);
};
shared_ptr<Node> div;
for (int64_t b_idx = 1; b_idx < block_length; ++b_idx) {
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
const auto block_value = rg.make<Gather>(block, block_index, zero);
// dispersed_shape[0] = block[b_idx];
// dispersed_shape[1] /= block[b_idx];
if (!div) {
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
div = rg.make<Divide>(batch, block_value);
} else {
div = rg.make<Divide>(div, block_value);
}
auto ds_tail = rg.make<Slice>(dispersed_shape, two, int_max, one);
dispersed_shape = make_concat({block_value, div, ds_tail});
constexpr auto special_zero = false;
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
vector<int64_t> axes_order(block_length + 1);
int64_t val = 1;
for (int64_t axis_idx = 0; axis_idx <= block_length; ++axis_idx) {
if ((b_idx + 1) == axis_idx) {
axes_order[axis_idx] = 0;
} else {
axes_order[axis_idx] = val;
val++;
}
}
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
const auto axes_order_const =
ov::opset3::Constant::create(element::i64,
Shape{axes_order.size()},
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
// squeezed_shape[0] = dispersed_shape[1];
// squeezed_shape[b_idx] *= block[b_idx];
const auto sq_slice = rg.make<Slice>(squeezed_shape, one, block_index, one);
const auto sq_bidx_dim = rg.make<Gather>(squeezed_shape, block_index, zero);
const auto sq_mul = rg.make<Multiply>(sq_bidx_dim, block_value);
const auto sq_shape_tail = rg.make<Slice>(squeezed_shape, block_index_next, int_max, one);
squeezed_shape.reset();
squeezed_shape = make_concat({div, sq_slice, sq_mul, sq_shape_tail});
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
squeezed_shape[0] = dispersed_shape[1];
squeezed_shape[block_idx] *= block_values[block_idx];
dispersed_shape[block_idx + 1] = squeezed_shape[block_idx];
const auto out_pattern_2 =
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
new_ops.push_back(flat_node);
// dispersed_shape[b_idx + 1] = squeezed_shape[b_idx];
const auto ds_front = rg.make<Slice>(dispersed_shape, zero, block_index_next, one);
ds_tail = rg.make<Slice>(dispersed_shape, rg.make<Constant>(i64, Shape{1}, b_idx + 2), int_max, one);
dispersed_shape = make_concat({ds_front, sq_mul, ds_tail});
}
std::vector<int64_t> upperbounds_values;
auto flat_node_shape = flat_node->get_shape();
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
}
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
Shape{upperbounds_values.size()},
upperbounds_values);
const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
std::vector<int64_t> begin_mask(data_shape.size(), 0);
std::vector<int64_t> end_mask(data_shape.size(), 0);
flat_node =
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
new_ops.push_back(flat_node);
const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
const auto& end_mask = begin_mask;
flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
flat_node->set_friendly_name(batch_to_space->get_friendly_name());
ngraph::copy_runtime_info(batch_to_space, new_ops);
ngraph::replace_node(batch_to_space, flat_node);
copy_runtime_info(batch_to_space, rg.get());
replace_node(batch_to_space, flat_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name);
const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -4,42 +4,43 @@
#include "transformations/op_conversions/convert_space_to_batch.hpp"
#include <climits>
#include <memory>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset3.hpp>
#include <openvino/opsets/opset10.hpp>
#include <vector>
#include "itt.hpp"
using namespace std;
using namespace ov::opset10;
using namespace ov::element;
void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() {
MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch);
auto space_to_batch = ngraph::pattern::wrap_type<ov::opset3::SpaceToBatch>();
matcher_pass_callback callback = [](pattern::Matcher& m) {
auto space_to_batch = std::dynamic_pointer_cast<ov::opset3::SpaceToBatch>(m.get_match_root());
if (!space_to_batch) {
const auto space_to_batch = pattern::wrap_type<SpaceToBatch>();
matcher_pass_callback callback = [this](pattern::Matcher& m) {
const auto space_to_batch = dynamic_pointer_cast<SpaceToBatch>(m.get_match_root());
if (!space_to_batch || transformation_callback(space_to_batch)) {
return false;
}
NodeVector new_ops;
auto data = space_to_batch->input_value(0);
auto block = space_to_batch->input_value(1);
auto pads_begin = space_to_batch->input_value(2);
auto pads_end = space_to_batch->input_value(3);
if (data.get_partial_shape().is_dynamic()) {
const auto data = space_to_batch->input_value(0);
if (data.get_partial_shape().rank().is_dynamic()) {
return false;
}
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr());
const auto pads_begin_const = std::dynamic_pointer_cast<opset3::Constant>(pads_begin.get_node_shared_ptr());
const auto pads_end_const = std::dynamic_pointer_cast<opset3::Constant>(pads_end.get_node_shared_ptr());
const auto block = space_to_batch->input_value(1);
const auto pads_begin = space_to_batch->input_value(2);
const auto pads_end = space_to_batch->input_value(3);
if (!block_const || !pads_begin_const || !pads_end_const) {
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
return false;
}
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
NodeRegistry rg;
// Zero-pad the start and end of dimensions [D_0, ..., D_{N - 1}] of the input according to
// `pads_begin`
@ -47,162 +48,159 @@ void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() {
// note: P_0 for batch dimension is expected to be 0 (no-padding).
// x = [batch + P_0, D_1 + P_1, D_2 + P_2, ..., D_{N - 1} + P_{N - 1}], where P_i =
// pads_begin[i] + pads_end[i]
std::shared_ptr<Node> flat_node =
std::make_shared<opset3::Pad>(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT);
auto out_shape = flat_node->get_shape();
new_ops.push_back(flat_node);
shared_ptr<Node> flat_node = rg.make<Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
const auto out_shape = rg.make<ShapeOf>(flat_node, block.get_element_type());
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
// First we have to disperse the data from spatial dimensions, then
// rearrange them so as appropriate chunks of data where close to their
// destination place. Finally squeeze data from respective dimensions.
Shape dispersed_shape{out_shape.at(0)};
// note: B_0 for batch is ignored.
// x' = reshape(x, [batch, (D_1 + P_1) / B_1, B_1, (D_2 + P_2) / B_2, B_2, ...,
// (D_{N - 1} + P_{N - 1}) / B_{N - 1}, B_{N - 1}]), where B_i = block_shape[i]
for (size_t i = 1; i < block_values.size(); ++i) {
dispersed_shape.push_back(out_shape.at(i) / block_values.at(i));
dispersed_shape.push_back(block_values.at(i));
}
const auto batch = rg.make<Gather>(out_shape, zero, zero);
const auto out_shape_tail = rg.make<Slice>(out_shape, one, int_max, one);
const auto block_tail = rg.make<Slice>(block, one, int_max, one);
const auto os_tail_div = rg.make<Divide>(out_shape_tail, block_tail);
const auto out_pattern = opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern, false);
new_ops.push_back(flat_node);
// interleave os_tail_div with block_tail
const auto c = rg.make<Concat>(NodeVector{os_tail_div, block_tail}, 0);
const auto r =
rg.make<Reshape>(c, rg.make<Constant>(i64, Shape{2}, vector<int64_t>{2, block_length - 1}), false);
const auto t = rg.make<Transpose>(r, rg.make<Constant>(i64, Shape{2}, vector<int64_t>{1, 0}));
const auto interleaved = rg.make<Reshape>(t, rg.make<Constant>(i64, Shape{1}, 2 * (block_length - 1)), false);
const auto dispersed_shape = rg.make<Concat>(NodeVector{batch, interleaved}, 0);
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, false);
// x'' = transpose(x', [2, 4, ..., (N - 1) + (N - 1), 0, 1, 3, ..., N + (N - 1)])
std::vector<size_t> axes_order;
for (size_t i = 0, j = 2; i < block_values.size() - 1; ++i, j += 2) {
vector<int64_t> axes_order;
for (int64_t i = 0, j = 2; i < block_length - 1; ++i, j += 2) {
axes_order.push_back(j);
}
axes_order.push_back(0);
for (size_t i = 0, j = 1; i < block_values.size() - 1; ++i, j += 2) {
for (int64_t i = 0, j = 1; i < block_length - 1; ++i, j += 2) {
axes_order.push_back(j);
}
const auto axes_order_const =
opset3::Constant::create(element::i64,
Shape{axes_order.size()},
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
Shape squeezed_shape;
int64_t prod = 1;
for (const auto& el : block_values) {
prod *= el;
}
// y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ...
// ,
// y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ...,
// (D_{N - 1} + P_{N - 1}) / B_{N - 1}])
squeezed_shape.push_back(out_shape.at(0) * prod);
for (size_t i = 1; i < block_values.size(); ++i) {
squeezed_shape.push_back(out_shape.at(i) / block_values.at(i));
}
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, false);
new_ops.push_back(flat_node);
// note: B_0 is assumed to be 1 by op definion
const auto block_prod = rg.make<ReduceProd>(block, zero);
const auto squeezed_shape = rg.make<Concat>(NodeVector{rg.make<Multiply>(batch, block_prod), os_tail_div}, 0);
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, false);
flat_node->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::copy_runtime_info(space_to_batch, new_ops);
ngraph::replace_node(space_to_batch, flat_node);
copy_runtime_info(space_to_batch, rg.get());
replace_node(space_to_batch, flat_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(space_to_batch, matcher_name);
const auto m = make_shared<pattern::Matcher>(space_to_batch, matcher_name);
this->register_matcher(m, callback);
}
void ov::pass::ConvertSpaceToBatch::convert_space_to_batch_by_elements() {
MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch_by_elements);
auto space_to_batch = ngraph::pattern::wrap_type<ov::opset3::SpaceToBatch>();
const auto space_to_batch = pattern::wrap_type<SpaceToBatch>();
matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto space_to_batch = std::dynamic_pointer_cast<ov::opset3::SpaceToBatch>(m.get_match_root());
if (!space_to_batch) {
const auto space_to_batch = dynamic_pointer_cast<SpaceToBatch>(m.get_match_root());
if (!space_to_batch || transformation_callback(space_to_batch)) {
return false;
}
auto data = space_to_batch->input_value(0);
if (data.get_partial_shape().is_dynamic()) {
return false;
}
const auto& data_shape = data.get_shape();
if (transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
const auto data = space_to_batch->input_value(0);
if (data.get_partial_shape().rank().is_dynamic()) {
return false;
}
auto block = space_to_batch->input_value(1);
auto pads_begin = space_to_batch->input_value(2);
auto pads_end = space_to_batch->input_value(3);
const auto block = space_to_batch->input_value(1);
const auto pads_begin = space_to_batch->input_value(2);
const auto pads_end = space_to_batch->input_value(3);
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr());
const auto pads_begin_const = ov::as_type_ptr<opset3::Constant>(pads_begin.get_node_shared_ptr());
const auto pads_end_const = ov::as_type_ptr<opset3::Constant>(pads_end.get_node_shared_ptr());
if (!block_const || !pads_begin_const || !pads_end_const) {
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
return false;
}
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
NodeVector new_ops;
NodeRegistry rg;
std::shared_ptr<Node> flat_node =
std::make_shared<opset3::Pad>(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT);
new_ops.push_back(flat_node);
auto out_shape = flat_node->get_shape();
shared_ptr<Node> flat_node = rg.make<Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
std::vector<int64_t> dispersed_shape(block_values.size() + 1);
std::vector<size_t> axes_order(block_values.size() + 1);
std::vector<int64_t> squeezed_shape(out_shape.begin(), out_shape.end());
for (int64_t block_idx = block_values.size() - 1; block_idx >= 0; --block_idx) {
int64_t sq_shape_idx = block_values.size() - 1;
shared_ptr<Node> squeezed_shape = rg.make<ShapeOf>(flat_node, block.get_element_type());
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
for (int64_t b_idx = block_length - 1; b_idx >= 0; --b_idx) {
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
const auto block_value = rg.make<Gather>(block, block_index, zero);
NodeVector dispersed_shape_prep;
dispersed_shape_prep.reserve(block_length + 1);
if (b_idx > 0) // avoid addind empty Slice into Concat
dispersed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, zero, block_index, one));
const auto squeezed_element = rg.make<Gather>(squeezed_shape, block_index, zero);
dispersed_shape_prep.push_back(rg.make<Divide>(squeezed_element, block_value));
dispersed_shape_prep.push_back(block_value);
if (b_idx + 1 < block_length) // avoid addind empty Slice into Concat
dispersed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, block_index_next, int_max, one));
const auto dispersed_shape = rg.make<Concat>(dispersed_shape_prep, 0);
constexpr auto special_zero = false;
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
vector<int64_t> axes_order(block_length + 1);
int64_t axis_idx = axes_order.size() - 1;
for (int64_t shape_idx = dispersed_shape.size() - 1; shape_idx >= 0; --shape_idx) {
if (shape_idx == (block_idx + 1)) {
dispersed_shape[shape_idx] = block_values[block_idx];
axes_order[0] = shape_idx;
} else if (shape_idx == block_idx) {
dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx] / block_values[block_idx];
axes_order[axis_idx] = shape_idx;
for (int64_t ds_idx = block_length; ds_idx >= 0; --ds_idx) {
if (ds_idx == (b_idx + 1)) {
axes_order[0] = ds_idx;
} else if (ds_idx == b_idx) {
axes_order[axis_idx] = ds_idx;
axis_idx--;
sq_shape_idx--;
} else {
dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx];
axes_order[axis_idx] = shape_idx;
axes_order[axis_idx] = ds_idx;
axis_idx--;
sq_shape_idx--;
}
}
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
const auto out_pattern_1 =
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
const bool special_zero = false;
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero);
new_ops.push_back(flat_node);
// don't change squeezed_shape at the last iteration, block[0] is assumed to be 1 by op definion
if (b_idx > 0) {
NodeVector squeezed_shape_prep;
squeezed_shape_prep.reserve(block_length);
squeezed_shape_prep.push_back(
rg.make<Multiply>(rg.make<Gather>(squeezed_shape, zero, zero), block_value));
if (b_idx > 1) { // avoid addind empty Slice into Concat
squeezed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, one, block_index, one));
}
squeezed_shape_prep.push_back(
rg.make<Divide>(rg.make<Gather>(squeezed_shape, block_index, zero), block_value));
if (b_idx + 1 < block_length) { // avoid addind empty Slice into Concat
squeezed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, block_index_next, int_max, one));
}
const auto axes_order_const =
opset3::Constant::create(element::i64,
Shape{axes_order.size()},
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
new_ops.push_back(flat_node);
squeezed_shape[0] *= block_values[block_idx];
squeezed_shape[block_idx] /= block_values[block_idx];
const auto out_pattern_2 =
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
new_ops.push_back(flat_node);
squeezed_shape = rg.make<Concat>(squeezed_shape_prep, 0);
}
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
}
flat_node->set_friendly_name(space_to_batch->get_friendly_name());
ngraph::copy_runtime_info(space_to_batch, new_ops);
ngraph::replace_node(space_to_batch, flat_node);
copy_runtime_info(space_to_batch, rg.get());
replace_node(space_to_batch, flat_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(space_to_batch, matcher_name);
const auto m = make_shared<pattern::Matcher>(space_to_batch, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -8,6 +8,7 @@
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pass/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <queue>
#include <sstream>
#include <string>
@ -19,6 +20,7 @@
#include "common_test_utils/ngraph_test_utils.hpp"
#include "common_test_utils/test_common.hpp"
using namespace std;
using namespace testing;
using namespace ngraph;
@ -35,6 +37,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecompositionByElements) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertBatchToSpace>();
manager.register_pass<ov::pass::ConstantFolding>();
}
{
@ -93,6 +96,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecompositionByElements) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertSpaceToBatch>();
manager.register_pass<ov::pass::ConstantFolding>();
}
{
@ -159,6 +163,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecomposition) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertSpaceToBatch>(false);
manager.register_pass<ov::pass::ConstantFolding>();
}
{
@ -195,6 +200,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) {
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
manager.register_pass<ov::pass::ConvertBatchToSpace>(false);
manager.register_pass<ov::pass::ConstantFolding>();
}
{
@ -218,3 +224,156 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
}
}
template <typename Op, typename Conversion, typename Params>
void op_convertion_type_test(const Params& params) {
using namespace ov::opset10;
using namespace ov::pass;
const auto by_elements = get<0>(params);
const auto block_elem_type = get<1>(params);
const auto data = make_shared<Parameter>(element::f32, Shape{1, 1});
const auto block_p = Constant::create(block_elem_type, Shape{2}, {1, 1});
const auto input_2_p = Constant::create(block_elem_type, Shape{2}, {0, 0});
const auto input_3_p = Constant::create(block_elem_type, Shape{2}, {0, 0});
const auto bts_or_stb = make_shared<Op>(data, block_p, input_2_p, input_3_p);
const auto f = make_shared<Function>(NodeVector{bts_or_stb}, ParameterVector{data});
Manager m;
m.register_pass<Conversion>(by_elements);
m.register_pass<ConstantFolding>();
ASSERT_NO_THROW(m.run_passes(f));
EXPECT_EQ(f->get_result()->get_input_shape(0), (Shape{1, 1}));
}
using ElementTypeParams = tuple<bool, // by_elements
element::Type // block element type
>;
class BatchToSpaceDecomposition2D : public testing::WithParamInterface<ElementTypeParams>,
public TransformationTests {};
TEST_P(BatchToSpaceDecomposition2D, BlockElemType) {
op_convertion_type_test<ov::opset10::BatchToSpace, ov::pass::ConvertBatchToSpace>(GetParam());
}
INSTANTIATE_TEST_SUITE_P(TransformationTests,
BatchToSpaceDecomposition2D,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn({element::i32, element::i64})));
class SpaceToBatchDecomposition2D : public testing::WithParamInterface<ElementTypeParams>,
public TransformationTests {};
TEST_P(SpaceToBatchDecomposition2D, BlockElemType) {
op_convertion_type_test<ov::opset10::SpaceToBatch, ov::pass::ConvertSpaceToBatch>(GetParam());
}
INSTANTIATE_TEST_SUITE_P(TransformationTests,
SpaceToBatchDecomposition2D,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn({element::i32, element::i64})));
template <typename Op, typename Conversion, typename Params>
void op_convertion_test(const Params& params) {
using namespace ov::opset10;
using namespace ov::pass;
const bool by_elements = get<0>(params);
Shape data_shape;
Shape expected_output_shape;
vector<int64_t> block;
vector<int64_t> input_2; // crops_begin or pads_begin
vector<int64_t> input_3; // crops_end or pads_end
tie(data_shape, block, input_2, input_3, expected_output_shape) = get<1>(params);
const auto data = make_shared<Parameter>(element::f32, PartialShape::dynamic(data_shape.size()));
const auto block_p = Constant::create(element::i64, Shape{block.size()}, block);
const auto input_2_p = Constant::create(element::i64, Shape{input_2.size()}, input_2);
const auto input_3_p = Constant::create(element::i64, Shape{input_3.size()}, input_3);
const auto bts_or_stb = make_shared<Op>(data, block_p, input_2_p, input_3_p);
const auto f = make_shared<Function>(NodeVector{bts_or_stb}, ParameterVector{data});
Manager m;
m.set_per_pass_validation(false);
m.register_pass<Conversion>(by_elements);
m.run_passes(f);
ASSERT_EQ(count_ops_of_type<Op>(f), 0);
EXPECT_TRUE(f->get_result()->get_input_partial_shape(0).is_dynamic());
data->set_partial_shape(data_shape);
f->validate_nodes_and_infer_types();
ASSERT_EQ(f->get_result()->get_input_shape(0), expected_output_shape);
}
template <typename Params>
string get_test_name(testing::TestParamInfo<Params> obj) {
const auto& params = obj.param;
const bool by_elements = get<0>(params);
const auto& data_shape = get<0>(get<1>(params));
ostringstream result;
result << data_shape.size() << "D" << (by_elements ? "_by_elements" : "");
return result.str();
}
using BatchToSpaceParams = tuple<Shape, // data_shape
vector<int64_t>, // block
vector<int64_t>, // crops_begin
vector<int64_t>, // crops_end
Shape // expected_output_shape
>;
using BatchToSpaceDecomposeParams = tuple<bool, // by_elements
BatchToSpaceParams>;
class BatchToSpaceDecompositionWithParams : public testing::WithParamInterface<BatchToSpaceDecomposeParams>,
public TransformationTests {};
TEST_P(BatchToSpaceDecompositionWithParams, DynamicInputs) {
op_convertion_test<ov::opset10::BatchToSpace, ov::pass::ConvertBatchToSpace>(GetParam());
}
static vector<BatchToSpaceParams> batch_to_space_params = {
{{4, 3}, {1, 2}, {0, 0}, {0, 0}, {2, 6}},
{{6, 5, 7}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {1, 8, 17}},
{{30, 4, 1, 1}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {1, 20, 3, 2}},
{{96, 3, 5, 7, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 12, 15, 14, 1}},
};
INSTANTIATE_TEST_SUITE_P(TransformationTests,
BatchToSpaceDecompositionWithParams,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn(batch_to_space_params)),
get_test_name<BatchToSpaceDecomposeParams>);
using SpaceToBatchParams = tuple<Shape, // data_shape
vector<int64_t>, // block
vector<int64_t>, // pads_begin
vector<int64_t>, // pads_end
Shape // expected_output_shape
>;
using SpaceToBatchDecomposeParams = tuple<bool, // by_elements
SpaceToBatchParams>;
class SpaceToBatchDecompositionWithParams : public testing::WithParamInterface<SpaceToBatchDecomposeParams>,
public TransformationTests {};
TEST_P(SpaceToBatchDecompositionWithParams, DynamicInputs) {
op_convertion_test<ov::opset10::SpaceToBatch, ov::pass::ConvertSpaceToBatch>(GetParam());
}
static vector<SpaceToBatchParams> space_to_batch_params = {
{{2, 6}, {1, 2}, {0, 0}, {0, 0}, {4, 3}},
{{1, 8, 17}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {6, 5, 7}},
{{1, 20, 3, 2}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {30, 4, 1, 1}},
{{4, 12, 15, 14, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {96, 3, 5, 7, 1}},
};
INSTANTIATE_TEST_SUITE_P(TransformationTests,
SpaceToBatchDecompositionWithParams,
::testing::Combine(::testing::ValuesIn({false, true}),
::testing::ValuesIn(space_to_batch_params)),
get_test_name<SpaceToBatchDecomposeParams>);