Added default constructor for RNNCellBase, fix conversions (#1370)

This commit is contained in:
Ivan Tikhonov 2020-07-20 14:15:37 +03:00 committed by GitHub
parent 06119efdf2
commit 4037613db3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 1 deletions

View File

@ -91,6 +91,24 @@ public:
params[name] = std::to_string(adapter.get()); params[name] = std::to_string(adapter.get());
} }
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
std::vector<std::string> data = adapter.get();
for (auto& str : data) {
std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) {
return std::tolower(c);
});
}
std::stringstream ss;
std::copy(data.begin(), data.end(), std::ostream_iterator<std::string>(ss, ","));
params[name] = ss.str();
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<float>>& adapter) override {
auto data = adapter.get();
params[name] = joinVec(data);
}
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override; void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
private: private:
@ -118,6 +136,9 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) { } else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) {
auto shape = static_cast<::ngraph::Strides&>(*a); auto shape = static_cast<::ngraph::Strides&>(*a);
params[name] = joinVec(shape); params[name] = joinVec(shape);
} else {
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
"Attribute adapter can not be found for " << name << " parameter";
} }
} }

View File

@ -218,6 +218,9 @@ private:
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) { } else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
if (!getStrAttribute(node.child("data"), name, val)) return; if (!getStrAttribute(node.child("data"), name, val)) return;
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val); static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
} else {
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
<< " parameter";
} }
} }
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override { void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {

View File

@ -37,6 +37,12 @@ static vector<string> to_lower_case(const vector<string>& vs)
return res; return res;
} }
op::util::RNNCellBase::RNNCellBase()
: m_clip(0.f)
, m_hidden_size(0)
{
}
op::util::RNNCellBase::RNNCellBase(size_t hidden_size, op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
float clip, float clip,
const vector<string>& activations, const vector<string>& activations,

View File

@ -56,7 +56,7 @@ namespace ngraph
const std::vector<float>& activations_alpha, const std::vector<float>& activations_alpha,
const std::vector<float>& activations_beta); const std::vector<float>& activations_beta);
RNNCellBase() = default; RNNCellBase();
virtual ~RNNCellBase() = default; virtual ~RNNCellBase() = default;
virtual bool visit_attributes(AttributeVisitor& visitor); virtual bool visit_attributes(AttributeVisitor& visitor);