Added default constructor for RNNCellBase, fix conversions (#1370)
This commit is contained in:
parent
06119efdf2
commit
4037613db3
@ -91,6 +91,24 @@ public:
|
||||
params[name] = std::to_string(adapter.get());
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
|
||||
std::vector<std::string> data = adapter.get();
|
||||
for (auto& str : data) {
|
||||
std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) {
|
||||
return std::tolower(c);
|
||||
});
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
std::copy(data.begin(), data.end(), std::ostream_iterator<std::string>(ss, ","));
|
||||
params[name] = ss.str();
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<float>>& adapter) override {
|
||||
auto data = adapter.get();
|
||||
params[name] = joinVec(data);
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
|
||||
|
||||
private:
|
||||
@ -118,6 +136,9 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
|
||||
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) {
|
||||
auto shape = static_cast<::ngraph::Strides&>(*a);
|
||||
params[name] = joinVec(shape);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
|
||||
"Attribute adapter can not be found for " << name << " parameter";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -218,6 +218,9 @@ private:
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
|
||||
<< " parameter";
|
||||
}
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
|
||||
|
@ -37,6 +37,12 @@ static vector<string> to_lower_case(const vector<string>& vs)
|
||||
return res;
|
||||
}
|
||||
|
||||
op::util::RNNCellBase::RNNCellBase()
|
||||
: m_clip(0.f)
|
||||
, m_hidden_size(0)
|
||||
{
|
||||
}
|
||||
|
||||
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
|
||||
float clip,
|
||||
const vector<string>& activations,
|
||||
|
@ -56,7 +56,7 @@ namespace ngraph
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta);
|
||||
|
||||
RNNCellBase() = default;
|
||||
RNNCellBase();
|
||||
virtual ~RNNCellBase() = default;
|
||||
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
|
Loading…
Reference in New Issue
Block a user