add if in python_api
This commit is contained in:
parent
ef33e3052c
commit
fca6e5a449
@ -3,7 +3,7 @@
|
||||
|
||||
"""Factory functions for all ngraph ops."""
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, List, Optional, Set, Union
|
||||
from typing import Callable, Iterable, List, Optional, Set, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from ngraph.impl import Node, Shape
|
||||
@ -367,3 +367,54 @@ def random_uniform(
|
||||
"op_seed": op_seed,
|
||||
}
|
||||
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
||||
|
||||
@nameable_op
|
||||
def if_op(
|
||||
condition: NodeInput,
|
||||
inputs: List[Node],
|
||||
bodies: Tuple(GraphBody, GraphBody)
|
||||
input_desc: Tuple(List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]),
|
||||
output_desc: Tuple(List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]),
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Perform recurrent execution of the network described in the body, iterating through the data.
|
||||
|
||||
@param trip_count: A scalar or 1D tensor with 1 element specifying
|
||||
maximum number of iterations.
|
||||
@param execution_condition: A scalar or 1D tensor with 1 element
|
||||
specifying whether to execute the first iteration or not.
|
||||
@param inputs: The provided to TensorIterator operator.
|
||||
@param graph_body: The graph representing the body we execute.
|
||||
@param slice_input_desc: The descriptors describing sliced inputs, that is nodes
|
||||
representing tensors we iterate through, processing single
|
||||
data slice in one iteration.
|
||||
@param merged_input_desc: The descriptors describing merged inputs, that is nodes
|
||||
representing variables with initial value at first iteration,
|
||||
which may be changing through iterations.
|
||||
@param invariant_input_desc: The descriptors describing invariant inputs, that is nodes
|
||||
representing variable with persistent value through all
|
||||
iterations.
|
||||
@param body_output_desc: The descriptors describing body outputs from specified
|
||||
iteration.
|
||||
@param concat_output_desc: The descriptors describing specified output values through
|
||||
all the iterations concatenated into one node.
|
||||
@param body_condition_output_idx: Determines the purpose of the corresponding result in
|
||||
the graph_body. This result will determine the dynamic
|
||||
exit condition. If the value of this result is False,
|
||||
then iterations stop.
|
||||
@param current_iteration_input_idx: Determines the purpose of the corresponding parameter
|
||||
in the graph_body. This parameter will be used as
|
||||
an iteration counter. Optional.
|
||||
@return: The new node which performs Loop.
|
||||
"""
|
||||
attributes = {
|
||||
"then_body": bodies[0].serialize(),
|
||||
"else_body": bodies[1].serialize(),
|
||||
"then_inputs": [desc.serialize() for desc in input_desc[0]],
|
||||
"else_inputs": [desc.serialize() for desc in input_desc[1]],
|
||||
"then_outputs": [desc.serialize() for desc in output_desc[0]],
|
||||
"else_outputs": [desc.serialize() for desc in output_desc[1]]
|
||||
}
|
||||
return _get_node_factory_opset8().create("If", as_nodes(condition, *inputs),
|
||||
attributes)
|
||||
|
||||
|
@ -24,69 +24,96 @@ util::DictAttributeDeserializer::DictAttributeDeserializer(
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||
if (m_attributes.contains(name)) {
|
||||
if (const auto& a = ngraph::as_type<
|
||||
ngraph::AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>>(
|
||||
ngraph::AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>>>>(
|
||||
&adapter)) {
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> input_descs;
|
||||
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
|
||||
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
|
||||
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
|
||||
for (py::handle h : slice_input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>(),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
input_descs.push_back(slice_in);
|
||||
}
|
||||
std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>> input_descs;
|
||||
|
||||
for (py::handle h : merged_input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>(),
|
||||
desc["body_value_idx"].cast<int64_t>());
|
||||
input_descs.push_back(merged_in);
|
||||
}
|
||||
if (name == "input_descriptions") {
|
||||
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
|
||||
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
|
||||
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
|
||||
for (py::handle h : slice_input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>(),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
input_descs.push_back(slice_in);
|
||||
}
|
||||
|
||||
for (py::handle h : invariant_input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>());
|
||||
input_descs.push_back(invariant_in);
|
||||
for (py::handle h : merged_input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>(),
|
||||
desc["body_value_idx"].cast<int64_t>());
|
||||
input_descs.push_back(merged_in);
|
||||
}
|
||||
|
||||
for (py::handle h : invariant_input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>());
|
||||
input_descs.push_back(invariant_in);
|
||||
}
|
||||
} else if (name == "then_inputs" || name == "else_inputs") {
|
||||
const py::list& input_desc = m_attributes[name.c_str()].cast<py::list>();
|
||||
for (py::handle h : input_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto invariant_in = std::make_shared<ngraph::op::util::MultiSubGraphOp::InvariantInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>());
|
||||
input_descs.push_back(invariant_in);
|
||||
}
|
||||
} else {
|
||||
NGRAPH_CHECK(false, "Input descriptions is not supported with name ", name);
|
||||
}
|
||||
a->set(input_descs);
|
||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> output_descs;
|
||||
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
|
||||
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
|
||||
for (py::handle h : body_output_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>(),
|
||||
desc["iteration"].cast<int64_t>());
|
||||
output_descs.push_back(body_output);
|
||||
}
|
||||
std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::OutputDescription>> output_descs;
|
||||
if (name == "output_descriptions") {
|
||||
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
|
||||
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
|
||||
for (py::handle h : body_output_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>(),
|
||||
desc["iteration"].cast<int64_t>());
|
||||
output_descs.push_back(body_output);
|
||||
}
|
||||
|
||||
for (py::handle h : concat_output_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>(),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
output_descs.push_back(concat_output);
|
||||
for (py::handle h : concat_output_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>(),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
output_descs.push_back(concat_output);
|
||||
}
|
||||
} else if (name == "then_outputs" || name == "else_outputs") {
|
||||
const py::list& output_desc = m_attributes[name.c_str()].cast<py::list>();
|
||||
for (py::handle h : output_desc) {
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto body_output = std::make_shared<ngraph::op::util::MultiSubGraphOp::BodyOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>());
|
||||
output_descs.push_back(body_output);
|
||||
}
|
||||
} else {
|
||||
NGRAPH_CHECK(false, "Output descriptions is not supported with name ", name);
|
||||
}
|
||||
a->set(output_descs);
|
||||
} else if (const auto& a =
|
||||
@ -241,7 +268,7 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
|
||||
if (m_attributes.contains(name)) {
|
||||
if (name == "body") {
|
||||
if (name == "body" || name == "then_body" || name == "else_body") {
|
||||
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
|
||||
const auto& body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
|
||||
|
Loading…
Reference in New Issue
Block a user