Remove VariantWrapper for simple attr classes (#7771)

* Remove VariantWrapper for simple attr classes

* Code style
This commit is contained in:
Gleb Kazantaev 2021-10-01 08:03:37 +03:00 committed by GitHub
parent faeaf045a9
commit a16cc81233
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 49 additions and 116 deletions

View File

@ -123,9 +123,8 @@ std::string NetworkCompilationContext::computeHash(const CNNNetwork& network,
} else if (auto fNames = } else if (auto fNames =
std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) { std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) {
seed = hash_combine(seed, fNames->get().getNames()); seed = hash_combine(seed, fNames->get().getNames());
} else if (auto prim = std::dynamic_pointer_cast<ngraph::VariantWrapper<ov::PrimitivesPriority>>( } else if (auto prim = std::dynamic_pointer_cast<ov::PrimitivesPriority>(rtMapData.second)) {
rtMapData.second)) { seed = hash_combine(seed, prim->get());
seed = hash_combine(seed, prim->get().getPrimitivesPriority());
} }
} }
} }

View File

@ -22,13 +22,11 @@ inline std::string getRTInfoValue(const std::map<std::string, std::shared_ptr<ng
inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) { inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) {
const auto &rtInfo = node->get_rt_info(); const auto &rtInfo = node->get_rt_info();
using PrimitivesPriorityWraper = ngraph::VariantWrapper<ov::PrimitivesPriority>;
if (!rtInfo.count(PrimitivesPriorityWraper::get_type_info_static())) return ""; if (!rtInfo.count(ov::PrimitivesPriority::get_type_info_static())) return "";
const auto &attr = rtInfo.at(PrimitivesPriorityWraper::get_type_info_static()); const auto &attr = rtInfo.at(ov::PrimitivesPriority::get_type_info_static());
ov::PrimitivesPriority pp = ngraph::as_type_ptr<PrimitivesPriorityWraper>(attr)->get(); return ngraph::as_type_ptr<ov::PrimitivesPriority>(attr)->get();
return pp.getPrimitivesPriority();
} }
template <typename T> template <typename T>

View File

