Layout::find_permutation - support of dynamic layouts (#8766)

Covered case for 'trivial convert' where no permutation is needed
It is needed for Model Optimizer for logic which will guess model's layout, like "?c??"
This commit is contained in:
Mikhail Nosov 2021-11-30 12:40:38 +03:00 committed by GitHub
parent e2172cd38a
commit 84a16513df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 198 additions and 16 deletions

View File

@ -275,6 +275,45 @@ Layout apply_permutation(const Layout& src_layout, const std::vector<uint64_t>&
}
std::vector<int64_t> find_permutation(const Layout& src_layout, const Rank& rank, const Layout& dst) {
auto check_trivial = [](std::vector<int64_t>& res) -> std::vector<int64_t>& {
size_t i = 0;
while (res[i] == i && i < res.size()) {
i++;
}
if (i == res.size()) {
// Array is [0,1,2,...,n], so permutation is not needed at all
res = {};
}
return res;
};
auto to_static = [](const Layout& layout, const Rank& rank) -> Layout {
OPENVINO_ASSERT(!layout.m_dynamic || !rank.is_dynamic(),
"Conversion is not supported for dynamic layouts with fully dynamic shapes");
if (!layout.m_dynamic) {
return layout;
}
Layout res = layout;
auto len = rank.get_length();
res.m_dynamic = false;
res.m_left_size = rank.get_length();
res.m_right_size = 0;
for (auto& item : res.m_names) {
if (item.second < 0) {
item.second += len;
}
}
std::unordered_map<std::int64_t, std::string> new_index_map;
for (const auto& item : res.m_index_map) {
auto new_ind = item.first;
if (new_ind < 0) {
new_ind += len;
}
new_index_map[new_ind] = item.second;
}
res.m_index_map = new_index_map;
return res;
};
// Basic implementation so far, can support partially-specified layouts later (shape rank will be needed for dynamic
// layouts)
if (src_layout == dst) {
@ -283,24 +322,54 @@ std::vector<int64_t> find_permutation(const Layout& src_layout, const Rank& rank
if (src_layout.empty() || dst.empty()) {
return {};
}
OPENVINO_ASSERT(!src_layout.m_dynamic && !dst.m_dynamic, "Conversion is not supported for dynamic layouts");
OPENVINO_ASSERT(src_layout.m_left_size == src_layout.m_left_size,
auto src_static = to_static(src_layout, rank);
auto dst_static = to_static(dst, rank);
OPENVINO_ASSERT(src_static.m_left_size == dst_static.m_left_size,
"Conversion is not supported for layouts with different sizes");
std::vector<int64_t> res(src_layout.m_left_size);
for (int64_t i = 0; i < src_layout.m_left_size; i++) {
auto it = src_layout.m_index_map.find(i);
OPENVINO_ASSERT(it != src_layout.m_index_map.end(),
"Conversion is not supported for partially specified source layout: ",
src_layout.to_string());
auto name = it->second;
OPENVINO_ASSERT(dst.has_name(name),
"Source dimension name '",
name,
"' is not found in destination layout: ",
dst.to_string());
res[dst.get_index_by_name(name)] = i;
std::vector<int64_t> res(src_static.m_left_size, -1);
if (src_static.m_names.size() > dst_static.m_names.size()) {
// find inverted permutation from least specified layout to most one
auto inverted = find_permutation(dst_static, rank, src_static);
if (inverted.empty()) {
return {};
}
for (size_t i = 0; i < inverted.size(); i++) {
res[inverted[i]] = i;
}
return check_trivial(res);
}
return res;
std::vector<bool> mapped(src_static.m_left_size, false);
// Fill known names (??c? -> nc??) will produce res=[-1,2,-1,-1], mapped=[false,false,true,false]
for (auto src_item : src_static.m_index_map) {
OPENVINO_ASSERT(dst.has_name(src_item.second),
"Dimension name '",
src_item.second,
"' is not found in layout: ",
dst_static.to_string());
auto dst_ind = dst_static.get_index_by_name(src_item.second);
res[dst_ind] = src_item.first;
mapped[src_item.first] = true;
}
// Fill the rest
int dst_pos = 0;
auto find_free_pos = [&]() {
while (mapped[dst_pos] && dst_pos < src_static.m_left_size) {
dst_pos++;
}
OPENVINO_ASSERT(dst_pos < src_static.m_left_size,
"Internal unexpected error: can't map layout ",
src_static.to_string(),
" to ",
dst_static.to_string());
mapped[dst_pos] = true;
return dst_pos;
};
for (int64_t i = 0; i < src_static.m_left_size; i++) {
if (res[i] < 0) {
res[i] = find_free_pos();
}
}
return check_trivial(res);
}
// Helper functions

View File

@ -740,6 +740,119 @@ TEST(pre_post_process, preprocess_convert_layout_invalid_dims_dyn_shape) {
p.build(), ov::AssertFailure);
}
TEST(pre_post_process, preprocess_convert_layout_partially_defined) {
auto f = create_n_inputs<8>(element::f32, Shape{1, 2, 3, 4, 5});
auto p = PrePostProcessor(f);
p.input(0).tensor().set_layout("nc???");
p.input(0).network().set_layout("????c");
p.input(1).tensor().set_layout("...c??");
p.input(1).network().set_layout("ndhwc");
p.input(2).tensor().set_layout("?cwh...");
p.input(2).network().set_layout("...hwc");
p.input(3).tensor().set_layout("...c");
p.input(3).network().set_layout("c...");
p.input(4).tensor().set_layout("...");
p.input(4).network().set_layout("c...");
p.input(5).tensor().set_layout("...c");
p.input(5).network().set_layout("...");
p.input(6).tensor().set_layout("ndhwc");
p.input(6).network().set_layout("ndh?c");
p.input(7).tensor().set_layout("ndh?c");
p.input(7).network().set_layout("ndhwc");
f = p.build();
EXPECT_EQ(f->input(0).get_partial_shape(), (PartialShape{1, 5, 2, 3, 4}));
EXPECT_EQ(f->input(1).get_partial_shape(), (PartialShape{1, 2, 5, 3, 4}));
EXPECT_EQ(f->input(2).get_partial_shape(), (PartialShape{1, 5, 4, 3, 2}));
EXPECT_EQ(f->input(3).get_partial_shape(), (PartialShape{2, 3, 4, 5, 1}));
EXPECT_EQ(f->input(4).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
EXPECT_EQ(f->input(5).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
EXPECT_EQ(f->input(6).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
EXPECT_EQ(f->input(7).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
}
TEST(pre_post_process, preprocess_convert_layout_partially_defined_trivial) {
auto f = create_n_inputs<4>(element::f32, Shape{1, 2, 3, 4, 5});
auto ops_num = f->get_ordered_ops().size();
auto p = PrePostProcessor(f);
p.input(0).tensor().set_layout("...");
p.input(0).network().set_layout("c...");
p.input(1).tensor().set_layout("...c");
p.input(1).network().set_layout("...");
p.input(2).tensor().set_layout("ndhwc");
p.input(2).network().set_layout("ndh?c");
p.input(3).tensor().set_layout("ndh?c");
p.input(3).network().set_layout("ndhwc");
f = p.build();
EXPECT_EQ(f->input(0).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
EXPECT_EQ(f->input(1).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
EXPECT_EQ(f->input(2).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
EXPECT_EQ(f->input(3).get_partial_shape(), (PartialShape{1, 2, 3, 4, 5}));
// Verify that no preprocessing Nodes are inserted
EXPECT_EQ(ops_num, f->get_ordered_ops().size());
}
TEST(pre_post_process, preprocess_convert_layout_partially_defined_error) {
auto f = create_simple_function(element::f32, Shape{1, 2, 3, 4, 5});
EXPECT_THROW(
{
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("nch??");
p.input().network().set_layout("???wc");
f = p.build();
},
ov::AssertFailure);
EXPECT_THROW(
{
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("nch??");
p.input().network().set_layout("???wc?");
f = p.build();
},
ov::AssertFailure);
}
TEST(pre_post_process, preprocess_convert_layout_partially_defined_error_diff_rank) {
auto f = create_simple_function(element::f32, Shape{1, 2, 3, 4, 5});
}
TEST(pre_post_process, preprocess_convert_layout_partially_defined_error_dyn_rank) {
auto f = create_simple_function(element::f32, PartialShape::dynamic());
EXPECT_THROW(
{
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("nchw");
p.input().network().set_layout("...wc");
f = p.build();
},
ov::AssertFailure);
EXPECT_THROW(
{
auto p = PrePostProcessor(f);
p.input().tensor().set_layout("nchw");
p.input().network().set_layout("??wc?");
f = p.build();
},
ov::AssertFailure);
}
TEST(pre_post_process, preprocess_reverse_channels_multiple_planes) {
auto f = create_simple_function(element::f32, Shape{1, 3, 2, 2});
auto p = PrePostProcessor(f);