[GNA] convertFunctionToICNNNetwork operation performance improovement (#17685)

* make CNNLayerCreator be persistent accross single convertFunctionToICNNNetwork operation

* [GNA] RR comments applied

* [GNA] RR comments applied
This commit is contained in:
Maciej Kwapulinski 2023-06-01 13:21:54 +02:00 committed by GitHub
parent edf089bf22
commit 4b1d0fbc37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -353,11 +353,13 @@ CNNLayer::Ptr createSubGraphLayer(const std::shared_ptr<ngraph::Node>& layer) {
*/ */
class CNNLayerCreator : public ::ngraph::AttributeVisitor { class CNNLayerCreator : public ::ngraph::AttributeVisitor {
public: public:
explicit CNNLayerCreator();
CNNLayerPtr create(const std::shared_ptr<::ngraph::Node>& origin);
protected:
using CreatorFor = std::function<CNNLayerPtr(const std::shared_ptr<::ngraph::Node>& node, using CreatorFor = std::function<CNNLayerPtr(const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string>& param)>; const std::map<std::string, std::string>& param)>;
explicit CNNLayerCreator(const std::shared_ptr<::ngraph::Node>& node);
CNNLayerPtr create();
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<bool>& value) override { void on_adapter(const std::string& name, ::ngraph::ValueAccessor<bool>& value) override {
params[name] = value.get() ? "true" : "false"; params[name] = value.get() ? "true" : "false";
@ -481,7 +483,7 @@ void CNNLayerCreator::on_adapter(const std::string& name, ::ngraph::ValueAccesso
} }
} }
CNNLayerCreator::CNNLayerCreator(const std::shared_ptr<::ngraph::Node>& node) : node(node) { CNNLayerCreator::CNNLayerCreator() {
addSpecificCreator({"Parameter"}, addSpecificCreator({"Parameter"},
[](const std::shared_ptr<::ngraph::Node>& node, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string>& params) -> CNNLayerPtr { const std::map<std::string, std::string>& params) -> CNNLayerPtr {
@ -1968,15 +1970,28 @@ CNNLayerCreator::CNNLayerCreator(const std::shared_ptr<::ngraph::Node>& node) :
}); });
} }
CNNLayerPtr CNNLayerCreator::create() { CNNLayerPtr CNNLayerCreator::create(const std::shared_ptr<::ngraph::Node>& origin) {
node = origin; // node used by node->visit_attributes(..) > AttributeVisitor::on_attribute(..)
if (!node->visit_attributes(*this))
return nullptr;
LayerParams attrs = {node->get_friendly_name(), LayerParams attrs = {node->get_friendly_name(),
node->description(), node->description(),
details::convertPrecision(node->get_output_element_type(0))}; details::convertPrecision(node->get_output_element_type(0))};
if (creators.find(node->description()) != creators.end())
return creators[node->description()](node, params);
auto res = std::make_shared<CNNLayer>(attrs); CNNLayerPtr res;
auto creator = creators.find(node->description());
if (creator != creators.end())
res = creator->second(node, params);
else {
res = std::make_shared<CNNLayer>(attrs);
res->params = params; res->params = params;
}
node = nullptr;
return res; return res;
} }
@ -1989,19 +2004,17 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
bool keep_constant_inputs) { bool keep_constant_inputs) {
OV_ITT_SCOPED_TASK(itt::domains::IELegacy, "details::convertFunctionToICNNNetwork"); OV_ITT_SCOPED_TASK(itt::domains::IELegacy, "details::convertFunctionToICNNNetwork");
const auto createCNNLayer = [](const std::shared_ptr<::ngraph::Node>& node) -> CNNLayerPtr { CNNLayerCreator visitor;
const auto createCNNLayer = [&visitor](const std::shared_ptr<::ngraph::Node>& node) -> CNNLayerPtr {
class NGraphCNNLayer : public CNNLayer { class NGraphCNNLayer : public CNNLayer {
public: public:
void setNode(const std::shared_ptr<::ngraph::Node>& node) { void setNode(const std::shared_ptr<::ngraph::Node>& node) {
this->node = node; this->node = node;
} }
}; };
CNNLayerPtr result;
CNNLayerCreator visitor(node); CNNLayerPtr result = visitor.create(node);
if (node->visit_attributes(visitor)) {
result = visitor.create();
}
if (!result) if (!result)
IE_THROW() << "Cannot cast ngraph node " << node->get_friendly_name() << " to CNNLayer!"; IE_THROW() << "Cannot cast ngraph node " << node->get_friendly_name() << " to CNNLayer!";