Remove VariantWrapper for simple attr classes (#7771)
* Remove VariantWrapper for simple attr classes * Code style
This commit is contained in:
parent
faeaf045a9
commit
a16cc81233
@ -123,9 +123,8 @@ std::string NetworkCompilationContext::computeHash(const CNNNetwork& network,
|
||||
} else if (auto fNames =
|
||||
std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) {
|
||||
seed = hash_combine(seed, fNames->get().getNames());
|
||||
} else if (auto prim = std::dynamic_pointer_cast<ngraph::VariantWrapper<ov::PrimitivesPriority>>(
|
||||
rtMapData.second)) {
|
||||
seed = hash_combine(seed, prim->get().getPrimitivesPriority());
|
||||
} else if (auto prim = std::dynamic_pointer_cast<ov::PrimitivesPriority>(rtMapData.second)) {
|
||||
seed = hash_combine(seed, prim->get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -22,13 +22,11 @@ inline std::string getRTInfoValue(const std::map<std::string, std::shared_ptr<ng
|
||||
|
||||
inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) {
|
||||
const auto &rtInfo = node->get_rt_info();
|
||||
using PrimitivesPriorityWraper = ngraph::VariantWrapper<ov::PrimitivesPriority>;
|
||||
|
||||
if (!rtInfo.count(PrimitivesPriorityWraper::get_type_info_static())) return "";
|
||||
if (!rtInfo.count(ov::PrimitivesPriority::get_type_info_static())) return "";
|
||||
|
||||
const auto &attr = rtInfo.at(PrimitivesPriorityWraper::get_type_info_static());
|
||||
ov::PrimitivesPriority pp = ngraph::as_type_ptr<PrimitivesPriorityWraper>(attr)->get();
|
||||
return pp.getPrimitivesPriority();
|
||||
const auto &attr = rtInfo.at(ov::PrimitivesPriority::get_type_info_static());
|
||||
return ngraph::as_type_ptr<ov::PrimitivesPriority>(attr)->get();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -28,10 +28,10 @@ class TRANSFORMATIONS_API Attributes {
|
||||
public:
|
||||
Attributes() {
|
||||
register_factory<VariantWrapper<ngraph::FusedNames>>();
|
||||
register_factory<VariantWrapper<PrimitivesPriority>>();
|
||||
register_factory<VariantWrapper<DisableConstantFolding>>();
|
||||
register_factory<VariantWrapper<NmsSelectedIndices>>();
|
||||
register_factory<VariantWrapper<Strides>>();
|
||||
register_factory<PrimitivesPriority>();
|
||||
register_factory<DisableConstantFolding>();
|
||||
register_factory<NmsSelectedIndices>();
|
||||
register_factory<StridesPropagation>();
|
||||
}
|
||||
|
||||
Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {
|
||||
|
@ -17,31 +17,19 @@
|
||||
|
||||
namespace ov {
|
||||
|
||||
/**
|
||||
* @ingroup ie_runtime_attr_api
|
||||
* @brief DisableConstantFolding disable ConstantFolding for given operation
|
||||
*/
|
||||
class TRANSFORMATIONS_API DisableConstantFolding {
|
||||
public:
|
||||
DisableConstantFolding() = default;
|
||||
};
|
||||
|
||||
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
|
||||
|
||||
TRANSFORMATIONS_API void enable_constant_folding(const std::shared_ptr<Node>& node);
|
||||
|
||||
TRANSFORMATIONS_API bool constant_folding_is_disabled(const std::shared_ptr<Node>& node);
|
||||
|
||||
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>;
|
||||
|
||||
template<>
|
||||
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
|
||||
class TRANSFORMATIONS_API DisableConstantFolding : public VariantImpl<bool> {
|
||||
public:
|
||||
OPENVINO_RTTI("disabled_constant_folding", "0");
|
||||
|
||||
VariantWrapper() = default;
|
||||
DisableConstantFolding() = default;
|
||||
|
||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
DisableConstantFolding(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
|
||||
bool is_copyable() const override { return false; }
|
||||
};
|
||||
|
@ -16,25 +16,17 @@
|
||||
|
||||
namespace ov {
|
||||
|
||||
class TRANSFORMATIONS_API NmsSelectedIndices {
|
||||
public:
|
||||
NmsSelectedIndices() = default;
|
||||
};
|
||||
|
||||
TRANSFORMATIONS_API bool has_nms_selected_indices(const Node * node);
|
||||
|
||||
TRANSFORMATIONS_API void set_nms_selected_indices(Node * node);
|
||||
|
||||
extern template class TRANSFORMATIONS_API VariantImpl<NmsSelectedIndices>;
|
||||
|
||||
template<>
|
||||
class TRANSFORMATIONS_API VariantWrapper<NmsSelectedIndices> : public VariantImpl<NmsSelectedIndices> {
|
||||
class TRANSFORMATIONS_API NmsSelectedIndices : public VariantImpl<bool> {
|
||||
public:
|
||||
OPENVINO_RTTI("nms_selected_indices", "0");
|
||||
|
||||
VariantWrapper() = default;
|
||||
NmsSelectedIndices() = default;
|
||||
|
||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
NmsSelectedIndices(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
|
||||
bool is_copyable() const override { return false; }
|
||||
};
|
||||
|
@ -19,34 +19,6 @@
|
||||
#include <ngraph/variant.hpp>
|
||||
|
||||
namespace ov {
|
||||
|
||||
/**
|
||||
* @ingroup ie_runtime_attr_api
|
||||
* @brief PrimitivesPriority class represents runtime info attribute that
|
||||
* can be used for plugins specific primitive choice.
|
||||
*/
|
||||
class NGRAPH_API PrimitivesPriority {
|
||||
private:
|
||||
std::string primitives_priority;
|
||||
|
||||
public:
|
||||
friend class VariantWrapper<PrimitivesPriority>;
|
||||
/**
|
||||
* A default constructor
|
||||
*/
|
||||
PrimitivesPriority() = default;
|
||||
|
||||
/**
|
||||
* @brief Constructs a new object consisting of a single name *
|
||||
* @param[in] name The primitives priority value
|
||||
*/
|
||||
explicit PrimitivesPriority(const std::string &primitives_priority) : primitives_priority(primitives_priority) {}
|
||||
|
||||
/**
|
||||
* @brief return string with primitives priority value
|
||||
*/
|
||||
std::string getPrimitivesPriority() const;
|
||||
};
|
||||
/**
|
||||
* @ingroup ie_runtime_attr_api
|
||||
* @brief getPrimitivesPriority return string with primitive priorities value
|
||||
@ -54,16 +26,13 @@ public:
|
||||
*/
|
||||
NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node);
|
||||
|
||||
extern template class NGRAPH_API VariantImpl<PrimitivesPriority>;
|
||||
|
||||
template<>
|
||||
class NGRAPH_API VariantWrapper<PrimitivesPriority> : public VariantImpl<PrimitivesPriority> {
|
||||
class NGRAPH_API PrimitivesPriority : public VariantImpl<std::string> {
|
||||
public:
|
||||
OPENVINO_RTTI("primitives_priority", "0");
|
||||
|
||||
VariantWrapper() = default;
|
||||
PrimitivesPriority() = default;
|
||||
|
||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
PrimitivesPriority(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||
|
||||
std::shared_ptr<ov::Variant> merge(const ngraph::NodeVector & nodes) override;
|
||||
|
||||
@ -71,5 +40,4 @@ public:
|
||||
|
||||
bool visit_attributes(AttributeVisitor & visitor) override;
|
||||
};
|
||||
|
||||
} // namespace ov
|
||||
|
@ -10,14 +10,13 @@
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
namespace ov {
|
||||
template <>
|
||||
class TRANSFORMATIONS_API VariantWrapper<ngraph::Strides> : public VariantImpl<ngraph::Strides> {
|
||||
class TRANSFORMATIONS_API StridesPropagation : public VariantImpl<ngraph::Strides> {
|
||||
public:
|
||||
OPENVINO_RTTI("strides", "0");
|
||||
OPENVINO_RTTI("strides_propagation", "0");
|
||||
|
||||
VariantWrapper() = default;
|
||||
StridesPropagation() = default;
|
||||
|
||||
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
|
||||
StridesPropagation(const value_type& value) : VariantImpl<value_type>(value) {}
|
||||
};
|
||||
|
||||
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);
|
||||
|
@ -4,19 +4,17 @@
|
||||
|
||||
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||
|
||||
template class ov::VariantImpl<ov::DisableConstantFolding>;
|
||||
|
||||
void ov::disable_constant_folding(const std::shared_ptr<Node>& node) {
|
||||
auto & rt_info = node->get_rt_info();
|
||||
rt_info[VariantWrapper<DisableConstantFolding>::get_type_info_static()] = make_variant<DisableConstantFolding>({});
|
||||
rt_info[DisableConstantFolding::get_type_info_static()] = std::make_shared<DisableConstantFolding>(true);
|
||||
}
|
||||
|
||||
void ov::enable_constant_folding(const std::shared_ptr<Node>& node) {
|
||||
auto & rt_info = node->get_rt_info();
|
||||
rt_info.erase(VariantWrapper<DisableConstantFolding>::get_type_info_static());
|
||||
rt_info.erase(DisableConstantFolding::get_type_info_static());
|
||||
}
|
||||
|
||||
bool ov::constant_folding_is_disabled(const std::shared_ptr<Node> &node) {
|
||||
const auto & rt_info = node->get_rt_info();
|
||||
return rt_info.count(VariantWrapper<DisableConstantFolding>::get_type_info_static());
|
||||
return rt_info.count(DisableConstantFolding::get_type_info_static());
|
||||
}
|
||||
|
@ -8,10 +8,10 @@ template class ov::VariantImpl<ov::NmsSelectedIndices>;
|
||||
|
||||
void ov::set_nms_selected_indices(Node * node) {
|
||||
auto & rt_info = node->get_rt_info();
|
||||
rt_info[VariantWrapper<NmsSelectedIndices>::get_type_info_static()] = make_variant<NmsSelectedIndices>({});
|
||||
rt_info[NmsSelectedIndices::get_type_info_static()] = std::make_shared<NmsSelectedIndices>(true);
|
||||
}
|
||||
|
||||
bool ov::has_nms_selected_indices(const Node * node) {
|
||||
const auto & rt_info = node->get_rt_info();
|
||||
return rt_info.count(VariantWrapper<NmsSelectedIndices>::get_type_info_static());
|
||||
return rt_info.count(NmsSelectedIndices::get_type_info_static());
|
||||
}
|
||||
|
@ -6,17 +6,16 @@
|
||||
|
||||
bool ov::has_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
||||
const auto& rt_map = node.get_rt_info();
|
||||
auto it = rt_map.find(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static());
|
||||
return it != rt_map.end();
|
||||
return rt_map.count(StridesPropagation::get_type_info_static());
|
||||
}
|
||||
|
||||
ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
||||
const auto& rt_map = node.get_rt_info();
|
||||
const auto& var = rt_map.at(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static());
|
||||
return ngraph::as_type_ptr<ngraph::VariantWrapper<ngraph::Strides>>(var)->get();
|
||||
const auto& var = rt_map.at(StridesPropagation::get_type_info_static());
|
||||
return ngraph::as_type_ptr<StridesPropagation>(var)->get();
|
||||
}
|
||||
|
||||
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
|
||||
auto& rt_map = node.get_rt_info();
|
||||
rt_map[ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static()] = std::make_shared<ngraph::VariantWrapper<ngraph::Strides>>(strides);
|
||||
rt_map[StridesPropagation::get_type_info_static()] = std::make_shared<StridesPropagation>(strides);
|
||||
}
|
||||
|
@ -32,8 +32,8 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
|
||||
auto init_info = [](RTMap & info) {
|
||||
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
|
||||
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
|
||||
info[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
|
||||
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority("priority"));
|
||||
info[ov::PrimitivesPriority::get_type_info_static()] =
|
||||
std::make_shared<ov::PrimitivesPriority>("priority");
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::Function> function;
|
||||
@ -62,11 +62,11 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
|
||||
ASSERT_TRUE(fused_names_attr);
|
||||
ASSERT_EQ(fused_names_attr->get().getNames(), "add");
|
||||
|
||||
const std::string & pkey = VariantWrapper<ov::PrimitivesPriority>::get_type_info_static();
|
||||
const std::string & pkey = ov::PrimitivesPriority::get_type_info_static();
|
||||
ASSERT_TRUE(info.count(pkey));
|
||||
auto primitives_priority_attr = std::dynamic_pointer_cast<VariantWrapper<ov::PrimitivesPriority>>(info.at(pkey));
|
||||
auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
|
||||
ASSERT_TRUE(primitives_priority_attr);
|
||||
ASSERT_EQ(primitives_priority_attr->get().getPrimitivesPriority(), "priority");
|
||||
ASSERT_EQ(primitives_priority_attr->get(), "priority");
|
||||
};
|
||||
|
||||
auto add = f->get_results()[0]->get_input_node_ptr(0);
|
||||
|
@ -237,12 +237,12 @@ TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
|
||||
|
||||
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
|
||||
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
|
||||
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
|
||||
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority());
|
||||
rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
|
||||
std::make_shared<ov::PrimitivesPriority>("");
|
||||
};
|
||||
auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) {
|
||||
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
|
||||
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(name));
|
||||
rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
|
||||
std::make_shared<ov::PrimitivesPriority>(name);
|
||||
};
|
||||
checkCustomRt(setPrimEmpty, setPrim);
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ bool ngraph::pass::InitNodeInfo::run_on_function(std::shared_ptr<ngraph::Functio
|
||||
using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>;
|
||||
std::map<std::string, VariantCreator> update_attributes{
|
||||
{"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> {
|
||||
return std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(value));
|
||||
return std::make_shared<ov::PrimitivesPriority>(value);
|
||||
}}};
|
||||
|
||||
for (auto& node : f->get_ops()) {
|
||||
|
@ -17,25 +17,17 @@
|
||||
using namespace ov;
|
||||
using namespace ngraph;
|
||||
|
||||
std::string PrimitivesPriority::getPrimitivesPriority() const {
|
||||
return primitives_priority;
|
||||
}
|
||||
|
||||
std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) {
|
||||
const auto& rtInfo = node->get_rt_info();
|
||||
using PrimitivesPriorityWrapper = VariantWrapper<PrimitivesPriority>;
|
||||
|
||||
if (!rtInfo.count(PrimitivesPriorityWrapper::get_type_info_static()))
|
||||
if (!rtInfo.count(PrimitivesPriority::get_type_info_static()))
|
||||
return "";
|
||||
|
||||
const auto& attr = rtInfo.at(PrimitivesPriorityWrapper::get_type_info_static());
|
||||
PrimitivesPriority pp = ov::as_type_ptr<PrimitivesPriorityWrapper>(attr)->get();
|
||||
return pp.getPrimitivesPriority();
|
||||
const auto& attr = rtInfo.at(PrimitivesPriority::get_type_info_static());
|
||||
return ov::as_type_ptr<PrimitivesPriority>(attr)->get();
|
||||
}
|
||||
|
||||
template class ov::VariantImpl<PrimitivesPriority>;
|
||||
|
||||
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const ngraph::NodeVector& nodes) {
|
||||
std::shared_ptr<ngraph::Variant> PrimitivesPriority::merge(const ngraph::NodeVector& nodes) {
|
||||
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
|
||||
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
|
||||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
|
||||
@ -64,14 +56,14 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const
|
||||
if (unique_pp.size() == 1) {
|
||||
final_primitives_priority = *unique_pp.begin();
|
||||
}
|
||||
return std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority(final_primitives_priority));
|
||||
return std::make_shared<PrimitivesPriority>(final_primitives_priority);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::init(const std::shared_ptr<ngraph::Node>& node) {
|
||||
std::shared_ptr<ngraph::Variant> PrimitivesPriority::init(const std::shared_ptr<ngraph::Node>& node) {
|
||||
throw ngraph_error(std::string(get_type_info()) + " has no default initialization.");
|
||||
}
|
||||
|
||||
bool VariantWrapper<PrimitivesPriority>::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("value", m_value.primitives_priority);
|
||||
bool PrimitivesPriority::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("value", m_value);
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user