Added default constructor for RNNCellBase, fix conversions (#1370)
This commit is contained in:
parent
06119efdf2
commit
4037613db3
@ -91,6 +91,24 @@ public:
|
|||||||
params[name] = std::to_string(adapter.get());
|
params[name] = std::to_string(adapter.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<std::string>>& adapter) override {
|
||||||
|
std::vector<std::string> data = adapter.get();
|
||||||
|
for (auto& str : data) {
|
||||||
|
std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) {
|
||||||
|
return std::tolower(c);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
std::copy(data.begin(), data.end(), std::ostream_iterator<std::string>(ss, ","));
|
||||||
|
params[name] = ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<float>>& adapter) override {
|
||||||
|
auto data = adapter.get();
|
||||||
|
params[name] = joinVec(data);
|
||||||
|
}
|
||||||
|
|
||||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
|
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -118,6 +136,9 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
|
|||||||
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) {
|
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<::ngraph::Strides>>(&adapter)) {
|
||||||
auto shape = static_cast<::ngraph::Strides&>(*a);
|
auto shape = static_cast<::ngraph::Strides&>(*a);
|
||||||
params[name] = joinVec(shape);
|
params[name] = joinVec(shape);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
|
||||||
|
"Attribute adapter can not be found for " << name << " parameter";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,6 +218,9 @@ private:
|
|||||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
|
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
|
||||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||||
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
|
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
|
||||||
|
} else {
|
||||||
|
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
|
||||||
|
<< " parameter";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
|
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
|
||||||
|
@ -37,6 +37,12 @@ static vector<string> to_lower_case(const vector<string>& vs)
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
op::util::RNNCellBase::RNNCellBase()
|
||||||
|
: m_clip(0.f)
|
||||||
|
, m_hidden_size(0)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
|
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
|
||||||
float clip,
|
float clip,
|
||||||
const vector<string>& activations,
|
const vector<string>& activations,
|
||||||
|
@ -56,7 +56,7 @@ namespace ngraph
|
|||||||
const std::vector<float>& activations_alpha,
|
const std::vector<float>& activations_alpha,
|
||||||
const std::vector<float>& activations_beta);
|
const std::vector<float>& activations_beta);
|
||||||
|
|
||||||
RNNCellBase() = default;
|
RNNCellBase();
|
||||||
virtual ~RNNCellBase() = default;
|
virtual ~RNNCellBase() = default;
|
||||||
|
|
||||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||||
|
Loading…
Reference in New Issue
Block a user