[LPT] SharedValueAttribute extending (#9534)
* [LPT] SharedValueAttribute extending * [LPT] tests + comments
This commit is contained in:
parent
0ff88458f9
commit
bd97d1edc6
@ -54,13 +54,18 @@ public:
|
|||||||
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::LPT_LT, "CreatePrecisionsDependentAttribute");
|
OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::LPT_LT, "CreatePrecisionsDependentAttribute");
|
||||||
auto &rt = node->get_rt_info();
|
auto &rt = node->get_rt_info();
|
||||||
|
|
||||||
|
// The goal is definition if an operation precision preserved or not. As result here we should make 3 steps:
|
||||||
|
// Step #1: create PrecisionPreservedAttribute instance obviously,
|
||||||
|
// which will be used as result (will be used for future precision propagation)
|
||||||
const auto precisionPreservedAttribute = PrecisionPreservedAttribute(false);
|
const auto precisionPreservedAttribute = PrecisionPreservedAttribute(false);
|
||||||
rt[PrecisionPreservedAttribute::get_type_info_static()] = precisionPreservedAttribute;
|
rt[PrecisionPreservedAttribute::get_type_info_static()] = precisionPreservedAttribute;
|
||||||
const auto &targetSharedValue = precisionPreservedAttribute.attribute->sharedValue;
|
const auto &targetSharedValue = precisionPreservedAttribute.attribute->sharedValue;
|
||||||
|
|
||||||
|
// Step #2: create AttributeType attribute instance for OperationType operation to propagate the instance
|
||||||
const auto attribute = AttributeType{};
|
const auto attribute = AttributeType{};
|
||||||
rt[AttributeType::get_type_info_static()] = attribute;
|
rt[AttributeType::get_type_info_static()] = attribute;
|
||||||
|
|
||||||
|
// Step #3: assign the same shared value to enable PrecisionPreservedAttribute update during AttributeType propagation
|
||||||
ngraph::pass::low_precision::NetworkHelper::reassign<AttributeType>(
|
ngraph::pass::low_precision::NetworkHelper::reassign<AttributeType>(
|
||||||
targetSharedValue,
|
targetSharedValue,
|
||||||
{
|
{
|
||||||
|
@ -230,7 +230,7 @@ public:
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
attribute->sharedValue = sharedValue;
|
attribute->sharedValue = sharedValue;
|
||||||
sharedValue->attributes.push_back(attribute);
|
sharedValue->addAttribute(attribute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,14 +66,14 @@ public:
|
|||||||
const_cast<AttributeType&>(resultAttribute).merge(toMerge);
|
const_cast<AttributeType&>(resultAttribute).merge(toMerge);
|
||||||
|
|
||||||
for (size_t index = 1ul; index < parentRestrictions.size(); index++) {
|
for (size_t index = 1ul; index < parentRestrictions.size(); index++) {
|
||||||
auto& attributes = parentRestrictions[index].template as<AttributeType>().attribute->sharedValue->attributes;
|
auto& attributes = parentRestrictions[index].template as<AttributeType>().attribute->sharedValue->getAttributes();
|
||||||
for (auto&& attributeWeakPtr : attributes) {
|
for (auto&& attributeWeakPtr : attributes) {
|
||||||
auto attribute = attributeWeakPtr.lock();
|
auto attribute = attributeWeakPtr.lock();
|
||||||
if (attribute == nullptr) {
|
if (attribute == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
attribute->sharedValue = resultAttribute.attribute->sharedValue;
|
attribute->sharedValue = resultAttribute.attribute->sharedValue;
|
||||||
resultAttribute.attribute->sharedValue->attributes.push_back(attribute);
|
resultAttribute.attribute->sharedValue->addAttribute(attribute);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -30,6 +30,30 @@ public:
|
|||||||
SharedValue() = default;
|
SharedValue() = default;
|
||||||
SharedValue(const T& value) : value{value} {}
|
SharedValue(const T& value) : value{value} {}
|
||||||
T value = {};
|
T value = {};
|
||||||
|
void addAttribute(std::weak_ptr<SharedValueAttribute> attribute) {
|
||||||
|
auto attributeLocked = attribute.lock();
|
||||||
|
if (attributeLocked == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& attr : attributes) {
|
||||||
|
auto attrLocked = attr.lock();
|
||||||
|
if (attrLocked == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (attributeLocked == attrLocked) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
attributes.push_back(attribute);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::weak_ptr<SharedValueAttribute>>& getAttributes() {
|
||||||
|
return attributes;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
std::vector<std::weak_ptr<SharedValueAttribute>> attributes;
|
std::vector<std::weak_ptr<SharedValueAttribute>> attributes;
|
||||||
};
|
};
|
||||||
SharedValueAttribute() : sharedValue(std::make_shared<SharedValue>()) {}
|
SharedValueAttribute() : sharedValue(std::make_shared<SharedValue>()) {}
|
||||||
@ -49,7 +73,7 @@ public:
|
|||||||
|
|
||||||
bool firstAttribute = true;
|
bool firstAttribute = true;
|
||||||
ss << ", attributes: {";
|
ss << ", attributes: {";
|
||||||
for (auto& attributeWeakPtr : sharedValue->attributes) {
|
for (auto& attributeWeakPtr : sharedValue->getAttributes()) {
|
||||||
auto attribute = attributeWeakPtr.lock();
|
auto attribute = attributeWeakPtr.lock();
|
||||||
if (attribute == nullptr) {
|
if (attribute == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
@ -67,10 +91,10 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
SharedAttribute() : attribute{std::make_shared<SharedValueAttribute>()} {
|
SharedAttribute() : attribute{std::make_shared<SharedValueAttribute>()} {
|
||||||
attribute->sharedValue->attributes.emplace_back(attribute);
|
attribute->sharedValue->addAttribute(attribute);
|
||||||
}
|
}
|
||||||
SharedAttribute(const T& value) : attribute{std::make_shared<SharedValueAttribute>(value)} {
|
SharedAttribute(const T& value) : attribute{std::make_shared<SharedValueAttribute>(value)} {
|
||||||
attribute->sharedValue->attributes.emplace_back(attribute);
|
attribute->sharedValue->addAttribute(attribute);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<SharedValueAttribute> attribute;
|
std::shared_ptr<SharedValueAttribute> attribute;
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
// Copyright (C) 2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "low_precision/network_helper.hpp"
|
||||||
|
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
||||||
|
#include "low_precision/rt_info/avg_pool_precision_preserved_attribute.hpp"
|
||||||
|
|
||||||
|
using LPT_ReshapeTransformation = ::testing::Test;
|
||||||
|
|
||||||
|
TEST(LPT_SharedAttribute, assign) {
|
||||||
|
const auto attribute1 = ngraph::PrecisionPreservedAttribute();
|
||||||
|
ASSERT_EQ(1ul, attribute1.attribute->sharedValue->getAttributes().size());
|
||||||
|
|
||||||
|
const auto attribute2 = ngraph::AvgPoolPrecisionPreservedAttribute();
|
||||||
|
ASSERT_EQ(1ul, attribute2.attribute->sharedValue->getAttributes().size());
|
||||||
|
|
||||||
|
ngraph::pass::low_precision::NetworkHelper::reassign<ngraph::AvgPoolPrecisionPreservedAttribute>(
|
||||||
|
attribute1.attribute->sharedValue,
|
||||||
|
{ attribute1.attribute, attribute2.attribute });
|
||||||
|
|
||||||
|
ASSERT_EQ(2ul, attribute1.attribute->sharedValue->getAttributes().size());
|
||||||
|
ASSERT_EQ(2ul, attribute2.attribute->sharedValue->getAttributes().size());
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user