Remove VariantWrapper for simple attr classes (#7771)

* Remove VariantWrapper for simple attr classes

* Code style
This commit is contained in:
Gleb Kazantaev 2021-10-01 08:03:37 +03:00 committed by GitHub
parent faeaf045a9
commit a16cc81233
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 49 additions and 116 deletions

View File

@ -123,9 +123,8 @@ std::string NetworkCompilationContext::computeHash(const CNNNetwork& network,
} else if (auto fNames =
std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) {
seed = hash_combine(seed, fNames->get().getNames());
} else if (auto prim = std::dynamic_pointer_cast<ngraph::VariantWrapper<ov::PrimitivesPriority>>(
rtMapData.second)) {
seed = hash_combine(seed, prim->get().getPrimitivesPriority());
} else if (auto prim = std::dynamic_pointer_cast<ov::PrimitivesPriority>(rtMapData.second)) {
seed = hash_combine(seed, prim->get());
}
}
}

View File

@ -22,13 +22,11 @@ inline std::string getRTInfoValue(const std::map<std::string, std::shared_ptr<ng
inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) {
const auto &rtInfo = node->get_rt_info();
using PrimitivesPriorityWraper = ngraph::VariantWrapper<ov::PrimitivesPriority>;
if (!rtInfo.count(PrimitivesPriorityWraper::get_type_info_static())) return "";
if (!rtInfo.count(ov::PrimitivesPriority::get_type_info_static())) return "";
const auto &attr = rtInfo.at(PrimitivesPriorityWraper::get_type_info_static());
ov::PrimitivesPriority pp = ngraph::as_type_ptr<PrimitivesPriorityWraper>(attr)->get();
return pp.getPrimitivesPriority();
const auto &attr = rtInfo.at(ov::PrimitivesPriority::get_type_info_static());
return ngraph::as_type_ptr<ov::PrimitivesPriority>(attr)->get();
}
template <typename T>

View File