@ -28,10 +28,10 @@ class TRANSFORMATIONS_API Attributes {
public: public:
Attributes() { Attributes() {
register_factory<VariantWrapper<ngraph::FusedNames>>(); register_factory<VariantWrapper<ngraph::FusedNames>>();
register_factory<VariantWrapper<PrimitivesPriority>>(); register_factory<PrimitivesPriority>();
register_factory<VariantWrapper<DisableConstantFolding>>(); register_factory<DisableConstantFolding>();
register_factory<VariantWrapper<NmsSelectedIndices>>(); register_factory<NmsSelectedIndices>();
register_factory<VariantWrapper<Strides>>(); register_factory<StridesPropagation>();
} }
Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) { Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {

View File

@ -17,31 +17,19 @@
namespace ov { namespace ov {
/**
* @ingroup ie_runtime_attr_api
* @brief DisableConstantFolding disable ConstantFolding for given operation
*/
class TRANSFORMATIONS_API DisableConstantFolding {
public:
DisableConstantFolding() = default;
};
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node); TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API void enable_constant_folding(const std::shared_ptr<Node>& node); TRANSFORMATIONS_API void enable_constant_folding(const std::shared_ptr<Node>& node);
TRANSFORMATIONS_API bool constant_folding_is_disabled(const std::shared_ptr<Node>& node); TRANSFORMATIONS_API bool constant_folding_is_disabled(const std::shared_ptr<Node>& node);
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>; class TRANSFORMATIONS_API DisableConstantFolding : public VariantImpl<bool> {
template<>
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
public: public:
OPENVINO_RTTI("disabled_constant_folding", "0"); OPENVINO_RTTI("disabled_constant_folding", "0");
VariantWrapper() = default; DisableConstantFolding() = default;
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {} DisableConstantFolding(const value_type &value) : VariantImpl<value_type>(value) {}
bool is_copyable() const override { return false; } bool is_copyable() const override { return false; }
}; };

View File

@ -16,25 +16,17 @@
namespace ov { namespace ov {
class TRANSFORMATIONS_API NmsSelectedIndices {
public:
NmsSelectedIndices() = default;
};
TRANSFORMATIONS_API bool has_nms_selected_indices(const Node * node); TRANSFORMATIONS_API bool has_nms_selected_indices(const Node * node);
TRANSFORMATIONS_API void set_nms_selected_indices(Node * node); TRANSFORMATIONS_API void set_nms_selected_indices(Node * node);
extern template class TRANSFORMATIONS_API VariantImpl<NmsSelectedIndices>; class TRANSFORMATIONS_API NmsSelectedIndices : public VariantImpl<bool> {
template<>
class TRANSFORMATIONS_API VariantWrapper<NmsSelectedIndices> : public VariantImpl<NmsSelectedIndices> {
public: public:
OPENVINO_RTTI("nms_selected_indices", "0"); OPENVINO_RTTI("nms_selected_indices", "0");
VariantWrapper() = default; NmsSelectedIndices() = default;
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {} NmsSelectedIndices(const value_type &value) : VariantImpl<value_type>(value) {}
bool is_copyable() const override { return false; } bool is_copyable() const override { return false; }
}; };

View File

@ -19,34 +19,6 @@
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
namespace ov { namespace ov {
/**
* @ingroup ie_runtime_attr_api
* @brief PrimitivesPriority class represents runtime info attribute that
* can be used for plugins specific primitive choice.
*/
class NGRAPH_API PrimitivesPriority {
private:
std::string primitives_priority;
public:
friend class VariantWrapper<PrimitivesPriority>;
/**
* A default constructor
*/
PrimitivesPriority() = default;
/**
* @brief Constructs a new object consisting of a single name *
* @param[in] name The primitives priority value
*/
explicit PrimitivesPriority(const std::string &primitives_priority) : primitives_priority(primitives_priority) {}
/**
* @brief return string with primitives priority value
*/
std::string getPrimitivesPriority() const;
};
/** /**
* @ingroup ie_runtime_attr_api * @ingroup ie_runtime_attr_api
* @brief getPrimitivesPriority return string with primitive priorities value * @brief getPrimitivesPriority return string with primitive priorities value
@ -54,16 +26,13 @@ public:
*/ */
NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node); NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node);
extern template class NGRAPH_API VariantImpl<PrimitivesPriority>; class NGRAPH_API PrimitivesPriority : public VariantImpl<std::string> {
template<>
class NGRAPH_API VariantWrapper<PrimitivesPriority> : public VariantImpl<PrimitivesPriority> {
public: public:
OPENVINO_RTTI("primitives_priority", "0"); OPENVINO_RTTI("primitives_priority", "0");
VariantWrapper() = default; PrimitivesPriority() = default;
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {} PrimitivesPriority(const value_type &value) : VariantImpl<value_type>(value) {}
std::shared_ptr<ov::Variant> merge(const ngraph::NodeVector & nodes) override; std::shared_ptr<ov::Variant> merge(const ngraph::NodeVector & nodes) override;
@ -71,5 +40,4 @@ public:
bool visit_attributes(AttributeVisitor & visitor) override; bool visit_attributes(AttributeVisitor & visitor) override;
}; };
} // namespace ov } // namespace ov

View File

@ -10,14 +10,13 @@
#include <transformations_visibility.hpp> #include <transformations_visibility.hpp>
namespace ov { namespace ov {
template <> class TRANSFORMATIONS_API StridesPropagation : public VariantImpl<ngraph::Strides> {
class TRANSFORMATIONS_API VariantWrapper<ngraph::Strides> : public VariantImpl<ngraph::Strides> {
public: public:
OPENVINO_RTTI("strides", "0"); OPENVINO_RTTI("strides_propagation", "0");
VariantWrapper() = default; StridesPropagation() = default;
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {} StridesPropagation(const value_type& value) : VariantImpl<value_type>(value) {}
}; };
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node); TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);

View File

@ -4,19 +4,17 @@
#include "transformations/rt_info/disable_constant_folding.hpp" #include "transformations/rt_info/disable_constant_folding.hpp"
template class ov::VariantImpl<ov::DisableConstantFolding>;
void ov::disable_constant_folding(const std::shared_ptr<Node>& node) { void ov::disable_constant_folding(const std::shared_ptr<Node>& node) {
auto & rt_info = node->get_rt_info(); auto & rt_info = node->get_rt_info();
rt_info[VariantWrapper<DisableConstantFolding>::get_type_info_static()] = make_variant<DisableConstantFolding>({}); rt_info[DisableConstantFolding::get_type_info_static()] = std::make_shared<DisableConstantFolding>(true);
} }
void ov::enable_constant_folding(const std::shared_ptr<Node>& node) { void ov::enable_constant_folding(const std::shared_ptr<Node>& node) {
auto & rt_info = node->get_rt_info(); auto & rt_info = node->get_rt_info();
rt_info.erase(VariantWrapper<DisableConstantFolding>::get_type_info_static()); rt_info.erase(DisableConstantFolding::get_type_info_static());
} }
bool ov::constant_folding_is_disabled(const std::shared_ptr<Node> &node) { bool ov::constant_folding_is_disabled(const std::shared_ptr<Node> &node) {
const auto & rt_info = node->get_rt_info(); const auto & rt_info = node->get_rt_info();
return rt_info.count(VariantWrapper<DisableConstantFolding>::get_type_info_static()); return rt_info.count(DisableConstantFolding::get_type_info_static());
} }

View File

@ -8,10 +8,10 @@ template class ov::VariantImpl<ov::NmsSelectedIndices>;
void ov::set_nms_selected_indices(Node * node) { void ov::set_nms_selected_indices(Node * node) {
auto & rt_info = node->get_rt_info(); auto & rt_info = node->get_rt_info();
rt_info[VariantWrapper<NmsSelectedIndices>::get_type_info_static()] = make_variant<NmsSelectedIndices>({}); rt_info[NmsSelectedIndices::get_type_info_static()] = std::make_shared<NmsSelectedIndices>(true);
} }
bool ov::has_nms_selected_indices(const Node * node) { bool ov::has_nms_selected_indices(const Node * node) {
const auto & rt_info = node->get_rt_info(); const auto & rt_info = node->get_rt_info();
return rt_info.count(VariantWrapper<NmsSelectedIndices>::get_type_info_static()); return rt_info.count(NmsSelectedIndices::get_type_info_static());
} }

View File

@ -6,17 +6,16 @@
bool ov::has_strides_prop(const ngraph::Input<ngraph::Node>& node) { bool ov::has_strides_prop(const ngraph::Input<ngraph::Node>& node) {
const auto& rt_map = node.get_rt_info(); const auto& rt_map = node.get_rt_info();
auto it = rt_map.find(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static()); return rt_map.count(StridesPropagation::get_type_info_static());
return it != rt_map.end();
} }
ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) { ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
const auto& rt_map = node.get_rt_info(); const auto& rt_map = node.get_rt_info();
const auto& var = rt_map.at(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static()); const auto& var = rt_map.at(StridesPropagation::get_type_info_static());
return ngraph::as_type_ptr<ngraph::VariantWrapper<ngraph::Strides>>(var)->get(); return ngraph::as_type_ptr<StridesPropagation>(var)->get();
} }
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) { void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
auto& rt_map = node.get_rt_info(); auto& rt_map = node.get_rt_info();
rt_map[ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static()] = std::make_shared<ngraph::VariantWrapper<ngraph::Strides>>(strides); rt_map[StridesPropagation::get_type_info_static()] = std::make_shared<StridesPropagation>(strides);
} }

