[LPT] SharedValueAttribute extending (#9534)

* [LPT] SharedValueAttribute extending

* [LPT] tests + comments
This commit is contained in:
Edward Shogulin 2022-01-14 13:33:07 +03:00 committed by GitHub
parent 0ff88458f9
commit bd97d1edc6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 6 deletions

View File

@ -54,13 +54,18 @@ public:
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::LPT_LT, "CreatePrecisionsDependentAttribute");
auto &rt = node->get_rt_info();
// The goal is definition if an operation precision preserved or not. As result here we should make 3 steps:
// Step #1: create PrecisionPreservedAttribute instance obviously,
// which will be used as result (will be used for future precision propagation)
const auto precisionPreservedAttribute = PrecisionPreservedAttribute(false);
rt[PrecisionPreservedAttribute::get_type_info_static()] = precisionPreservedAttribute;
const auto &targetSharedValue = precisionPreservedAttribute.attribute->sharedValue;
// Step #2: create AttributeType attribute instance for OperationType operation to propagate the instance
const auto attribute = AttributeType{};
rt[AttributeType::get_type_info_static()] = attribute;
// Step #3: assign the same shared value to enable PrecisionPreservedAttribute update during AttributeType propagation
ngraph::pass::low_precision::NetworkHelper::reassign<AttributeType>(
targetSharedValue,
{

View File

@ -230,7 +230,7 @@ public:
continue;
}
attribute->sharedValue = sharedValue;
sharedValue->attributes.push_back(attribute);
sharedValue->addAttribute(attribute);
}
}

View File

@ -66,14 +66,14 @@ public:
const_cast<AttributeType&>(resultAttribute).merge(toMerge);
for (size_t index = 1ul; index < parentRestrictions.size(); index++) {
auto& attributes = parentRestrictions[index].template as<AttributeType>().attribute->sharedValue->attributes;
auto& attributes = parentRestrictions[index].template as<AttributeType>().attribute->sharedValue->getAttributes();
for (auto&& attributeWeakPtr : attributes) {
auto attribute = attributeWeakPtr.lock();
if (attribute == nullptr) {
continue;
}
attribute->sharedValue = resultAttribute.attribute->sharedValue;
resultAttribute.attribute->sharedValue->attributes.push_back(attribute);
resultAttribute.attribute->sharedValue->addAttribute(attribute);
}
}

View File

@ -30,6 +30,30 @@ public:
SharedValue() = default;
SharedValue(const T& value) : value{value} {}
T value = {};
void addAttribute(std::weak_ptr<SharedValueAttribute> attribute) {
auto attributeLocked = attribute.lock();
if (attributeLocked == nullptr) {
return;
}
for (auto& attr : attributes) {
auto attrLocked = attr.lock();
if (attrLocked == nullptr) {
continue;
}
if (attributeLocked == attrLocked) {
return;
}
}
attributes.push_back(attribute);
}
std::vector<std::weak_ptr<SharedValueAttribute>>& getAttributes() {
return attributes;
}
private:
std::vector<std::weak_ptr<SharedValueAttribute>> attributes;
};
SharedValueAttribute() : sharedValue(std::make_shared<SharedValue>()) {}
@ -49,7 +73,7 @@ public:
bool firstAttribute = true;
ss << ", attributes: {";
for (auto& attributeWeakPtr : sharedValue->attributes) {
for (auto& attributeWeakPtr : sharedValue->getAttributes()) {
auto attribute = attributeWeakPtr.lock();
if (attribute == nullptr) {
continue;
@ -67,10 +91,10 @@ public:
};
SharedAttribute() : attribute{std::make_shared<SharedValueAttribute>()} {
attribute->sharedValue->attributes.emplace_back(attribute);
attribute->sharedValue->addAttribute(attribute);
}
SharedAttribute(const T& value) : attribute{std::make_shared<SharedValueAttribute>(value)} {
attribute->sharedValue->attributes.emplace_back(attribute);
attribute->sharedValue->addAttribute(attribute);
}
std::shared_ptr<SharedValueAttribute> attribute;

View File

@ -0,0 +1,25 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
#include "low_precision/rt_info/avg_pool_precision_preserved_attribute.hpp"
using LPT_ReshapeTransformation = ::testing::Test;
TEST(LPT_SharedAttribute, assign) {
const auto attribute1 = ngraph::PrecisionPreservedAttribute();
ASSERT_EQ(1ul, attribute1.attribute->sharedValue->getAttributes().size());
const auto attribute2 = ngraph::AvgPoolPrecisionPreservedAttribute();
ASSERT_EQ(1ul, attribute2.attribute->sharedValue->getAttributes().size());
ngraph::pass::low_precision::NetworkHelper::reassign<ngraph::AvgPoolPrecisionPreservedAttribute>(
attribute1.attribute->sharedValue,
{ attribute1.attribute, attribute2.attribute });
ASSERT_EQ(2ul, attribute1.attribute->sharedValue->getAttributes().size());
ASSERT_EQ(2ul, attribute2.attribute->sharedValue->getAttributes().size());
}