Dimension equality fix (#18806)

This commit is contained in:
Evgenya Stepyreva 2023-07-27 14:01:45 +04:00 committed by GitHub
parent 609a7d7716
commit 812d11cf8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 2 deletions

View File

@ -44,6 +44,8 @@ label_t TableOfEquivalence::get_next_label() {
bool TableOfEquivalence::are_equal(const Dimension& lhs, const Dimension& rhs) {
const auto &l_label = DimensionTracker::get_label(lhs), r_label = DimensionTracker::get_label(rhs);
if (l_label == r_label)
return true;
if (dimension_table_of_equivalence.count(l_label) && dimension_table_of_equivalence[l_label])
return dimension_table_of_equivalence[l_label]->count(r_label);
return false;

View File

@ -260,13 +260,13 @@ Dimension resolve_minus_one(const Node* reshape_node,
Dimension input_const_part(1), output_const_part(1);
for (const auto& dim : output_product)
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
if (dim.is_static()) {
output_const_part *= dim;
to_delete_from_output.push_back(dim);
}
for (const auto& dim : input_product)
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
if (dim.is_static()) {
input_const_part *= dim;
to_delete_from_input.push_back(dim);
}

View File

@ -139,6 +139,9 @@ TEST(dimension, dimension_equality) {
EXPECT_NE(DimensionTracker::get_label(A), DimensionTracker::get_label(B));
EXPECT_NE(DimensionTracker::get_label(B), DimensionTracker::get_label(C));
EXPECT_NE(DimensionTracker::get_label(A), DimensionTracker::get_label(C));
EXPECT_EQ(DimensionTracker::get_label(A), DimensionTracker::get_label(A));
EXPECT_EQ(DimensionTracker::get_label(B), DimensionTracker::get_label(B));
EXPECT_EQ(DimensionTracker::get_label(C), DimensionTracker::get_label(C));
// setting A == B and B == C
te->set_as_equal(A, B);