[Transformations] Enable dynamic decomposition BTS and STB ops (#15179)
* Add dynamism for BatchToSpace conversion * Extend dynamism for BatchToSpace conversion Only block input needs to be const 'cse axes_order_const is freaky * Enhace dynamism for BatchToSpace conversion Block input need not be const now. * Add dynamism for STB by elements conversion * Remove const need for crops for BTS by_elements * temp for review * Try to fix output tensor overwrite * Make test to reproduce invalid shape inference * Reproduce the error with template plugin * Fix code style * Fix 0D inputs issue * Remove 0D shape parts before Concat * Apply nested namespaces * Enable non-constant STB Block input * Fix BTS runtime info * Fix STB by elems runtime info * Add dynamism for STB conversion * Add BTS dynamic data test * Add STB dynamic data test * Reduce STB concats * Add tests naming * Edit * style * Consider other block element types * Enhance type test * Use opset10 only * Check block shape
This commit is contained in:
parent
fac03ee5f7
commit
c62be51cc1
@ -4,94 +4,82 @@
|
||||
|
||||
#include "transformations/op_conversions/convert_batch_to_space.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::element;
|
||||
|
||||
void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
|
||||
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space);
|
||||
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>();
|
||||
matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root());
|
||||
if (!batch_to_space) {
|
||||
const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
|
||||
matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
|
||||
if (!batch_to_space || transformation_callback(batch_to_space)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeVector new_ops;
|
||||
auto data = batch_to_space->input_value(0);
|
||||
auto block = batch_to_space->input_value(1);
|
||||
auto crops_begin = batch_to_space->input_value(2);
|
||||
auto crops_end = batch_to_space->input_value(3);
|
||||
NodeRegistry rg;
|
||||
const auto data = batch_to_space->input_value(0);
|
||||
const auto block = batch_to_space->input_value(1);
|
||||
const auto crops_begin = batch_to_space->input_value(2);
|
||||
const auto crops_end = batch_to_space->input_value(3);
|
||||
|
||||
if (data.get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto& data_shape = data.get_shape();
|
||||
|
||||
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr());
|
||||
const auto crops_begin_const = std::dynamic_pointer_cast<opset3::Constant>(crops_begin.get_node_shared_ptr());
|
||||
const auto crops_end_const = std::dynamic_pointer_cast<opset3::Constant>(crops_end.get_node_shared_ptr());
|
||||
|
||||
if (!block_const || !crops_begin_const || !crops_end_const) {
|
||||
return false;
|
||||
const auto data_shape_rank = data.get_partial_shape().rank();
|
||||
if (data_shape_rank.is_dynamic()) {
|
||||
return false; // because StridedSlice masks are std::vector
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
|
||||
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
|
||||
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
|
||||
return false;
|
||||
}
|
||||
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
|
||||
|
||||
// First we have to disperse the data from batch, then rearrange them
|
||||
// so as appropriate chunks of data where close to their destination place.
|
||||
// Finally squeeze data from respective dimensions.ss
|
||||
std::vector<int64_t> dispersed_shape;
|
||||
int64_t b_dim_divider = 1;
|
||||
for (const auto& el : block_values) {
|
||||
b_dim_divider *= el;
|
||||
}
|
||||
// Finally squeeze data from respective dimensions
|
||||
|
||||
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
|
||||
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
|
||||
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
|
||||
const auto block_prod = rg.make<ReduceProd>(block, zero);
|
||||
const auto batch_div = rg.make<Divide>(batch, block_prod);
|
||||
|
||||
// note: B_0 is expected to be 1.
|
||||
// x' = reshape(`data`, [B_1, ..., B_{N - 1}, batch / (B_1 * ... B_{N - 1}), D_1, D_2, ...,
|
||||
// D_{N - 1}]),
|
||||
// where B_i = block_shape[i]
|
||||
dispersed_shape.insert(dispersed_shape.begin(), block_values.begin() + 1, block_values.end());
|
||||
dispersed_shape.push_back(data_shape.at(0) / b_dim_divider);
|
||||
for (size_t i = 1; i < data_shape.size(); ++i) {
|
||||
dispersed_shape.push_back(data_shape.at(i));
|
||||
}
|
||||
|
||||
const auto out_pattern_1 =
|
||||
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
|
||||
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
|
||||
const auto end = rg.make<Constant>(i64, Shape{1}, block_length);
|
||||
const auto block_tail = rg.make<Slice>(block, one, end, one);
|
||||
const auto data_shape_tail = rg.make<Slice>(shape_of_data, one, end, one);
|
||||
const auto dispersed_shape = rg.make<Concat>(OutputVector{block_tail, batch_div, data_shape_tail}, 0);
|
||||
const bool special_zero = false;
|
||||
std::shared_ptr<Node> flat_node = std::make_shared<ov::opset3::Reshape>(data, out_pattern_1, special_zero);
|
||||
new_ops.push_back(flat_node);
|
||||
shared_ptr<Node> flat_node = rg.make<Reshape>(data, dispersed_shape, special_zero);
|
||||
|
||||
// calculate axes to transpose
|
||||
// x'' = transpose(x', [N, N + 1, 0, N + 2, 1, ..., N + N - 1, N - 1])
|
||||
std::vector<size_t> axes_order{block_values.size() - 1};
|
||||
for (size_t i = 0; i < block_values.size() - 1; ++i) {
|
||||
axes_order.push_back(i + block_values.size());
|
||||
vector<int64_t> axes_order{block_length - 1};
|
||||
for (int64_t i = 0; i < block_length - 1; ++i) {
|
||||
axes_order.push_back(i + block_length);
|
||||
axes_order.push_back(i);
|
||||
}
|
||||
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
|
||||
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
|
||||
|
||||
const auto axes_order_const =
|
||||
opset3::Constant::create(element::i64,
|
||||
Shape{axes_order.size()},
|
||||
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
|
||||
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
|
||||
new_ops.push_back(flat_node);
|
||||
// x''' = reshape(x'', [batch / (B_1 * ... * B_{N - 1}), D_1 * B_1, D_2 * B_2, ... , D_{N - 1}
|
||||
// * B_{N - 1}])
|
||||
std::vector<int64_t> squeezed_shape;
|
||||
squeezed_shape.push_back(data_shape.at(0) / b_dim_divider);
|
||||
for (size_t i = 1; i < block_values.size(); ++i) {
|
||||
squeezed_shape.push_back(data_shape.at(i) * block_values.at(i));
|
||||
}
|
||||
|
||||
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
|
||||
flat_node = std::make_shared<opset3::Reshape>(flat_node, out_pattern_2, special_zero);
|
||||
new_ops.push_back(flat_node);
|
||||
const auto squeezed_shape_tail = rg.make<Multiply>(block_tail, data_shape_tail);
|
||||
const auto squeezed_shape = rg.make<Concat>(OutputVector{batch_div, squeezed_shape_tail}, 0);
|
||||
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
|
||||
|
||||
// Crop the start and end of dimensions according to `crops_begin`, `crops_end` to produce
|
||||
// the output of shape:
|
||||
@ -99,129 +87,133 @@ void ov::pass::ConvertBatchToSpace::convert_batch_to_space() {
|
||||
// `y = [batch / (B_1 * ... * B_{N - 1}), crop(D_1 * B_1, crops_begin[1], crops_end[1]),
|
||||
// crop(D_2 * B_2, crops_begin[2], crops_end[2]), ... ,
|
||||
// crop(D_{N - 1} * B_{N - 1}, crops_begin[N - 1], crops_end[N - 1])]`
|
||||
std::vector<int64_t> upperbounds_values;
|
||||
auto flat_node_shape = flat_node->get_shape();
|
||||
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
|
||||
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
|
||||
}
|
||||
const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
|
||||
const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
|
||||
|
||||
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
|
||||
Shape{upperbounds_values.size()},
|
||||
upperbounds_values);
|
||||
|
||||
std::vector<int64_t> begin_mask(data_shape.size(), 0);
|
||||
std::vector<int64_t> end_mask(data_shape.size(), 0);
|
||||
flat_node =
|
||||
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
|
||||
new_ops.push_back(flat_node);
|
||||
const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
|
||||
const auto& end_mask = begin_mask;
|
||||
flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
|
||||
|
||||
flat_node->set_friendly_name(batch_to_space->get_friendly_name());
|
||||
ngraph::copy_runtime_info(batch_to_space, new_ops);
|
||||
ngraph::replace_node(batch_to_space, flat_node);
|
||||
copy_runtime_info(batch_to_space, rg.get());
|
||||
replace_node(batch_to_space, flat_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name);
|
||||
const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
void ov::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() {
|
||||
MATCHER_SCOPE(ConvertBatchToSpace_convert_batch_to_space_by_elements);
|
||||
auto batch_to_space = ngraph::pattern::wrap_type<ov::opset3::BatchToSpace>();
|
||||
const auto batch_to_space = pattern::wrap_type<BatchToSpace>();
|
||||
matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto batch_to_space = std::dynamic_pointer_cast<ov::opset3::BatchToSpace>(m.get_match_root());
|
||||
if (!batch_to_space) {
|
||||
const auto batch_to_space = dynamic_pointer_cast<BatchToSpace>(m.get_match_root());
|
||||
if (!batch_to_space || transformation_callback(batch_to_space)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto data = batch_to_space->input_value(0);
|
||||
const auto data = batch_to_space->input_value(0);
|
||||
|
||||
if (data.get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
auto data_shape = data.get_shape();
|
||||
|
||||
if (transformation_callback(batch_to_space) && (data_shape.size() == 4 || data_shape.size() == 5)) {
|
||||
return false;
|
||||
}
|
||||
auto block = batch_to_space->input_value(1);
|
||||
auto crops_begin = batch_to_space->input_value(2);
|
||||
auto crops_end = batch_to_space->input_value(3);
|
||||
|
||||
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr());
|
||||
const auto crops_begin_const = ov::as_type_ptr<opset3::Constant>(crops_begin.get_node_shared_ptr());
|
||||
const auto crops_end_const = ov::as_type_ptr<opset3::Constant>(crops_end.get_node_shared_ptr());
|
||||
|
||||
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
|
||||
const std::vector<int64_t>& crops_end_values = crops_end_const->cast_vector<int64_t>();
|
||||
|
||||
std::vector<int64_t> dispersed_shape(1);
|
||||
dispersed_shape.insert(dispersed_shape.end(), data_shape.begin(), data_shape.end());
|
||||
std::vector<size_t> axes_order(block_values.size() + 1);
|
||||
std::vector<int64_t> squeezed_shape(data_shape.begin(), data_shape.end());
|
||||
if (squeezed_shape.size() > block_values.size()) {
|
||||
return false;
|
||||
const auto data_shape_rank = data.get_partial_shape().rank();
|
||||
if (data_shape_rank.is_dynamic()) {
|
||||
return false; // because StridedSlice masks are std::vector
|
||||
}
|
||||
|
||||
NodeVector new_ops;
|
||||
const auto block = batch_to_space->input_value(1);
|
||||
const auto crops_begin = batch_to_space->input_value(2);
|
||||
const auto crops_end = batch_to_space->input_value(3);
|
||||
|
||||
std::shared_ptr<Node> flat_node = data.get_node_shared_ptr();
|
||||
for (size_t block_idx = 1; block_idx < block_values.size(); ++block_idx) {
|
||||
dispersed_shape[0] = block_values[block_idx];
|
||||
dispersed_shape[1] /= block_values[block_idx];
|
||||
const auto out_pattern_1 =
|
||||
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
|
||||
const bool special_zero = false;
|
||||
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero);
|
||||
new_ops.push_back(flat_node);
|
||||
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
|
||||
return false;
|
||||
}
|
||||
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
|
||||
|
||||
size_t val = 1;
|
||||
for (size_t axis_idx = 0; axis_idx <= block_values.size(); ++axis_idx) {
|
||||
if ((block_idx + 1) == axis_idx) {
|
||||
NodeRegistry rg;
|
||||
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
|
||||
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
|
||||
const auto two = rg.make<Constant>(i64, Shape{1}, 2);
|
||||
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
|
||||
|
||||
const auto shape_of_data = rg.make<ShapeOf>(data, block.get_element_type());
|
||||
const auto et_zero = rg.make<Constant>(block.get_element_type(), Shape{1}, 0);
|
||||
shared_ptr<Node> dispersed_shape = rg.make<Concat>(OutputVector{et_zero, shape_of_data}, 0);
|
||||
shared_ptr<Node> squeezed_shape = shape_of_data;
|
||||
|
||||
shared_ptr<Node> flat_node = data.get_node_shared_ptr();
|
||||
|
||||
const auto make_concat = [&](OutputVector nodes) {
|
||||
nodes.erase(remove_if(nodes.begin(),
|
||||
nodes.end(),
|
||||
[](const Output<Node>& n) {
|
||||
return n.get_partial_shape().is_static() && n.get_shape().size() > 0 &&
|
||||
n.get_shape()[0] == 0;
|
||||
}),
|
||||
nodes.end());
|
||||
return rg.make<Concat>(nodes, 0);
|
||||
};
|
||||
|
||||
shared_ptr<Node> div;
|
||||
for (int64_t b_idx = 1; b_idx < block_length; ++b_idx) {
|
||||
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
|
||||
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
|
||||
const auto block_value = rg.make<Gather>(block, block_index, zero);
|
||||
|
||||
// dispersed_shape[0] = block[b_idx];
|
||||
// dispersed_shape[1] /= block[b_idx];
|
||||
if (!div) {
|
||||
const auto batch = rg.make<Gather>(shape_of_data, zero, zero);
|
||||
div = rg.make<Divide>(batch, block_value);
|
||||
} else {
|
||||
div = rg.make<Divide>(div, block_value);
|
||||
}
|
||||
auto ds_tail = rg.make<Slice>(dispersed_shape, two, int_max, one);
|
||||
dispersed_shape = make_concat({block_value, div, ds_tail});
|
||||
constexpr auto special_zero = false;
|
||||
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
|
||||
|
||||
vector<int64_t> axes_order(block_length + 1);
|
||||
int64_t val = 1;
|
||||
for (int64_t axis_idx = 0; axis_idx <= block_length; ++axis_idx) {
|
||||
if ((b_idx + 1) == axis_idx) {
|
||||
axes_order[axis_idx] = 0;
|
||||
} else {
|
||||
axes_order[axis_idx] = val;
|
||||
val++;
|
||||
}
|
||||
}
|
||||
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
|
||||
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
|
||||
|
||||
const auto axes_order_const =
|
||||
ov::opset3::Constant::create(element::i64,
|
||||
Shape{axes_order.size()},
|
||||
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
|
||||
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
|
||||
new_ops.push_back(flat_node);
|
||||
// squeezed_shape[0] = dispersed_shape[1];
|
||||
// squeezed_shape[b_idx] *= block[b_idx];
|
||||
const auto sq_slice = rg.make<Slice>(squeezed_shape, one, block_index, one);
|
||||
const auto sq_bidx_dim = rg.make<Gather>(squeezed_shape, block_index, zero);
|
||||
const auto sq_mul = rg.make<Multiply>(sq_bidx_dim, block_value);
|
||||
const auto sq_shape_tail = rg.make<Slice>(squeezed_shape, block_index_next, int_max, one);
|
||||
squeezed_shape.reset();
|
||||
squeezed_shape = make_concat({div, sq_slice, sq_mul, sq_shape_tail});
|
||||
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
|
||||
|
||||
squeezed_shape[0] = dispersed_shape[1];
|
||||
squeezed_shape[block_idx] *= block_values[block_idx];
|
||||
dispersed_shape[block_idx + 1] = squeezed_shape[block_idx];
|
||||
const auto out_pattern_2 =
|
||||
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
|
||||
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
|
||||
new_ops.push_back(flat_node);
|
||||
// dispersed_shape[b_idx + 1] = squeezed_shape[b_idx];
|
||||
const auto ds_front = rg.make<Slice>(dispersed_shape, zero, block_index_next, one);
|
||||
ds_tail = rg.make<Slice>(dispersed_shape, rg.make<Constant>(i64, Shape{1}, b_idx + 2), int_max, one);
|
||||
dispersed_shape = make_concat({ds_front, sq_mul, ds_tail});
|
||||
}
|
||||
|
||||
std::vector<int64_t> upperbounds_values;
|
||||
auto flat_node_shape = flat_node->get_shape();
|
||||
for (size_t i = 0; i < flat_node_shape.size(); ++i) {
|
||||
upperbounds_values.push_back(flat_node_shape.at(i) - crops_end_values.at(i));
|
||||
}
|
||||
const auto upperbounds = opset3::Constant::create(crops_end.get_element_type(),
|
||||
Shape{upperbounds_values.size()},
|
||||
upperbounds_values);
|
||||
const auto shape_of_flat_node = rg.make<ShapeOf>(flat_node, crops_end.get_element_type());
|
||||
const auto upperbounds = rg.make<Subtract>(shape_of_flat_node, crops_end);
|
||||
|
||||
std::vector<int64_t> begin_mask(data_shape.size(), 0);
|
||||
std::vector<int64_t> end_mask(data_shape.size(), 0);
|
||||
flat_node =
|
||||
std::make_shared<opset3::StridedSlice>(flat_node, crops_begin_const, upperbounds, begin_mask, end_mask);
|
||||
new_ops.push_back(flat_node);
|
||||
const auto begin_mask = vector<int64_t>(data_shape_rank.get_length(), 0);
|
||||
const auto& end_mask = begin_mask;
|
||||
flat_node = rg.make<StridedSlice>(flat_node, crops_begin, upperbounds, begin_mask, end_mask);
|
||||
|
||||
flat_node->set_friendly_name(batch_to_space->get_friendly_name());
|
||||
ngraph::copy_runtime_info(batch_to_space, new_ops);
|
||||
ngraph::replace_node(batch_to_space, flat_node);
|
||||
copy_runtime_info(batch_to_space, rg.get());
|
||||
replace_node(batch_to_space, flat_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(batch_to_space, matcher_name);
|
||||
const auto m = make_shared<pattern::Matcher>(batch_to_space, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
@ -4,42 +4,43 @@
|
||||
|
||||
#include "transformations/op_conversions/convert_space_to_batch.hpp"
|
||||
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <openvino/opsets/opset3.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::element;
|
||||
|
||||
void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() {
|
||||
MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch);
|
||||
auto space_to_batch = ngraph::pattern::wrap_type<ov::opset3::SpaceToBatch>();
|
||||
matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto space_to_batch = std::dynamic_pointer_cast<ov::opset3::SpaceToBatch>(m.get_match_root());
|
||||
if (!space_to_batch) {
|
||||
const auto space_to_batch = pattern::wrap_type<SpaceToBatch>();
|
||||
matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
const auto space_to_batch = dynamic_pointer_cast<SpaceToBatch>(m.get_match_root());
|
||||
if (!space_to_batch || transformation_callback(space_to_batch)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeVector new_ops;
|
||||
auto data = space_to_batch->input_value(0);
|
||||
auto block = space_to_batch->input_value(1);
|
||||
auto pads_begin = space_to_batch->input_value(2);
|
||||
auto pads_end = space_to_batch->input_value(3);
|
||||
|
||||
if (data.get_partial_shape().is_dynamic()) {
|
||||
const auto data = space_to_batch->input_value(0);
|
||||
if (data.get_partial_shape().rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto block_const = std::dynamic_pointer_cast<opset3::Constant>(block.get_node_shared_ptr());
|
||||
const auto pads_begin_const = std::dynamic_pointer_cast<opset3::Constant>(pads_begin.get_node_shared_ptr());
|
||||
const auto pads_end_const = std::dynamic_pointer_cast<opset3::Constant>(pads_end.get_node_shared_ptr());
|
||||
const auto block = space_to_batch->input_value(1);
|
||||
const auto pads_begin = space_to_batch->input_value(2);
|
||||
const auto pads_end = space_to_batch->input_value(3);
|
||||
|
||||
if (!block_const || !pads_begin_const || !pads_end_const) {
|
||||
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
|
||||
return false;
|
||||
}
|
||||
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
|
||||
|
||||
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
|
||||
NodeRegistry rg;
|
||||
|
||||
// Zero-pad the start and end of dimensions [D_0, ..., D_{N - 1}] of the input according to
|
||||
// `pads_begin`
|
||||
@ -47,162 +48,159 @@ void ov::pass::ConvertSpaceToBatch::convert_space_to_batch() {
|
||||
// note: P_0 for batch dimension is expected to be 0 (no-padding).
|
||||
// x = [batch + P_0, D_1 + P_1, D_2 + P_2, ..., D_{N - 1} + P_{N - 1}], where P_i =
|
||||
// pads_begin[i] + pads_end[i]
|
||||
std::shared_ptr<Node> flat_node =
|
||||
std::make_shared<opset3::Pad>(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT);
|
||||
auto out_shape = flat_node->get_shape();
|
||||
new_ops.push_back(flat_node);
|
||||
shared_ptr<Node> flat_node = rg.make<Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
const auto out_shape = rg.make<ShapeOf>(flat_node, block.get_element_type());
|
||||
|
||||
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
|
||||
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
|
||||
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
|
||||
|
||||
// First we have to disperse the data from spatial dimensions, then
|
||||
// rearrange them so as appropriate chunks of data where close to their
|
||||
// destination place. Finally squeeze data from respective dimensions.
|
||||
Shape dispersed_shape{out_shape.at(0)};
|
||||
|
||||
// note: B_0 for batch is ignored.
|
||||
// x' = reshape(x, [batch, (D_1 + P_1) / B_1, B_1, (D_2 + P_2) / B_2, B_2, ...,
|
||||
// (D_{N - 1} + P_{N - 1}) / B_{N - 1}, B_{N - 1}]), where B_i = block_shape[i]
|
||||
for (size_t i = 1; i < block_values.size(); ++i) {
|
||||
dispersed_shape.push_back(out_shape.at(i) / block_values.at(i));
|
||||
dispersed_shape.push_back(block_values.at(i));
|
||||
}
|
||||
const auto batch = rg.make<Gather>(out_shape, zero, zero);
|
||||
const auto out_shape_tail = rg.make<Slice>(out_shape, one, int_max, one);
|
||||
const auto block_tail = rg.make<Slice>(block, one, int_max, one);
|
||||
const auto os_tail_div = rg.make<Divide>(out_shape_tail, block_tail);
|
||||
|
||||
const auto out_pattern = opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
|
||||
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern, false);
|
||||
new_ops.push_back(flat_node);
|
||||
// interleave os_tail_div with block_tail
|
||||
const auto c = rg.make<Concat>(NodeVector{os_tail_div, block_tail}, 0);
|
||||
const auto r =
|
||||
rg.make<Reshape>(c, rg.make<Constant>(i64, Shape{2}, vector<int64_t>{2, block_length - 1}), false);
|
||||
const auto t = rg.make<Transpose>(r, rg.make<Constant>(i64, Shape{2}, vector<int64_t>{1, 0}));
|
||||
const auto interleaved = rg.make<Reshape>(t, rg.make<Constant>(i64, Shape{1}, 2 * (block_length - 1)), false);
|
||||
|
||||
const auto dispersed_shape = rg.make<Concat>(NodeVector{batch, interleaved}, 0);
|
||||
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, false);
|
||||
|
||||
// x'' = transpose(x', [2, 4, ..., (N - 1) + (N - 1), 0, 1, 3, ..., N + (N - 1)])
|
||||
std::vector<size_t> axes_order;
|
||||
for (size_t i = 0, j = 2; i < block_values.size() - 1; ++i, j += 2) {
|
||||
vector<int64_t> axes_order;
|
||||
for (int64_t i = 0, j = 2; i < block_length - 1; ++i, j += 2) {
|
||||
axes_order.push_back(j);
|
||||
}
|
||||
axes_order.push_back(0);
|
||||
for (size_t i = 0, j = 1; i < block_values.size() - 1; ++i, j += 2) {
|
||||
for (int64_t i = 0, j = 1; i < block_length - 1; ++i, j += 2) {
|
||||
axes_order.push_back(j);
|
||||
}
|
||||
|
||||
const auto axes_order_const =
|
||||
opset3::Constant::create(element::i64,
|
||||
Shape{axes_order.size()},
|
||||
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
|
||||
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
|
||||
new_ops.push_back(flat_node);
|
||||
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
|
||||
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
|
||||
|
||||
Shape squeezed_shape;
|
||||
int64_t prod = 1;
|
||||
for (const auto& el : block_values) {
|
||||
prod *= el;
|
||||
}
|
||||
|
||||
// y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ...
|
||||
// ,
|
||||
// y = reshape(x'', [batch * B_1 * ... * B_{N - 1}, (D_1 + P_1) / B_1, (D_2 + P_2) / B_2, ...,
|
||||
// (D_{N - 1} + P_{N - 1}) / B_{N - 1}])
|
||||
squeezed_shape.push_back(out_shape.at(0) * prod);
|
||||
for (size_t i = 1; i < block_values.size(); ++i) {
|
||||
squeezed_shape.push_back(out_shape.at(i) / block_values.at(i));
|
||||
}
|
||||
|
||||
const auto out_pattern_2 = opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
|
||||
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, false);
|
||||
new_ops.push_back(flat_node);
|
||||
// note: B_0 is assumed to be 1 by op definion
|
||||
const auto block_prod = rg.make<ReduceProd>(block, zero);
|
||||
const auto squeezed_shape = rg.make<Concat>(NodeVector{rg.make<Multiply>(batch, block_prod), os_tail_div}, 0);
|
||||
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, false);
|
||||
|
||||
flat_node->set_friendly_name(space_to_batch->get_friendly_name());
|
||||
ngraph::copy_runtime_info(space_to_batch, new_ops);
|
||||
ngraph::replace_node(space_to_batch, flat_node);
|
||||
copy_runtime_info(space_to_batch, rg.get());
|
||||
replace_node(space_to_batch, flat_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(space_to_batch, matcher_name);
|
||||
const auto m = make_shared<pattern::Matcher>(space_to_batch, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
void ov::pass::ConvertSpaceToBatch::convert_space_to_batch_by_elements() {
|
||||
MATCHER_SCOPE(ConvertSpaceToBatch_convert_space_to_batch_by_elements);
|
||||
auto space_to_batch = ngraph::pattern::wrap_type<ov::opset3::SpaceToBatch>();
|
||||
const auto space_to_batch = pattern::wrap_type<SpaceToBatch>();
|
||||
matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto space_to_batch = std::dynamic_pointer_cast<ov::opset3::SpaceToBatch>(m.get_match_root());
|
||||
if (!space_to_batch) {
|
||||
const auto space_to_batch = dynamic_pointer_cast<SpaceToBatch>(m.get_match_root());
|
||||
if (!space_to_batch || transformation_callback(space_to_batch)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto data = space_to_batch->input_value(0);
|
||||
|
||||
if (data.get_partial_shape().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto& data_shape = data.get_shape();
|
||||
|
||||
if (transformation_callback(space_to_batch) && (data_shape.size() == 4 || data_shape.size() == 5)) {
|
||||
const auto data = space_to_batch->input_value(0);
|
||||
if (data.get_partial_shape().rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto block = space_to_batch->input_value(1);
|
||||
auto pads_begin = space_to_batch->input_value(2);
|
||||
auto pads_end = space_to_batch->input_value(3);
|
||||
const auto block = space_to_batch->input_value(1);
|
||||
const auto pads_begin = space_to_batch->input_value(2);
|
||||
const auto pads_end = space_to_batch->input_value(3);
|
||||
|
||||
const auto block_const = ov::as_type_ptr<opset3::Constant>(block.get_node_shared_ptr());
|
||||
const auto pads_begin_const = ov::as_type_ptr<opset3::Constant>(pads_begin.get_node_shared_ptr());
|
||||
const auto pads_end_const = ov::as_type_ptr<opset3::Constant>(pads_end.get_node_shared_ptr());
|
||||
|
||||
if (!block_const || !pads_begin_const || !pads_end_const) {
|
||||
if (block.get_partial_shape().is_dynamic() || block.get_shape().size() == 0) {
|
||||
return false;
|
||||
}
|
||||
const std::vector<int64_t>& block_values = block_const->cast_vector<int64_t>();
|
||||
const auto block_length = static_cast<int64_t>(block.get_shape()[0]);
|
||||
|
||||
NodeVector new_ops;
|
||||
NodeRegistry rg;
|
||||
|
||||
std::shared_ptr<Node> flat_node =
|
||||
std::make_shared<opset3::Pad>(data, pads_begin_const, pads_end_const, ngraph::op::PadMode::CONSTANT);
|
||||
new_ops.push_back(flat_node);
|
||||
auto out_shape = flat_node->get_shape();
|
||||
shared_ptr<Node> flat_node = rg.make<Pad>(data, pads_begin, pads_end, op::PadMode::CONSTANT);
|
||||
|
||||
std::vector<int64_t> dispersed_shape(block_values.size() + 1);
|
||||
std::vector<size_t> axes_order(block_values.size() + 1);
|
||||
std::vector<int64_t> squeezed_shape(out_shape.begin(), out_shape.end());
|
||||
for (int64_t block_idx = block_values.size() - 1; block_idx >= 0; --block_idx) {
|
||||
int64_t sq_shape_idx = block_values.size() - 1;
|
||||
shared_ptr<Node> squeezed_shape = rg.make<ShapeOf>(flat_node, block.get_element_type());
|
||||
|
||||
const auto zero = rg.make<Constant>(i64, Shape{1}, 0);
|
||||
const auto one = rg.make<Constant>(i64, Shape{1}, 1);
|
||||
const auto int_max = rg.make<Constant>(i64, Shape{1}, INT_MAX);
|
||||
|
||||
for (int64_t b_idx = block_length - 1; b_idx >= 0; --b_idx) {
|
||||
const auto block_index = rg.make<Constant>(i64, Shape{1}, b_idx);
|
||||
const auto block_index_next = rg.make<Constant>(i64, Shape{1}, b_idx + 1);
|
||||
const auto block_value = rg.make<Gather>(block, block_index, zero);
|
||||
|
||||
NodeVector dispersed_shape_prep;
|
||||
dispersed_shape_prep.reserve(block_length + 1);
|
||||
if (b_idx > 0) // avoid addind empty Slice into Concat
|
||||
dispersed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, zero, block_index, one));
|
||||
const auto squeezed_element = rg.make<Gather>(squeezed_shape, block_index, zero);
|
||||
dispersed_shape_prep.push_back(rg.make<Divide>(squeezed_element, block_value));
|
||||
dispersed_shape_prep.push_back(block_value);
|
||||
if (b_idx + 1 < block_length) // avoid addind empty Slice into Concat
|
||||
dispersed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, block_index_next, int_max, one));
|
||||
|
||||
const auto dispersed_shape = rg.make<Concat>(dispersed_shape_prep, 0);
|
||||
constexpr auto special_zero = false;
|
||||
flat_node = rg.make<Reshape>(flat_node, dispersed_shape, special_zero);
|
||||
|
||||
vector<int64_t> axes_order(block_length + 1);
|
||||
int64_t axis_idx = axes_order.size() - 1;
|
||||
for (int64_t shape_idx = dispersed_shape.size() - 1; shape_idx >= 0; --shape_idx) {
|
||||
if (shape_idx == (block_idx + 1)) {
|
||||
dispersed_shape[shape_idx] = block_values[block_idx];
|
||||
axes_order[0] = shape_idx;
|
||||
} else if (shape_idx == block_idx) {
|
||||
dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx] / block_values[block_idx];
|
||||
axes_order[axis_idx] = shape_idx;
|
||||
for (int64_t ds_idx = block_length; ds_idx >= 0; --ds_idx) {
|
||||
if (ds_idx == (b_idx + 1)) {
|
||||
axes_order[0] = ds_idx;
|
||||
} else if (ds_idx == b_idx) {
|
||||
axes_order[axis_idx] = ds_idx;
|
||||
axis_idx--;
|
||||
sq_shape_idx--;
|
||||
} else {
|
||||
dispersed_shape[shape_idx] = squeezed_shape[sq_shape_idx];
|
||||
axes_order[axis_idx] = shape_idx;
|
||||
axes_order[axis_idx] = ds_idx;
|
||||
axis_idx--;
|
||||
sq_shape_idx--;
|
||||
}
|
||||
}
|
||||
const auto axes_order_const = rg.make<Constant>(i64, Shape{axes_order.size()}, axes_order);
|
||||
flat_node = rg.make<Transpose>(flat_node, axes_order_const);
|
||||
|
||||
const auto out_pattern_1 =
|
||||
opset3::Constant::create(element::i64, Shape{dispersed_shape.size()}, dispersed_shape);
|
||||
const bool special_zero = false;
|
||||
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_1, special_zero);
|
||||
new_ops.push_back(flat_node);
|
||||
// don't change squeezed_shape at the last iteration, block[0] is assumed to be 1 by op definion
|
||||
if (b_idx > 0) {
|
||||
NodeVector squeezed_shape_prep;
|
||||
squeezed_shape_prep.reserve(block_length);
|
||||
squeezed_shape_prep.push_back(
|
||||
rg.make<Multiply>(rg.make<Gather>(squeezed_shape, zero, zero), block_value));
|
||||
if (b_idx > 1) { // avoid addind empty Slice into Concat
|
||||
squeezed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, one, block_index, one));
|
||||
}
|
||||
squeezed_shape_prep.push_back(
|
||||
rg.make<Divide>(rg.make<Gather>(squeezed_shape, block_index, zero), block_value));
|
||||
if (b_idx + 1 < block_length) { // avoid addind empty Slice into Concat
|
||||
squeezed_shape_prep.push_back(rg.make<Slice>(squeezed_shape, block_index_next, int_max, one));
|
||||
}
|
||||
|
||||
const auto axes_order_const =
|
||||
opset3::Constant::create(element::i64,
|
||||
Shape{axes_order.size()},
|
||||
std::vector<int64_t>(axes_order.begin(), axes_order.end()));
|
||||
flat_node = std::make_shared<ov::opset3::Transpose>(flat_node, axes_order_const);
|
||||
new_ops.push_back(flat_node);
|
||||
squeezed_shape[0] *= block_values[block_idx];
|
||||
squeezed_shape[block_idx] /= block_values[block_idx];
|
||||
const auto out_pattern_2 =
|
||||
opset3::Constant::create(element::i64, Shape{squeezed_shape.size()}, squeezed_shape);
|
||||
flat_node = std::make_shared<ov::opset3::Reshape>(flat_node, out_pattern_2, special_zero);
|
||||
new_ops.push_back(flat_node);
|
||||
squeezed_shape = rg.make<Concat>(squeezed_shape_prep, 0);
|
||||
}
|
||||
flat_node = rg.make<Reshape>(flat_node, squeezed_shape, special_zero);
|
||||
}
|
||||
|
||||
flat_node->set_friendly_name(space_to_batch->get_friendly_name());
|
||||
ngraph::copy_runtime_info(space_to_batch, new_ops);
|
||||
ngraph::replace_node(space_to_batch, flat_node);
|
||||
copy_runtime_info(space_to_batch, rg.get());
|
||||
replace_node(space_to_batch, flat_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(space_to_batch, matcher_name);
|
||||
const auto m = make_shared<pattern::Matcher>(space_to_batch, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <openvino/opsets/opset10.hpp>
|
||||
#include <queue>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@ -19,6 +20,7 @@
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
@ -35,6 +37,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecompositionByElements) {
|
||||
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::ConvertBatchToSpace>();
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
}
|
||||
|
||||
{
|
||||
@ -93,6 +96,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecompositionByElements) {
|
||||
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::ConvertSpaceToBatch>();
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
}
|
||||
|
||||
{
|
||||
@ -159,6 +163,7 @@ TEST_F(TransformationTestsF, SpaceToBatchDecomposition) {
|
||||
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::ConvertSpaceToBatch>(false);
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
}
|
||||
|
||||
{
|
||||
@ -195,6 +200,7 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) {
|
||||
std::make_shared<ngraph::Function>(ngraph::NodeVector{batch_to_space}, ngraph::ParameterVector{data});
|
||||
|
||||
manager.register_pass<ov::pass::ConvertBatchToSpace>(false);
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
}
|
||||
|
||||
{
|
||||
@ -218,3 +224,156 @@ TEST_F(TransformationTestsF, BatchToSpaceDecomposition) {
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ss}, ngraph::ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Op, typename Conversion, typename Params>
|
||||
void op_convertion_type_test(const Params& params) {
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass;
|
||||
|
||||
const auto by_elements = get<0>(params);
|
||||
const auto block_elem_type = get<1>(params);
|
||||
|
||||
const auto data = make_shared<Parameter>(element::f32, Shape{1, 1});
|
||||
const auto block_p = Constant::create(block_elem_type, Shape{2}, {1, 1});
|
||||
const auto input_2_p = Constant::create(block_elem_type, Shape{2}, {0, 0});
|
||||
const auto input_3_p = Constant::create(block_elem_type, Shape{2}, {0, 0});
|
||||
const auto bts_or_stb = make_shared<Op>(data, block_p, input_2_p, input_3_p);
|
||||
const auto f = make_shared<Function>(NodeVector{bts_or_stb}, ParameterVector{data});
|
||||
|
||||
Manager m;
|
||||
m.register_pass<Conversion>(by_elements);
|
||||
m.register_pass<ConstantFolding>();
|
||||
ASSERT_NO_THROW(m.run_passes(f));
|
||||
EXPECT_EQ(f->get_result()->get_input_shape(0), (Shape{1, 1}));
|
||||
}
|
||||
|
||||
using ElementTypeParams = tuple<bool, // by_elements
|
||||
element::Type // block element type
|
||||
>;
|
||||
|
||||
class BatchToSpaceDecomposition2D : public testing::WithParamInterface<ElementTypeParams>,
|
||||
public TransformationTests {};
|
||||
|
||||
TEST_P(BatchToSpaceDecomposition2D, BlockElemType) {
|
||||
op_convertion_type_test<ov::opset10::BatchToSpace, ov::pass::ConvertBatchToSpace>(GetParam());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTests,
|
||||
BatchToSpaceDecomposition2D,
|
||||
::testing::Combine(::testing::ValuesIn({false, true}),
|
||||
::testing::ValuesIn({element::i32, element::i64})));
|
||||
|
||||
class SpaceToBatchDecomposition2D : public testing::WithParamInterface<ElementTypeParams>,
|
||||
public TransformationTests {};
|
||||
|
||||
TEST_P(SpaceToBatchDecomposition2D, BlockElemType) {
|
||||
op_convertion_type_test<ov::opset10::SpaceToBatch, ov::pass::ConvertSpaceToBatch>(GetParam());
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTests,
|
||||
SpaceToBatchDecomposition2D,
|
||||
::testing::Combine(::testing::ValuesIn({false, true}),
|
||||
::testing::ValuesIn({element::i32, element::i64})));
|
||||
|
||||
template <typename Op, typename Conversion, typename Params>
|
||||
void op_convertion_test(const Params& params) {
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass;
|
||||
|
||||
const bool by_elements = get<0>(params);
|
||||
Shape data_shape;
|
||||
Shape expected_output_shape;
|
||||
vector<int64_t> block;
|
||||
vector<int64_t> input_2; // crops_begin or pads_begin
|
||||
vector<int64_t> input_3; // crops_end or pads_end
|
||||
tie(data_shape, block, input_2, input_3, expected_output_shape) = get<1>(params);
|
||||
|
||||
const auto data = make_shared<Parameter>(element::f32, PartialShape::dynamic(data_shape.size()));
|
||||
const auto block_p = Constant::create(element::i64, Shape{block.size()}, block);
|
||||
const auto input_2_p = Constant::create(element::i64, Shape{input_2.size()}, input_2);
|
||||
const auto input_3_p = Constant::create(element::i64, Shape{input_3.size()}, input_3);
|
||||
const auto bts_or_stb = make_shared<Op>(data, block_p, input_2_p, input_3_p);
|
||||
const auto f = make_shared<Function>(NodeVector{bts_or_stb}, ParameterVector{data});
|
||||
|
||||
Manager m;
|
||||
m.set_per_pass_validation(false);
|
||||
m.register_pass<Conversion>(by_elements);
|
||||
m.run_passes(f);
|
||||
ASSERT_EQ(count_ops_of_type<Op>(f), 0);
|
||||
EXPECT_TRUE(f->get_result()->get_input_partial_shape(0).is_dynamic());
|
||||
|
||||
data->set_partial_shape(data_shape);
|
||||
f->validate_nodes_and_infer_types();
|
||||
ASSERT_EQ(f->get_result()->get_input_shape(0), expected_output_shape);
|
||||
}
|
||||
|
||||
template <typename Params>
|
||||
string get_test_name(testing::TestParamInfo<Params> obj) {
|
||||
const auto& params = obj.param;
|
||||
const bool by_elements = get<0>(params);
|
||||
const auto& data_shape = get<0>(get<1>(params));
|
||||
|
||||
ostringstream result;
|
||||
result << data_shape.size() << "D" << (by_elements ? "_by_elements" : "");
|
||||
return result.str();
|
||||
}
|
||||
|
||||
using BatchToSpaceParams = tuple<Shape, // data_shape
|
||||
vector<int64_t>, // block
|
||||
vector<int64_t>, // crops_begin
|
||||
vector<int64_t>, // crops_end
|
||||
Shape // expected_output_shape
|
||||
>;
|
||||
|
||||
using BatchToSpaceDecomposeParams = tuple<bool, // by_elements
|
||||
BatchToSpaceParams>;
|
||||
|
||||
class BatchToSpaceDecompositionWithParams : public testing::WithParamInterface<BatchToSpaceDecomposeParams>,
|
||||
public TransformationTests {};
|
||||
|
||||
TEST_P(BatchToSpaceDecompositionWithParams, DynamicInputs) {
|
||||
op_convertion_test<ov::opset10::BatchToSpace, ov::pass::ConvertBatchToSpace>(GetParam());
|
||||
}
|
||||
|
||||
static vector<BatchToSpaceParams> batch_to_space_params = {
|
||||
{{4, 3}, {1, 2}, {0, 0}, {0, 0}, {2, 6}},
|
||||
{{6, 5, 7}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {1, 8, 17}},
|
||||
{{30, 4, 1, 1}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {1, 20, 3, 2}},
|
||||
{{96, 3, 5, 7, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 12, 15, 14, 1}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTests,
|
||||
BatchToSpaceDecompositionWithParams,
|
||||
::testing::Combine(::testing::ValuesIn({false, true}),
|
||||
::testing::ValuesIn(batch_to_space_params)),
|
||||
get_test_name<BatchToSpaceDecomposeParams>);
|
||||
|
||||
using SpaceToBatchParams = tuple<Shape, // data_shape
|
||||
vector<int64_t>, // block
|
||||
vector<int64_t>, // pads_begin
|
||||
vector<int64_t>, // pads_end
|
||||
Shape // expected_output_shape
|
||||
>;
|
||||
|
||||
using SpaceToBatchDecomposeParams = tuple<bool, // by_elements
|
||||
SpaceToBatchParams>;
|
||||
|
||||
class SpaceToBatchDecompositionWithParams : public testing::WithParamInterface<SpaceToBatchDecomposeParams>,
|
||||
public TransformationTests {};
|
||||
|
||||
TEST_P(SpaceToBatchDecompositionWithParams, DynamicInputs) {
|
||||
op_convertion_test<ov::opset10::SpaceToBatch, ov::pass::ConvertSpaceToBatch>(GetParam());
|
||||
}
|
||||
|
||||
static vector<SpaceToBatchParams> space_to_batch_params = {
|
||||
{{2, 6}, {1, 2}, {0, 0}, {0, 0}, {4, 3}},
|
||||
{{1, 8, 17}, {1, 2, 3}, {0, 1, 2}, {0, 1, 2}, {6, 5, 7}},
|
||||
{{1, 20, 3, 2}, {1, 5, 3, 2}, {0, 0, 0, 0}, {0, 0, 0, 0}, {30, 4, 1, 1}},
|
||||
{{4, 12, 15, 14, 1}, {1, 4, 3, 2, 1}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {96, 3, 5, 7, 1}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransformationTests,
|
||||
SpaceToBatchDecompositionWithParams,
|
||||
::testing::Combine(::testing::ValuesIn({false, true}),
|
||||
::testing::ValuesIn(space_to_batch_params)),
|
||||
get_test_name<SpaceToBatchDecomposeParams>);
|
||||
|
Loading…
Reference in New Issue
Block a user