[LPT] SharedValueAttribute extending (#9534)
* [LPT] SharedValueAttribute extending * [LPT] tests + comments
This commit is contained in:
parent
0ff88458f9
commit
bd97d1edc6
@ -54,13 +54,18 @@ public:
|
||||
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::LPT_LT, "CreatePrecisionsDependentAttribute");
|
||||
auto &rt = node->get_rt_info();
|
||||
|
||||
// The goal is definition if an operation precision preserved or not. As result here we should make 3 steps:
|
||||
// Step #1: create PrecisionPreservedAttribute instance obviously,
|
||||
// which will be used as result (will be used for future precision propagation)
|
||||
const auto precisionPreservedAttribute = PrecisionPreservedAttribute(false);
|
||||
rt[PrecisionPreservedAttribute::get_type_info_static()] = precisionPreservedAttribute;
|
||||
const auto &targetSharedValue = precisionPreservedAttribute.attribute->sharedValue;
|
||||
|
||||
// Step #2: create AttributeType attribute instance for OperationType operation to propagate the instance
|
||||
const auto attribute = AttributeType{};
|
||||
rt[AttributeType::get_type_info_static()] = attribute;
|
||||
|
||||
// Step #3: assign the same shared value to enable PrecisionPreservedAttribute update during AttributeType propagation
|
||||
ngraph::pass::low_precision::NetworkHelper::reassign<AttributeType>(
|
||||
targetSharedValue,
|
||||
{
|
||||
|
@ -230,7 +230,7 @@ public:
|
||||
continue;
|
||||
}
|
||||
attribute->sharedValue = sharedValue;
|
||||
sharedValue->attributes.push_back(attribute);
|
||||
sharedValue->addAttribute(attribute);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,14 +66,14 @@ public:
|
||||
const_cast<AttributeType&>(resultAttribute).merge(toMerge);
|
||||
|
||||
for (size_t index = 1ul; index < parentRestrictions.size(); index++) {
|
||||
auto& attributes = parentRestrictions[index].template as<AttributeType>().attribute->sharedValue->attributes;
|
||||
auto& attributes = parentRestrictions[index].template as<AttributeType>().attribute->sharedValue->getAttributes();
|
||||
for (auto&& attributeWeakPtr : attributes) {
|
||||
auto attribute = attributeWeakPtr.lock();
|
||||
if (attribute == nullptr) {
|
||||
continue;
|
||||
}
|
||||
attribute->sharedValue = resultAttribute.attribute->sharedValue;
|
||||
resultAttribute.attribute->sharedValue->attributes.push_back(attribute);
|
||||
resultAttribute.attribute->sharedValue->addAttribute(attribute);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -30,6 +30,30 @@ public:
|
||||
SharedValue() = default;
|
||||
SharedValue(const T& value) : value{value} {}
|
||||
T value = {};
|
||||
void addAttribute(std::weak_ptr<SharedValueAttribute> attribute) {
|
||||
auto attributeLocked = attribute.lock();
|
||||
if (attributeLocked == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto& attr : attributes) {
|
||||
auto attrLocked = attr.lock();
|
||||
if (attrLocked == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (attributeLocked == attrLocked) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
attributes.push_back(attribute);
|
||||
}
|
||||
|
||||
std::vector<std::weak_ptr<SharedValueAttribute>>& getAttributes() {
|
||||
return attributes;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::weak_ptr<SharedValueAttribute>> attributes;
|
||||
};
|
||||
SharedValueAttribute() : sharedValue(std::make_shared<SharedValue>()) {}
|
||||
@ -49,7 +73,7 @@ public:
|
||||
|
||||
bool firstAttribute = true;
|
||||
ss << ", attributes: {";
|
||||
for (auto& attributeWeakPtr : sharedValue->attributes) {
|
||||
for (auto& attributeWeakPtr : sharedValue->getAttributes()) {
|
||||
auto attribute = attributeWeakPtr.lock();
|
||||
if (attribute == nullptr) {
|
||||
continue;
|
||||
@ -67,10 +91,10 @@ public:
|
||||
};
|
||||
|
||||
SharedAttribute() : attribute{std::make_shared<SharedValueAttribute>()} {
|
||||
attribute->sharedValue->attributes.emplace_back(attribute);
|
||||
attribute->sharedValue->addAttribute(attribute);
|
||||
}
|
||||
SharedAttribute(const T& value) : attribute{std::make_shared<SharedValueAttribute>(value)} {
|
||||
attribute->sharedValue->attributes.emplace_back(attribute);
|
||||
attribute->sharedValue->addAttribute(attribute);
|
||||
}
|
||||
|
||||
std::shared_ptr<SharedValueAttribute> attribute;
|
||||
|
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
||||
#include "low_precision/rt_info/avg_pool_precision_preserved_attribute.hpp"
|
||||
|
||||
using LPT_ReshapeTransformation = ::testing::Test;
|
||||
|
||||
TEST(LPT_SharedAttribute, assign) {
|
||||
const auto attribute1 = ngraph::PrecisionPreservedAttribute();
|
||||
ASSERT_EQ(1ul, attribute1.attribute->sharedValue->getAttributes().size());
|
||||
|
||||
const auto attribute2 = ngraph::AvgPoolPrecisionPreservedAttribute();
|
||||
ASSERT_EQ(1ul, attribute2.attribute->sharedValue->getAttributes().size());
|
||||
|
||||
ngraph::pass::low_precision::NetworkHelper::reassign<ngraph::AvgPoolPrecisionPreservedAttribute>(
|
||||
attribute1.attribute->sharedValue,
|
||||
{ attribute1.attribute, attribute2.attribute });
|
||||
|
||||
ASSERT_EQ(2ul, attribute1.attribute->sharedValue->getAttributes().size());
|
||||
ASSERT_EQ(2ul, attribute2.attribute->sharedValue->getAttributes().size());
|
||||
}
|
Loading…
Reference in New Issue
Block a user