From bd97d1edc68ffc7abd1ef483c1c6c0b0519226a3 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 14 Jan 2022 13:33:07 +0300 Subject: [PATCH] [LPT] SharedValueAttribute extending (#9534) * [LPT] SharedValueAttribute extending * [LPT] tests + comments --- .../create_precisions_dependent_attribute.hpp | 5 ++++ .../include/low_precision/network_helper.hpp | 2 +- .../propagate_through_precision_preserved.hpp | 4 +-- .../rt_info/shared_value_attribute.hpp | 30 +++++++++++++++++-- .../low_precision/shared_attribute_add.cpp | 25 ++++++++++++++++ 5 files changed, 60 insertions(+), 6 deletions(-) create mode 100644 src/tests/unit/inference_engine/transformations/low_precision/shared_attribute_add.cpp diff --git a/src/common/low_precision_transformations/include/low_precision/create_precisions_dependent_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/create_precisions_dependent_attribute.hpp index e157940b12d..fa9023338cc 100644 --- a/src/common/low_precision_transformations/include/low_precision/create_precisions_dependent_attribute.hpp +++ b/src/common/low_precision_transformations/include/low_precision/create_precisions_dependent_attribute.hpp @@ -54,13 +54,18 @@ public: OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::LPT_LT, "CreatePrecisionsDependentAttribute"); auto &rt = node->get_rt_info(); + // The goal is definition if an operation precision preserved or not. As result here we should make 3 steps: + // Step #1: create PrecisionPreservedAttribute instance obviously, + // which will be used as result (will be used for future precision propagation) const auto precisionPreservedAttribute = PrecisionPreservedAttribute(false); rt[PrecisionPreservedAttribute::get_type_info_static()] = precisionPreservedAttribute; const auto &targetSharedValue = precisionPreservedAttribute.attribute->sharedValue; + // Step #2: create AttributeType attribute instance for OperationType operation to propagate the instance const auto attribute = AttributeType{}; rt[AttributeType::get_type_info_static()] = attribute; + // Step #3: assign the same shared value to enable PrecisionPreservedAttribute update during AttributeType propagation ngraph::pass::low_precision::NetworkHelper::reassign( targetSharedValue, { diff --git a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp index fab0b2e56e2..9a143114ea1 100644 --- a/src/common/low_precision_transformations/include/low_precision/network_helper.hpp +++ b/src/common/low_precision_transformations/include/low_precision/network_helper.hpp @@ -230,7 +230,7 @@ public: continue; } attribute->sharedValue = sharedValue; - sharedValue->attributes.push_back(attribute); + sharedValue->addAttribute(attribute); } } diff --git a/src/common/low_precision_transformations/include/low_precision/propagate_through_precision_preserved.hpp b/src/common/low_precision_transformations/include/low_precision/propagate_through_precision_preserved.hpp index cf2512e0a52..50c58b9e4bf 100644 --- a/src/common/low_precision_transformations/include/low_precision/propagate_through_precision_preserved.hpp +++ b/src/common/low_precision_transformations/include/low_precision/propagate_through_precision_preserved.hpp @@ -66,14 +66,14 @@ public: const_cast(resultAttribute).merge(toMerge); for (size_t index = 1ul; index < parentRestrictions.size(); index++) { - auto& attributes = parentRestrictions[index].template as().attribute->sharedValue->attributes; + auto& attributes = parentRestrictions[index].template as().attribute->sharedValue->getAttributes(); for (auto&& attributeWeakPtr : attributes) { auto attribute = attributeWeakPtr.lock(); if (attribute == nullptr) { continue; } attribute->sharedValue = resultAttribute.attribute->sharedValue; - resultAttribute.attribute->sharedValue->attributes.push_back(attribute); + resultAttribute.attribute->sharedValue->addAttribute(attribute); } } diff --git a/src/common/low_precision_transformations/include/low_precision/rt_info/shared_value_attribute.hpp b/src/common/low_precision_transformations/include/low_precision/rt_info/shared_value_attribute.hpp index 5f829d132b7..0950f56668c 100644 --- a/src/common/low_precision_transformations/include/low_precision/rt_info/shared_value_attribute.hpp +++ b/src/common/low_precision_transformations/include/low_precision/rt_info/shared_value_attribute.hpp @@ -30,6 +30,30 @@ public: SharedValue() = default; SharedValue(const T& value) : value{value} {} T value = {}; + void addAttribute(std::weak_ptr attribute) { + auto attributeLocked = attribute.lock(); + if (attributeLocked == nullptr) { + return; + } + + for (auto& attr : attributes) { + auto attrLocked = attr.lock(); + if (attrLocked == nullptr) { + continue; + } + if (attributeLocked == attrLocked) { + return; + } + } + + attributes.push_back(attribute); + } + + std::vector>& getAttributes() { + return attributes; + } + + private: std::vector> attributes; }; SharedValueAttribute() : sharedValue(std::make_shared()) {} @@ -49,7 +73,7 @@ public: bool firstAttribute = true; ss << ", attributes: {"; - for (auto& attributeWeakPtr : sharedValue->attributes) { + for (auto& attributeWeakPtr : sharedValue->getAttributes()) { auto attribute = attributeWeakPtr.lock(); if (attribute == nullptr) { continue; @@ -67,10 +91,10 @@ public: }; SharedAttribute() : attribute{std::make_shared()} { - attribute->sharedValue->attributes.emplace_back(attribute); + attribute->sharedValue->addAttribute(attribute); } SharedAttribute(const T& value) : attribute{std::make_shared(value)} { - attribute->sharedValue->attributes.emplace_back(attribute); + attribute->sharedValue->addAttribute(attribute); } std::shared_ptr attribute; diff --git a/src/tests/unit/inference_engine/transformations/low_precision/shared_attribute_add.cpp b/src/tests/unit/inference_engine/transformations/low_precision/shared_attribute_add.cpp new file mode 100644 index 00000000000..4a241c9c8e2 --- /dev/null +++ b/src/tests/unit/inference_engine/transformations/low_precision/shared_attribute_add.cpp @@ -0,0 +1,25 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/precision_preserved_attribute.hpp" +#include "low_precision/rt_info/avg_pool_precision_preserved_attribute.hpp" + +using LPT_ReshapeTransformation = ::testing::Test; + +TEST(LPT_SharedAttribute, assign) { + const auto attribute1 = ngraph::PrecisionPreservedAttribute(); + ASSERT_EQ(1ul, attribute1.attribute->sharedValue->getAttributes().size()); + + const auto attribute2 = ngraph::AvgPoolPrecisionPreservedAttribute(); + ASSERT_EQ(1ul, attribute2.attribute->sharedValue->getAttributes().size()); + + ngraph::pass::low_precision::NetworkHelper::reassign( + attribute1.attribute->sharedValue, + { attribute1.attribute, attribute2.attribute }); + + ASSERT_EQ(2ul, attribute1.attribute->sharedValue->getAttributes().size()); + ASSERT_EQ(2ul, attribute2.attribute->sharedValue->getAttributes().size()); +}