@ -28,10 +28,10 @@ class TRANSFORMATIONS_API Attributes {
public:
Attributes() {
register_factory<VariantWrapper<ngraph::FusedNames>>();
register_factory<VariantWrapper<PrimitivesPriority>>();
register_factory<VariantWrapper<DisableConstantFolding>>();
register_factory<VariantWrapper<NmsSelectedIndices>>();
register_factory<VariantWrapper<Strides>>();
register_factory<PrimitivesPriority>();
register_factory<DisableConstantFolding>();
register_factory<NmsSelectedIndices>();
register_factory<StridesPropagation>();
}
Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {

View File

@ -17,31 +17,19 @@
namespace ov {
/**
* @ingroup ie_runtime_attr_api
* @brief DisableConstantFolding disable ConstantFolding for given operation
*/
class TRANSFORMATIONS_API DisableConstantFolding {
public:
DisableConstantFolding() = default;
};
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API void enable_constant_folding(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API bool constant_folding_is_disabled(const std::shared_ptr<Node>& node);
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>;
template<>
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
class TRANSFORMATIONS_API DisableConstantFolding : public VariantImpl<bool> {
public:
OPENVINO_RTTI("disabled_constant_folding", "0");
VariantWrapper() = default;
DisableConstantFolding() = default;
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
DisableConstantFolding(const value_type &value) : VariantImpl<value_type>(value) {}
bool is_copyable() const override { return false; }
};

View File

@ -16,25 +16,17 @@
namespace ov {
class TRANSFORMATIONS_API NmsSelectedIndices {
public:
NmsSelectedIndices() = default;
};
TRANSFORMATIONS_API bool has_nms_selected_indices(const Node * node);
TRANSFORMATIONS_API void set_nms_selected_indices(Node * node);
extern template class TRANSFORMATIONS_API VariantImpl<NmsSelectedIndices>;
template<>
class TRANSFORMATIONS_API VariantWrapper<NmsSelectedIndices> : public VariantImpl<NmsSelectedIndices> {
class TRANSFORMATIONS_API NmsSelectedIndices : public VariantImpl<bool> {
public:
OPENVINO_RTTI("nms_selected_indices", "0");
VariantWrapper() = default;
NmsSelectedIndices() = default;
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
NmsSelectedIndices(const value_type &value) : VariantImpl<value_type>(value) {}
bool is_copyable() const override { return false; }
};

View File

@ -19,34 +19,6 @@
#include <ngraph/variant.hpp>
namespace ov {
/**
* @ingroup ie_runtime_attr_api
* @brief PrimitivesPriority class represents runtime info attribute that
* can be used for plugins specific primitive choice.
*/
class NGRAPH_API PrimitivesPriority {
private:
std::string primitives_priority;
public:
friend class VariantWrapper<PrimitivesPriority>;
/**
* A default constructor
*/
PrimitivesPriority() = default;
/**
* @brief Constructs a new object consisting of a single name *
* @param[in] name The primitives priority value
*/
explicit PrimitivesPriority(const std::string &primitives_priority) : primitives_priority(primitives_priority) {}
/**
* @brief return string with primitives priority value
*/
std::string getPrimitivesPriority() const;
};
/**
* @ingroup ie_runtime_attr_api
* @brief getPrimitivesPriority return string with primitive priorities value
@ -54,16 +26,13 @@ public:
*/
NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node);
extern template class NGRAPH_API VariantImpl<PrimitivesPriority>;
template<>
class NGRAPH_API VariantWrapper<PrimitivesPriority> : public VariantImpl<PrimitivesPriority> {
class NGRAPH_API PrimitivesPriority : public VariantImpl<std::string> {
public:
OPENVINO_RTTI("primitives_priority", "0");
VariantWrapper() = default;
PrimitivesPriority() = default;
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
PrimitivesPriority(const value_type &value) : VariantImpl<value_type>(value) {}
std::shared_ptr<ov::Variant> merge(const ngraph::NodeVector & nodes) override;
@ -71,5 +40,4 @@ public:
bool visit_attributes(AttributeVisitor & visitor) override;
};
} // namespace ov

View File

@ -10,14 +10,13 @@
#include <transformations_visibility.hpp>
namespace ov {
template <>
class TRANSFORMATIONS_API VariantWrapper<ngraph::Strides> : public VariantImpl<ngraph::Strides> {
class TRANSFORMATIONS_API StridesPropagation : public VariantImpl<ngraph::Strides> {
public:
OPENVINO_RTTI("strides", "0");
OPENVINO_RTTI("strides_propagation", "0");
VariantWrapper() = default;
StridesPropagation() = default;
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
StridesPropagation(const value_type& value) : VariantImpl<value_type>(value) {}
};
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);

View File

@ -4,19 +4,17 @@
#include "transformations/rt_info/disable_constant_folding.hpp"
template class ov::VariantImpl<ov::DisableConstantFolding>;
void ov::disable_constant_folding(const std::shared_ptr<Node>& node) {
auto & rt_info = node->get_rt_info();
rt_info[VariantWrapper<DisableConstantFolding>::get_type_info_static()] = make_variant<DisableConstantFolding>({});
rt_info[DisableConstantFolding::get_type_info_static()] = std::make_shared<DisableConstantFolding>(true);
}
void ov::enable_constant_folding(const std::shared_ptr<Node>& node) {
auto & rt_info = node->get_rt_info();
rt_info.erase(VariantWrapper<DisableConstantFolding>::get_type_info_static());
rt_info.erase(DisableConstantFolding::get_type_info_static());
}
bool ov::constant_folding_is_disabled(const std::shared_ptr<Node> &node) {
const auto & rt_info = node->get_rt_info();
return rt_info.count(VariantWrapper<DisableConstantFolding>::get_type_info_static());
return rt_info.count(DisableConstantFolding::get_type_info_static());
}

View File

@ -8,10 +8,10 @@ template class ov::VariantImpl<ov::NmsSelectedIndices>;
void ov::set_nms_selected_indices(Node * node) {
auto & rt_info = node->get_rt_info();
rt_info[VariantWrapper<NmsSelectedIndices>::get_type_info_static()] = make_variant<NmsSelectedIndices>({});
rt_info[NmsSelectedIndices::get_type_info_static()] = std::make_shared<NmsSelectedIndices>(true);
}
bool ov::has_nms_selected_indices(const Node * node) {
const auto & rt_info = node->get_rt_info();
return rt_info.count(VariantWrapper<NmsSelectedIndices>::get_type_info_static());
return rt_info.count(NmsSelectedIndices::get_type_info_static());
}

View File

@ -6,17 +6,16 @@
bool ov::has_strides_prop(const ngraph::Input<ngraph::Node>& node) {
const auto& rt_map = node.get_rt_info();
auto it = rt_map.find(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static());
return it != rt_map.end();
return rt_map.count(StridesPropagation::get_type_info_static());
}
ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
const auto& rt_map = node.get_rt_info();
const auto& var = rt_map.at(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static());
return ngraph::as_type_ptr<ngraph::VariantWrapper<ngraph::Strides>>(var)->get();
const auto& var = rt_map.at(StridesPropagation::get_type_info_static());
return ngraph::as_type_ptr<StridesPropagation>(var)->get();
}
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
auto& rt_map = node.get_rt_info();
rt_map[ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static()] = std::make_shared<ngraph::VariantWrapper<ngraph::Strides>>(strides);
rt_map[StridesPropagation::get_type_info_static()] = std::make_shared<StridesPropagation>(strides);
}

View File

@ -32,8 +32,8 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
auto init_info = [](RTMap & info) {
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority("priority"));
info[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>("priority");
};
std::shared_ptr<ngraph::Function> function;
@ -62,11 +62,11 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
ASSERT_TRUE(fused_names_attr);
ASSERT_EQ(fused_names_attr->get().getNames(), "add");
const std::string & pkey = VariantWrapper<ov::PrimitivesPriority>::get_type_info_static();
const std::string & pkey = ov::PrimitivesPriority::get_type_info_static();
ASSERT_TRUE(info.count(pkey));
auto primitives_priority_attr = std::dynamic_pointer_cast<VariantWrapper<ov::PrimitivesPriority>>(info.at(pkey));
auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
ASSERT_TRUE(primitives_priority_attr);
ASSERT_EQ(primitives_priority_attr->get().getPrimitivesPriority(), "priority");
ASSERT_EQ(primitives_priority_attr->get(), "priority");
};
auto add = f->get_results()[0]->get_input_node_ptr(0);

View File

@ -237,12 +237,12 @@ TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority());
rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>("");
};
auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) {
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(name));
rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<ov::PrimitivesPriority>(name);
};
checkCustomRt(setPrimEmpty, setPrim);
}

