Dimension equality fix (#18806)
This commit is contained in:
parent
609a7d7716
commit
812d11cf8c
@ -44,6 +44,8 @@ label_t TableOfEquivalence::get_next_label() {
|
||||
|
||||
bool TableOfEquivalence::are_equal(const Dimension& lhs, const Dimension& rhs) {
|
||||
const auto &l_label = DimensionTracker::get_label(lhs), r_label = DimensionTracker::get_label(rhs);
|
||||
if (l_label == r_label)
|
||||
return true;
|
||||
if (dimension_table_of_equivalence.count(l_label) && dimension_table_of_equivalence[l_label])
|
||||
return dimension_table_of_equivalence[l_label]->count(r_label);
|
||||
return false;
|
||||
|
@ -260,13 +260,13 @@ Dimension resolve_minus_one(const Node* reshape_node,
|
||||
Dimension input_const_part(1), output_const_part(1);
|
||||
|
||||
for (const auto& dim : output_product)
|
||||
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
|
||||
if (dim.is_static()) {
|
||||
output_const_part *= dim;
|
||||
to_delete_from_output.push_back(dim);
|
||||
}
|
||||
|
||||
for (const auto& dim : input_product)
|
||||
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
|
||||
if (dim.is_static()) {
|
||||
input_const_part *= dim;
|
||||
to_delete_from_input.push_back(dim);
|
||||
}
|
||||
|
@ -139,6 +139,9 @@ TEST(dimension, dimension_equality) {
|
||||
EXPECT_NE(DimensionTracker::get_label(A), DimensionTracker::get_label(B));
|
||||
EXPECT_NE(DimensionTracker::get_label(B), DimensionTracker::get_label(C));
|
||||
EXPECT_NE(DimensionTracker::get_label(A), DimensionTracker::get_label(C));
|
||||
EXPECT_EQ(DimensionTracker::get_label(A), DimensionTracker::get_label(A));
|
||||
EXPECT_EQ(DimensionTracker::get_label(B), DimensionTracker::get_label(B));
|
||||
EXPECT_EQ(DimensionTracker::get_label(C), DimensionTracker::get_label(C));
|
||||
|
||||
// setting A == B and B == C
|
||||
te->set_as_equal(A, B);
|
||||
|
Loading…
Reference in New Issue
Block a user