View File

@ -32,8 +32,8 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
auto init_info = [](RTMap & info) { auto init_info = [](RTMap & info) {
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] = info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add")); std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
info[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] = info[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority("priority")); std::make_shared<ov::PrimitivesPriority>("priority");
}; };
std::shared_ptr<ngraph::Function> function; std::shared_ptr<ngraph::Function> function;
@ -62,11 +62,11 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
ASSERT_TRUE(fused_names_attr); ASSERT_TRUE(fused_names_attr);
ASSERT_EQ(fused_names_attr->get().getNames(), "add"); ASSERT_EQ(fused_names_attr->get().getNames(), "add");
const std::string & pkey = VariantWrapper<ov::PrimitivesPriority>::get_type_info_static(); const std::string & pkey = ov::PrimitivesPriority::get_type_info_static();
ASSERT_TRUE(info.count(pkey)); ASSERT_TRUE(info.count(pkey));
auto primitives_priority_attr = std::dynamic_pointer_cast<VariantWrapper<ov::PrimitivesPriority>>(info.at(pkey)); auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
ASSERT_TRUE(primitives_priority_attr); ASSERT_TRUE(primitives_priority_attr);
ASSERT_EQ(primitives_priority_attr->get().getPrimitivesPriority(), "priority"); ASSERT_EQ(primitives_priority_attr->get(), "priority");
}; };
auto add = f->get_results()[0]->get_input_node_ptr(0); auto add = f->get_results()[0]->get_input_node_ptr(0);