View File

@ -25,7 +25,7 @@ bool ngraph::pass::InitNodeInfo::run_on_function(std::shared_ptr<ngraph::Functio
using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>;
std::map<std::string, VariantCreator> update_attributes{
{"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> {
return std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(value));
return std::make_shared<ov::PrimitivesPriority>(value);
}}};
for (auto& node : f->get_ops()) {

View File

@ -17,25 +17,17 @@
using namespace ov;
using namespace ngraph;
std::string PrimitivesPriority::getPrimitivesPriority() const {
return primitives_priority;
}
std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) {
const auto& rtInfo = node->get_rt_info();
using PrimitivesPriorityWrapper = VariantWrapper<PrimitivesPriority>;
if (!rtInfo.count(PrimitivesPriorityWrapper::get_type_info_static()))
if (!rtInfo.count(PrimitivesPriority::get_type_info_static()))
return "";
const auto& attr = rtInfo.at(PrimitivesPriorityWrapper::get_type_info_static());
PrimitivesPriority pp = ov::as_type_ptr<PrimitivesPriorityWrapper>(attr)->get();
return pp.getPrimitivesPriority();
const auto& attr = rtInfo.at(PrimitivesPriority::get_type_info_static());
return ov::as_type_ptr<PrimitivesPriority>(attr)->get();
}
template class ov::VariantImpl<PrimitivesPriority>;
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const ngraph::NodeVector& nodes) {
std::shared_ptr<ngraph::Variant> PrimitivesPriority::merge(const ngraph::NodeVector& nodes) {
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
@ -64,14 +56,14 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const
if (unique_pp.size() == 1) {
final_primitives_priority = *unique_pp.begin();
}
return std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority(final_primitives_priority));
return std::make_shared<PrimitivesPriority>(final_primitives_priority);
}
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::init(const std::shared_ptr<ngraph::Node>& node) {
std::shared_ptr<ngraph::Variant> PrimitivesPriority::init(const std::shared_ptr<ngraph::Node>& node) {
throw ngraph_error(std::string(get_type_info()) + " has no default initialization.");
}
bool VariantWrapper<PrimitivesPriority>::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("value", m_value.primitives_priority);
bool PrimitivesPriority::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("value", m_value);
return true;
}