Remove VariantWrapper for simple attr classes (#7771)
* Remove VariantWrapper for simple attr classes * Code style
This commit is contained in:
parent
faeaf045a9
commit
a16cc81233
@ -123,9 +123,8 @@ std::string NetworkCompilationContext::computeHash(const CNNNetwork& network,
|
|||||||
} else if (auto fNames =
|
} else if (auto fNames =
|
||||||
std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) {
|
std::dynamic_pointer_cast<ngraph::VariantWrapper<ngraph::FusedNames>>(rtMapData.second)) {
|
||||||
seed = hash_combine(seed, fNames->get().getNames());
|
seed = hash_combine(seed, fNames->get().getNames());
|
||||||
} else if (auto prim = std::dynamic_pointer_cast<ngraph::VariantWrapper<ov::PrimitivesPriority>>(
|
} else if (auto prim = std::dynamic_pointer_cast<ov::PrimitivesPriority>(rtMapData.second)) {
|
||||||
rtMapData.second)) {
|
seed = hash_combine(seed, prim->get());
|
||||||
seed = hash_combine(seed, prim->get().getPrimitivesPriority());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -22,13 +22,11 @@ inline std::string getRTInfoValue(const std::map<std::string, std::shared_ptr<ng
|
|||||||
|
|
||||||
inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) {
|
inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) {
|
||||||
const auto &rtInfo = node->get_rt_info();
|
const auto &rtInfo = node->get_rt_info();
|
||||||
using PrimitivesPriorityWraper = ngraph::VariantWrapper<ov::PrimitivesPriority>;
|
|
||||||
|
|
||||||
if (!rtInfo.count(PrimitivesPriorityWraper::get_type_info_static())) return "";
|
if (!rtInfo.count(ov::PrimitivesPriority::get_type_info_static())) return "";
|
||||||
|
|
||||||
const auto &attr = rtInfo.at(PrimitivesPriorityWraper::get_type_info_static());
|
const auto &attr = rtInfo.at(ov::PrimitivesPriority::get_type_info_static());
|
||||||
ov::PrimitivesPriority pp = ngraph::as_type_ptr<PrimitivesPriorityWraper>(attr)->get();
|
return ngraph::as_type_ptr<ov::PrimitivesPriority>(attr)->get();
|
||||||
return pp.getPrimitivesPriority();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -28,10 +28,10 @@ class TRANSFORMATIONS_API Attributes {
|
|||||||
public:
|
public:
|
||||||
Attributes() {
|
Attributes() {
|
||||||
register_factory<VariantWrapper<ngraph::FusedNames>>();
|
register_factory<VariantWrapper<ngraph::FusedNames>>();
|
||||||
register_factory<VariantWrapper<PrimitivesPriority>>();
|
register_factory<PrimitivesPriority>();
|
||||||
register_factory<VariantWrapper<DisableConstantFolding>>();
|
register_factory<DisableConstantFolding>();
|
||||||
register_factory<VariantWrapper<NmsSelectedIndices>>();
|
register_factory<NmsSelectedIndices>();
|
||||||
register_factory<VariantWrapper<Strides>>();
|
register_factory<StridesPropagation>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {
|
Variant * create_by_type_info(const ov::DiscreteTypeInfo & type_info) {
|
||||||
|
@ -17,31 +17,19 @@
|
|||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
|
|
||||||
/**
|
|
||||||
* @ingroup ie_runtime_attr_api
|
|
||||||
* @brief DisableConstantFolding disable ConstantFolding for given operation
|
|
||||||
*/
|
|
||||||
class TRANSFORMATIONS_API DisableConstantFolding {
|
|
||||||
public:
|
|
||||||
DisableConstantFolding() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
|
TRANSFORMATIONS_API void disable_constant_folding(const std::shared_ptr<Node>& node);
|
||||||
|
|
||||||
TRANSFORMATIONS_API void enable_constant_folding(const std::shared_ptr<Node>& node);
|
TRANSFORMATIONS_API void enable_constant_folding(const std::shared_ptr<Node>& node);
|
||||||
|
|
||||||
TRANSFORMATIONS_API bool constant_folding_is_disabled(const std::shared_ptr<Node>& node);
|
TRANSFORMATIONS_API bool constant_folding_is_disabled(const std::shared_ptr<Node>& node);
|
||||||
|
|
||||||
extern template class TRANSFORMATIONS_API VariantImpl<DisableConstantFolding>;
|
class TRANSFORMATIONS_API DisableConstantFolding : public VariantImpl<bool> {
|
||||||
|
|
||||||
template<>
|
|
||||||
class TRANSFORMATIONS_API VariantWrapper<DisableConstantFolding> : public VariantImpl<DisableConstantFolding> {
|
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("disabled_constant_folding", "0");
|
OPENVINO_RTTI("disabled_constant_folding", "0");
|
||||||
|
|
||||||
VariantWrapper() = default;
|
DisableConstantFolding() = default;
|
||||||
|
|
||||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
DisableConstantFolding(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||||
|
|
||||||
bool is_copyable() const override { return false; }
|
bool is_copyable() const override { return false; }
|
||||||
};
|
};
|
||||||
|
@ -16,25 +16,17 @@
|
|||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
|
|
||||||
class TRANSFORMATIONS_API NmsSelectedIndices {
|
|
||||||
public:
|
|
||||||
NmsSelectedIndices() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
TRANSFORMATIONS_API bool has_nms_selected_indices(const Node * node);
|
TRANSFORMATIONS_API bool has_nms_selected_indices(const Node * node);
|
||||||
|
|
||||||
TRANSFORMATIONS_API void set_nms_selected_indices(Node * node);
|
TRANSFORMATIONS_API void set_nms_selected_indices(Node * node);
|
||||||
|
|
||||||
extern template class TRANSFORMATIONS_API VariantImpl<NmsSelectedIndices>;
|
class TRANSFORMATIONS_API NmsSelectedIndices : public VariantImpl<bool> {
|
||||||
|
|
||||||
template<>
|
|
||||||
class TRANSFORMATIONS_API VariantWrapper<NmsSelectedIndices> : public VariantImpl<NmsSelectedIndices> {
|
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("nms_selected_indices", "0");
|
OPENVINO_RTTI("nms_selected_indices", "0");
|
||||||
|
|
||||||
VariantWrapper() = default;
|
NmsSelectedIndices() = default;
|
||||||
|
|
||||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
NmsSelectedIndices(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||||
|
|
||||||
bool is_copyable() const override { return false; }
|
bool is_copyable() const override { return false; }
|
||||||
};
|
};
|
||||||
|
@ -19,34 +19,6 @@
|
|||||||
#include <ngraph/variant.hpp>
|
#include <ngraph/variant.hpp>
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
|
|
||||||
/**
|
|
||||||
* @ingroup ie_runtime_attr_api
|
|
||||||
* @brief PrimitivesPriority class represents runtime info attribute that
|
|
||||||
* can be used for plugins specific primitive choice.
|
|
||||||
*/
|
|
||||||
class NGRAPH_API PrimitivesPriority {
|
|
||||||
private:
|
|
||||||
std::string primitives_priority;
|
|
||||||
|
|
||||||
public:
|
|
||||||
friend class VariantWrapper<PrimitivesPriority>;
|
|
||||||
/**
|
|
||||||
* A default constructor
|
|
||||||
*/
|
|
||||||
PrimitivesPriority() = default;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Constructs a new object consisting of a single name *
|
|
||||||
* @param[in] name The primitives priority value
|
|
||||||
*/
|
|
||||||
explicit PrimitivesPriority(const std::string &primitives_priority) : primitives_priority(primitives_priority) {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief return string with primitives priority value
|
|
||||||
*/
|
|
||||||
std::string getPrimitivesPriority() const;
|
|
||||||
};
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_runtime_attr_api
|
* @ingroup ie_runtime_attr_api
|
||||||
* @brief getPrimitivesPriority return string with primitive priorities value
|
* @brief getPrimitivesPriority return string with primitive priorities value
|
||||||
@ -54,16 +26,13 @@ public:
|
|||||||
*/
|
*/
|
||||||
NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node);
|
NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node);
|
||||||
|
|
||||||
extern template class NGRAPH_API VariantImpl<PrimitivesPriority>;
|
class NGRAPH_API PrimitivesPriority : public VariantImpl<std::string> {
|
||||||
|
|
||||||
template<>
|
|
||||||
class NGRAPH_API VariantWrapper<PrimitivesPriority> : public VariantImpl<PrimitivesPriority> {
|
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("primitives_priority", "0");
|
OPENVINO_RTTI("primitives_priority", "0");
|
||||||
|
|
||||||
VariantWrapper() = default;
|
PrimitivesPriority() = default;
|
||||||
|
|
||||||
VariantWrapper(const value_type &value) : VariantImpl<value_type>(value) {}
|
PrimitivesPriority(const value_type &value) : VariantImpl<value_type>(value) {}
|
||||||
|
|
||||||
std::shared_ptr<ov::Variant> merge(const ngraph::NodeVector & nodes) override;
|
std::shared_ptr<ov::Variant> merge(const ngraph::NodeVector & nodes) override;
|
||||||
|
|
||||||
@ -71,5 +40,4 @@ public:
|
|||||||
|
|
||||||
bool visit_attributes(AttributeVisitor & visitor) override;
|
bool visit_attributes(AttributeVisitor & visitor) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
@ -10,14 +10,13 @@
|
|||||||
#include <transformations_visibility.hpp>
|
#include <transformations_visibility.hpp>
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
template <>
|
class TRANSFORMATIONS_API StridesPropagation : public VariantImpl<ngraph::Strides> {
|
||||||
class TRANSFORMATIONS_API VariantWrapper<ngraph::Strides> : public VariantImpl<ngraph::Strides> {
|
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("strides", "0");
|
OPENVINO_RTTI("strides_propagation", "0");
|
||||||
|
|
||||||
VariantWrapper() = default;
|
StridesPropagation() = default;
|
||||||
|
|
||||||
VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
|
StridesPropagation(const value_type& value) : VariantImpl<value_type>(value) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);
|
TRANSFORMATIONS_API bool has_strides_prop(const ngraph::Input<ngraph::Node>& node);
|
||||||
|
@ -4,19 +4,17 @@
|
|||||||
|
|
||||||
#include "transformations/rt_info/disable_constant_folding.hpp"
|
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||||
|
|
||||||
template class ov::VariantImpl<ov::DisableConstantFolding>;
|
|
||||||
|
|
||||||
void ov::disable_constant_folding(const std::shared_ptr<Node>& node) {
|
void ov::disable_constant_folding(const std::shared_ptr<Node>& node) {
|
||||||
auto & rt_info = node->get_rt_info();
|
auto & rt_info = node->get_rt_info();
|
||||||
rt_info[VariantWrapper<DisableConstantFolding>::get_type_info_static()] = make_variant<DisableConstantFolding>({});
|
rt_info[DisableConstantFolding::get_type_info_static()] = std::make_shared<DisableConstantFolding>(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ov::enable_constant_folding(const std::shared_ptr<Node>& node) {
|
void ov::enable_constant_folding(const std::shared_ptr<Node>& node) {
|
||||||
auto & rt_info = node->get_rt_info();
|
auto & rt_info = node->get_rt_info();
|
||||||
rt_info.erase(VariantWrapper<DisableConstantFolding>::get_type_info_static());
|
rt_info.erase(DisableConstantFolding::get_type_info_static());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ov::constant_folding_is_disabled(const std::shared_ptr<Node> &node) {
|
bool ov::constant_folding_is_disabled(const std::shared_ptr<Node> &node) {
|
||||||
const auto & rt_info = node->get_rt_info();
|
const auto & rt_info = node->get_rt_info();
|
||||||
return rt_info.count(VariantWrapper<DisableConstantFolding>::get_type_info_static());
|
return rt_info.count(DisableConstantFolding::get_type_info_static());
|
||||||
}
|
}
|
||||||
|
@ -8,10 +8,10 @@ template class ov::VariantImpl<ov::NmsSelectedIndices>;
|
|||||||
|
|
||||||
void ov::set_nms_selected_indices(Node * node) {
|
void ov::set_nms_selected_indices(Node * node) {
|
||||||
auto & rt_info = node->get_rt_info();
|
auto & rt_info = node->get_rt_info();
|
||||||
rt_info[VariantWrapper<NmsSelectedIndices>::get_type_info_static()] = make_variant<NmsSelectedIndices>({});
|
rt_info[NmsSelectedIndices::get_type_info_static()] = std::make_shared<NmsSelectedIndices>(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ov::has_nms_selected_indices(const Node * node) {
|
bool ov::has_nms_selected_indices(const Node * node) {
|
||||||
const auto & rt_info = node->get_rt_info();
|
const auto & rt_info = node->get_rt_info();
|
||||||
return rt_info.count(VariantWrapper<NmsSelectedIndices>::get_type_info_static());
|
return rt_info.count(NmsSelectedIndices::get_type_info_static());
|
||||||
}
|
}
|
||||||
|
@ -6,17 +6,16 @@
|
|||||||
|
|
||||||
bool ov::has_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
bool ov::has_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
||||||
const auto& rt_map = node.get_rt_info();
|
const auto& rt_map = node.get_rt_info();
|
||||||
auto it = rt_map.find(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static());
|
return rt_map.count(StridesPropagation::get_type_info_static());
|
||||||
return it != rt_map.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
ngraph::Strides ov::get_strides_prop(const ngraph::Input<ngraph::Node>& node) {
|
||||||
const auto& rt_map = node.get_rt_info();
|
const auto& rt_map = node.get_rt_info();
|
||||||
const auto& var = rt_map.at(ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static());
|
const auto& var = rt_map.at(StridesPropagation::get_type_info_static());
|
||||||
return ngraph::as_type_ptr<ngraph::VariantWrapper<ngraph::Strides>>(var)->get();
|
return ngraph::as_type_ptr<StridesPropagation>(var)->get();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
|
void ov::insert_strides_prop(ngraph::Input<ngraph::Node>& node, const ngraph::Strides& strides) {
|
||||||
auto& rt_map = node.get_rt_info();
|
auto& rt_map = node.get_rt_info();
|
||||||
rt_map[ngraph::VariantWrapper<ngraph::Strides>::get_type_info_static()] = std::make_shared<ngraph::VariantWrapper<ngraph::Strides>>(strides);
|
rt_map[StridesPropagation::get_type_info_static()] = std::make_shared<StridesPropagation>(strides);
|
||||||
}
|
}
|
||||||
|
@ -32,8 +32,8 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
|
|||||||
auto init_info = [](RTMap & info) {
|
auto init_info = [](RTMap & info) {
|
||||||
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
|
info[VariantWrapper<ngraph::FusedNames>::get_type_info_static()] =
|
||||||
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
|
std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames("add"));
|
||||||
info[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
|
info[ov::PrimitivesPriority::get_type_info_static()] =
|
||||||
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority("priority"));
|
std::make_shared<ov::PrimitivesPriority>("priority");
|
||||||
};
|
};
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> function;
|
std::shared_ptr<ngraph::Function> function;
|
||||||
@ -62,11 +62,11 @@ TEST_F(RTInfoSerializationTest, all_attributes) {
|
|||||||
ASSERT_TRUE(fused_names_attr);
|
ASSERT_TRUE(fused_names_attr);
|
||||||
ASSERT_EQ(fused_names_attr->get().getNames(), "add");
|
ASSERT_EQ(fused_names_attr->get().getNames(), "add");
|
||||||
|
|
||||||
const std::string & pkey = VariantWrapper<ov::PrimitivesPriority>::get_type_info_static();
|
const std::string & pkey = ov::PrimitivesPriority::get_type_info_static();
|
||||||
ASSERT_TRUE(info.count(pkey));
|
ASSERT_TRUE(info.count(pkey));
|
||||||
auto primitives_priority_attr = std::dynamic_pointer_cast<VariantWrapper<ov::PrimitivesPriority>>(info.at(pkey));
|
auto primitives_priority_attr = std::dynamic_pointer_cast<ov::PrimitivesPriority>(info.at(pkey));
|
||||||
ASSERT_TRUE(primitives_priority_attr);
|
ASSERT_TRUE(primitives_priority_attr);
|
||||||
ASSERT_EQ(primitives_priority_attr->get().getPrimitivesPriority(), "priority");
|
ASSERT_EQ(primitives_priority_attr->get(), "priority");
|
||||||
};
|
};
|
||||||
|
|
||||||
auto add = f->get_results()[0]->get_input_node_ptr(0);
|
auto add = f->get_results()[0]->get_input_node_ptr(0);
|
||||||
|
@ -237,12 +237,12 @@ TEST(NetworkContext_CNNNetwork, HashWithFusedNames) {
|
|||||||
|
|
||||||
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
|
TEST(NetworkContext_CNNNetwork, HashWithPrimitivesPriorityType) {
|
||||||
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
|
auto setPrimEmpty = [&](Node::RTMap& rtInfo) {
|
||||||
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
|
rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
|
||||||
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority());
|
std::make_shared<ov::PrimitivesPriority>("");
|
||||||
};
|
};
|
||||||
auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) {
|
auto setPrim = [&](Node::RTMap& rtInfo, const std::string& name) {
|
||||||
rtInfo[VariantWrapper<ov::PrimitivesPriority>::get_type_info_static()] =
|
rtInfo[ov::PrimitivesPriority::get_type_info_static()] =
|
||||||
std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(name));
|
std::make_shared<ov::PrimitivesPriority>(name);
|
||||||
};
|
};
|
||||||
checkCustomRt(setPrimEmpty, setPrim);
|
checkCustomRt(setPrimEmpty, setPrim);
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ bool ngraph::pass::InitNodeInfo::run_on_function(std::shared_ptr<ngraph::Functio
|
|||||||
using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>;
|
using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>;
|
||||||
std::map<std::string, VariantCreator> update_attributes{
|
std::map<std::string, VariantCreator> update_attributes{
|
||||||
{"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> {
|
{"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> {
|
||||||
return std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(value));
|
return std::make_shared<ov::PrimitivesPriority>(value);
|
||||||
}}};
|
}}};
|
||||||
|
|
||||||
for (auto& node : f->get_ops()) {
|
for (auto& node : f->get_ops()) {
|
||||||
|
@ -17,25 +17,17 @@
|
|||||||
using namespace ov;
|
using namespace ov;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
std::string PrimitivesPriority::getPrimitivesPriority() const {
|
|
||||||
return primitives_priority;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) {
|
std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) {
|
||||||
const auto& rtInfo = node->get_rt_info();
|
const auto& rtInfo = node->get_rt_info();
|
||||||
using PrimitivesPriorityWrapper = VariantWrapper<PrimitivesPriority>;
|
|
||||||
|
|
||||||
if (!rtInfo.count(PrimitivesPriorityWrapper::get_type_info_static()))
|
if (!rtInfo.count(PrimitivesPriority::get_type_info_static()))
|
||||||
return "";
|
return "";
|
||||||
|
|
||||||
const auto& attr = rtInfo.at(PrimitivesPriorityWrapper::get_type_info_static());
|
const auto& attr = rtInfo.at(PrimitivesPriority::get_type_info_static());
|
||||||
PrimitivesPriority pp = ov::as_type_ptr<PrimitivesPriorityWrapper>(attr)->get();
|
return ov::as_type_ptr<PrimitivesPriority>(attr)->get();
|
||||||
return pp.getPrimitivesPriority();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template class ov::VariantImpl<PrimitivesPriority>;
|
std::shared_ptr<ngraph::Variant> PrimitivesPriority::merge(const ngraph::NodeVector& nodes) {
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const ngraph::NodeVector& nodes) {
|
|
||||||
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
|
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
|
||||||
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
|
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
|
||||||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
|
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
|
||||||
@ -64,14 +56,14 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const
|
|||||||
if (unique_pp.size() == 1) {
|
if (unique_pp.size() == 1) {
|
||||||
final_primitives_priority = *unique_pp.begin();
|
final_primitives_priority = *unique_pp.begin();
|
||||||
}
|
}
|
||||||
return std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority(final_primitives_priority));
|
return std::make_shared<PrimitivesPriority>(final_primitives_priority);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::init(const std::shared_ptr<ngraph::Node>& node) {
|
std::shared_ptr<ngraph::Variant> PrimitivesPriority::init(const std::shared_ptr<ngraph::Node>& node) {
|
||||||
throw ngraph_error(std::string(get_type_info()) + " has no default initialization.");
|
throw ngraph_error(std::string(get_type_info()) + " has no default initialization.");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VariantWrapper<PrimitivesPriority>::visit_attributes(AttributeVisitor& visitor) {
|
bool PrimitivesPriority::visit_attributes(AttributeVisitor& visitor) {
|
||||||
visitor.on_attribute("value", m_value.primitives_priority);
|
visitor.on_attribute("value", m_value);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user