View File

@ -237,12 +237,12 @@ TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) { TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
auto setPrimEmpty = [&](Node::RTMap& rtInfo) { auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] = rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority()); std::make_shared<ov::PrimitivesPriority>("");
}; };
auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) { auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) {
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] = rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(name)); std::make_shared<ov::PrimitivesPriority>(name);
}; };
checkCustomRt(setPrimEmpty, setPrim); checkCustomRt(setPrimEmpty, setPrim);
} }

View File

@ -25,7 +25,7 @@ bool ngraph::pass::InitNodeInfo::run_on_function(std::shared_ptr<ngraph::Functio
using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>; using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>;
std::map<std::string, VariantCreator> update_attributes{ std::map<std::string, VariantCreator> update_attributes{
{"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> { {"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> {
return std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(value)); return std::make_shared<ov::PrimitivesPriority>(value);
}}}; }}};
for (auto& node : f->get_ops()) { for (auto& node : f->get_ops()) {

View File

@ -17,25 +17,17 @@
using namespace ov; using namespace ov;
using namespace ngraph; using namespace ngraph;
std::string PrimitivesPriority::getPrimitivesPriority() const {
return primitives_priority;
}
std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) { std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) {
const auto& rtInfo = node->get_rt_info(); const auto& rtInfo = node->get_rt_info();
using PrimitivesPriorityWrapper = VariantWrapper<PrimitivesPriority>;
if (!rtInfo.count(PrimitivesPriorityWrapper::get_type_info_static())) if (!rtInfo.count(PrimitivesPriority::get_type_info_static()))
return ""; return "";
const auto& attr = rtInfo.at(PrimitivesPriorityWrapper::get_type_info_static()); const auto& attr = rtInfo.at(PrimitivesPriority::get_type_info_static());
PrimitivesPriority pp = ov::as_type_ptr<PrimitivesPriorityWrapper>(attr)->get(); return ov::as_type_ptr<PrimitivesPriority>(attr)->get();
return pp.getPrimitivesPriority();
} }
template class ov::VariantImpl<PrimitivesPriority>; std::shared_ptr<ngraph::Variant> PrimitivesPriority::merge(const ngraph::NodeVector& nodes) {
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const ngraph::NodeVector& nodes) {
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool { auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) || if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
@ -64,14 +56,14 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const
if (unique_pp.size() == 1) { if (unique_pp.size() == 1) {
final_primitives_priority = *unique_pp.begin(); final_primitives_priority = *unique_pp.begin();
} }
return std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority(final_primitives_priority)); return std::make_shared<PrimitivesPriority>(final_primitives_priority);
} }
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::init(const std::shared_ptr<ngraph::Node>& node) { std::shared_ptr<ngraph::Variant> PrimitivesPriority::init(const std::shared_ptr<ngraph::Node>& node) {
throw ngraph_error(std::string(get_type_info()) + " has no default initialization."); throw ngraph_error(std::string(get_type_info()) + " has no default initialization.");
} }
bool VariantWrapper<PrimitivesPriority>::visit_attributes(AttributeVisitor& visitor) { bool PrimitivesPriority::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("value", m_value.primitives_priority); visitor.on_attribute("value", m_value);
return true; return true;
} }