From 4037613db31475e79577ab5dfea69d11a2a2949d Mon Sep 17 00:00:00 2001 From: Ivan Tikhonov Date: Mon, 20 Jul 2020 14:15:37 +0300 Subject: [PATCH] Added default constructor for RNNCellBase, fix conversions (#1370) --- .../src/convert_function_to_cnn_network.cpp | 21 +++++++++++++++++++ .../src/readers/ir_reader/ie_ir_parser.hpp | 3 +++ ngraph/src/ngraph/op/util/rnn_cell_base.cpp | 6 ++++++ ngraph/src/ngraph/op/util/rnn_cell_base.hpp | 2 +- 4 files changed, 31 insertions(+), 1 deletion(-) diff --git a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp index d334c8a5063..cd7883a1efd 100644 --- a/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp +++ b/inference-engine/src/legacy_api/src/convert_function_to_cnn_network.cpp @@ -91,6 +91,24 @@ public: params[name] = std::to_string(adapter.get()); } + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + std::vector data = adapter.get(); + for (auto& str : data) { + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { + return std::tolower(c); + }); + } + + std::stringstream ss; + std::copy(data.begin(), data.end(), std::ostream_iterator(ss, ",")); + params[name] = ss.str(); + } + + void on_adapter(const std::string& name, ngraph::ValueAccessor>& adapter) override { + auto data = adapter.get(); + params[name] = joinVec(data); + } + void on_adapter(const std::string& name, ::ngraph::ValueAccessor& adapter) override; private: @@ -118,6 +136,9 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na } else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) { auto shape = static_cast<::ngraph::Strides&>(*a); params[name] = joinVec(shape); + } else { + THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. " + "Attribute adapter can not be found for " << name << " parameter"; } } diff --git a/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp b/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp index 2f7513c1833..862622d5a06 100644 --- a/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp +++ b/inference-engine/src/readers/ir_reader/ie_ir_parser.hpp @@ -218,6 +218,9 @@ private: } else if (auto a = ngraph::as_type>(&adapter)) { if (!getStrAttribute(node.child("data"), name, val)) return; static_cast(*a) = ngraph::as_enum(val); + } else { + THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name + << " parameter"; } } void on_adapter(const std::string& name, ngraph::ValueAccessor& adapter) override { diff --git a/ngraph/src/ngraph/op/util/rnn_cell_base.cpp b/ngraph/src/ngraph/op/util/rnn_cell_base.cpp index d931e588242..276187fc380 100644 --- a/ngraph/src/ngraph/op/util/rnn_cell_base.cpp +++ b/ngraph/src/ngraph/op/util/rnn_cell_base.cpp @@ -37,6 +37,12 @@ static vector to_lower_case(const vector& vs) return res; } +op::util::RNNCellBase::RNNCellBase() + : m_clip(0.f) + , m_hidden_size(0) +{ +} + op::util::RNNCellBase::RNNCellBase(size_t hidden_size, float clip, const vector& activations, diff --git a/ngraph/src/ngraph/op/util/rnn_cell_base.hpp b/ngraph/src/ngraph/op/util/rnn_cell_base.hpp index ad02dfb1ca0..9034ddf6c01 100644 --- a/ngraph/src/ngraph/op/util/rnn_cell_base.hpp +++ b/ngraph/src/ngraph/op/util/rnn_cell_base.hpp @@ -56,7 +56,7 @@ namespace ngraph const std::vector& activations_alpha, const std::vector& activations_beta); - RNNCellBase() = default; + RNNCellBase(); virtual ~RNNCellBase() = default; virtual bool visit_attributes(AttributeVisitor& visitor);