Fix python API for Loop/TensorIterator/Assign/ReadValue operations (#3275)
* update python API for the Loop operation (attempt 1) * fix code style * update visitor API, pep8 * codestyle * delete TensorItertorBuilder * update visitor API for Loop and TensorIterator * fix code issues (python) * update python api:assign/read value * fix unit tests * fix python code style * fix loop unit test * fix loop unit test * fix build on centOS
This commit is contained in:
parent
56a179c6a3
commit
b792214d04
@ -1497,17 +1497,17 @@ def lstm_sequence(
|
||||
Shape: [batch_size]. Integer type.
|
||||
@param W: Tensor with weights for matrix multiplication operation with input portion of data.
|
||||
Shape: [num_directions, 4*hidden_size, input_size].
|
||||
:param R: The tensor with weights for matrix multiplication operation with hidden state.
|
||||
@param R: The tensor with weights for matrix multiplication operation with hidden state.
|
||||
Shape: [num_directions, 4*hidden_size, hidden_size].
|
||||
:param B: The tensor with biases.
|
||||
@param B: The tensor with biases.
|
||||
Shape: [num_directions, 4*hidden_size].
|
||||
:param hidden_size: Specifies hidden state size.
|
||||
:param direction: Specifies if the RNN is forward, reverse, or bidirectional.
|
||||
:param activations: The list of three activation functions for gates.
|
||||
:param activations_alpha: The list of alpha parameters for activation functions.
|
||||
:param activations_beta: The list of beta parameters for activation functions.
|
||||
:param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
|
||||
:param name: An optional name of the output node.
|
||||
@param hidden_size: Specifies hidden state size.
|
||||
@param direction: Specifies if the RNN is forward, reverse, or bidirectional.
|
||||
@param activations: The list of three activation functions for gates.
|
||||
@param activations_alpha: The list of alpha parameters for activation functions.
|
||||
@param activations_beta: The list of beta parameters for activation functions.
|
||||
@param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
|
||||
@param name: An optional name of the output node.
|
||||
|
||||
@return The new node represents LSTMSequence. Node outputs count: 3.
|
||||
"""
|
||||
@ -2800,11 +2800,11 @@ def tensor_iterator(
|
||||
"""
|
||||
attributes = {
|
||||
"body": graph_body.serialize(),
|
||||
"slice_input_desc": [desc.serialize() for desc in slice_input_desc],
|
||||
"merged_input_desc": [desc.serialize() for desc in merged_input_desc],
|
||||
"invariant_input_desc": [desc.serialize() for desc in invariant_input_desc],
|
||||
"body_output_desc": [desc.serialize() for desc in body_output_desc],
|
||||
"concat_output_desc": [desc.serialize() for desc in concat_output_desc],
|
||||
"input_descriptions": {"slice_input_desc": [desc.serialize() for desc in slice_input_desc],
|
||||
"merged_input_desc": [desc.serialize() for desc in merged_input_desc],
|
||||
"invariant_input_desc": [desc.serialize() for desc in invariant_input_desc]},
|
||||
"output_descriptions": {"body_output_desc": [desc.serialize() for desc in body_output_desc],
|
||||
"concat_output_desc": [desc.serialize() for desc in concat_output_desc]}
|
||||
}
|
||||
|
||||
return _get_node_factory_opset1().create("TensorIterator", as_nodes(*inputs), attributes)
|
||||
|
@ -385,16 +385,56 @@ def rnn_sequence(
|
||||
def loop(
|
||||
trip_count: NodeInput,
|
||||
execution_condition: NodeInput,
|
||||
inputs: List[Node],
|
||||
graph_body: GraphBody,
|
||||
slice_input_desc: List[TensorIteratorSliceInputDesc],
|
||||
merged_input_desc: List[TensorIteratorMergedInputDesc],
|
||||
invariant_input_desc: List[TensorIteratorInvariantInputDesc],
|
||||
body_output_desc: List[TensorIteratorBodyOutputDesc],
|
||||
concat_output_desc: List[TensorIteratorConcatOutputDesc],
|
||||
body_condition_output_idx: int,
|
||||
current_iteration_input_idx: int = -1,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which performs Loop.
|
||||
"""Perform recurrent execution of the network described in the body, iterating through the data.
|
||||
|
||||
@param trip_count: A scalar or 1D tensor with 1 element specifying
|
||||
maximum number of iterations.
|
||||
@param execution_condition: A scalar or 1D tensor with 1 element
|
||||
specifying whether to execute the first iteration or not.
|
||||
@param inputs: The provided to TensorIterator operator.
|
||||
@param graph_body: The graph representing the body we execute.
|
||||
@param slice_input_desc: The descriptors describing sliced inputs, that is nodes
|
||||
representing tensors we iterate through, processing single
|
||||
data slice in one iteration.
|
||||
@param merged_input_desc: The descriptors describing merged inputs, that is nodes
|
||||
representing variables with initial value at first iteration,
|
||||
which may be changing through iterations.
|
||||
@param invariant_input_desc: The descriptors describing invariant inputs, that is nodes
|
||||
representing variable with persistent value through all
|
||||
iterations.
|
||||
@param body_output_desc: The descriptors describing body outputs from specified
|
||||
iteration.
|
||||
@param concat_output_desc: The descriptors describing specified output values through
|
||||
all the iterations concatenated into one node.
|
||||
@param body_condition_output_idx: Determines the purpose of the corresponding result in
|
||||
the graph_body. This result will determine the dynamic
|
||||
exit condition. If the value of this result is False,
|
||||
then iterations stop.
|
||||
@param current_iteration_input_idx: Determines the purpose of the corresponding parameter
|
||||
in the graph_body. This parameter will be used as
|
||||
an iteration counter. Optional.
|
||||
@return: The new node which performs Loop.
|
||||
"""
|
||||
inputs = as_nodes(trip_count, execution_condition)
|
||||
|
||||
return _get_node_factory_opset5().create("Loop", inputs)
|
||||
attributes = {
|
||||
"body": graph_body.serialize(),
|
||||
"input_descriptions": {"slice_input_desc": [desc.serialize() for desc in slice_input_desc],
|
||||
"merged_input_desc": [desc.serialize() for desc in merged_input_desc],
|
||||
"invariant_input_desc": [desc.serialize() for desc in invariant_input_desc]},
|
||||
"output_descriptions": {"body_output_desc": [desc.serialize() for desc in body_output_desc],
|
||||
"concat_output_desc": [desc.serialize() for desc in concat_output_desc]},
|
||||
"special_body_ports": {"body_condition_output_idx": body_condition_output_idx,
|
||||
"current_iteration_input_idx": current_iteration_input_idx}
|
||||
}
|
||||
return _get_node_factory_opset5().create("Loop", as_nodes(trip_count, execution_condition, *inputs),
|
||||
attributes)
|
||||
|
@ -22,7 +22,7 @@ from ngraph.opset4.ops import acosh
|
||||
from ngraph.opset1.ops import add
|
||||
from ngraph.opset1.ops import asin
|
||||
from ngraph.opset4.ops import asinh
|
||||
from ngraph.opset3.ops import assign
|
||||
from ngraph.opset6.ops import assign
|
||||
from ngraph.opset1.ops import atan
|
||||
from ngraph.opset4.ops import atanh
|
||||
from ngraph.opset1.ops import avg_pool
|
||||
@ -114,7 +114,7 @@ from ngraph.opset1.ops import prior_box_clustered
|
||||
from ngraph.opset1.ops import psroi_pooling
|
||||
from ngraph.opset4.ops import proposal
|
||||
from ngraph.opset1.ops import range
|
||||
from ngraph.opset3.ops import read_value
|
||||
from ngraph.opset6.ops import read_value
|
||||
from ngraph.opset4.ops import reduce_l1
|
||||
from ngraph.opset4.ops import reduce_l2
|
||||
from ngraph.opset1.ops import reduce_logical_and
|
||||
|
@ -142,3 +142,35 @@ def mvn(
|
||||
}
|
||||
|
||||
return _get_node_factory_opset6().create("MVN", inputs, attributes)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
|
||||
"""Return a node which produces the Assign operation.
|
||||
|
||||
@param new_value: Node producing a value to be assigned to a variable.
|
||||
@param variable_id: Id of a variable to be updated.
|
||||
@param name: Optional name for output node.
|
||||
@return Assign node
|
||||
"""
|
||||
return _get_node_factory_opset6().create(
|
||||
"Assign",
|
||||
[as_node(new_value)],
|
||||
{"variable_id": variable_id}
|
||||
)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
|
||||
"""Return a node which produces the Assign operation.
|
||||
|
||||
@param init_value: Node producing a value to be returned instead of an unassigned variable.
|
||||
@param variable_id: Id of a variable to be read.
|
||||
@param name: Optional name for output node.
|
||||
@return ReadValue node
|
||||
"""
|
||||
return _get_node_factory_opset6().create(
|
||||
"ReadValue",
|
||||
[as_node(init_value)],
|
||||
{"variable_id": variable_id}
|
||||
)
|
||||
|
@ -21,11 +21,16 @@
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "dict_attribute_visitor.hpp"
|
||||
#include "ngraph/op/loop.hpp"
|
||||
#include "ngraph/op/util/sub_graph_base.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
util::DictAttributeDeserializer::DictAttributeDeserializer(const py::dict& attributes)
|
||||
util::DictAttributeDeserializer::DictAttributeDeserializer(
|
||||
const py::dict& attributes,
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>>& variables)
|
||||
: m_attributes(attributes)
|
||||
, m_variables(variables)
|
||||
{
|
||||
}
|
||||
|
||||
@ -34,7 +39,116 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
NGRAPH_CHECK(false, "No AttributeVisitor support for accessing attribute named: ", name);
|
||||
if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>>(
|
||||
&adapter))
|
||||
{
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>
|
||||
input_descs;
|
||||
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
|
||||
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
|
||||
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
|
||||
for (py::handle h : slice_input_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto slice_in =
|
||||
std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>(),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
input_descs.push_back(slice_in);
|
||||
}
|
||||
|
||||
for (py::handle h : merged_input_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto merged_in =
|
||||
std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>(),
|
||||
desc["body_value_idx"].cast<int64_t>());
|
||||
input_descs.push_back(merged_in);
|
||||
}
|
||||
|
||||
for (py::handle h : invariant_input_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto invariant_in =
|
||||
std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
|
||||
desc["input_idx"].cast<int64_t>(),
|
||||
desc["body_parameter_idx"].cast<int64_t>());
|
||||
input_descs.push_back(invariant_in);
|
||||
}
|
||||
a->set(input_descs);
|
||||
}
|
||||
else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<
|
||||
std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter))
|
||||
{
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>
|
||||
output_descs;
|
||||
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
|
||||
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
|
||||
for (py::handle h : body_output_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto body_output =
|
||||
std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>(),
|
||||
desc["iteration"].cast<int64_t>());
|
||||
output_descs.push_back(body_output);
|
||||
}
|
||||
|
||||
for (py::handle h : concat_output_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
auto concat_output =
|
||||
std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
|
||||
desc["body_value_idx"].cast<int64_t>(),
|
||||
desc["output_idx"].cast<int64_t>(),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
output_descs.push_back(concat_output);
|
||||
}
|
||||
a->set(output_descs);
|
||||
}
|
||||
else if (const auto& a = ngraph::as_type<
|
||||
ngraph::AttributeAdapter<ngraph::op::v5::Loop::SpecialBodyPorts>>(&adapter))
|
||||
{
|
||||
ngraph::op::v5::Loop::SpecialBodyPorts special_body_ports;
|
||||
const py::dict& special_ports_dict = m_attributes[name.c_str()].cast<py::dict>();
|
||||
special_body_ports.body_condition_output_idx =
|
||||
special_ports_dict["body_condition_output_idx"].cast<int64_t>();
|
||||
special_body_ports.current_iteration_input_idx =
|
||||
special_ports_dict["current_iteration_input_idx"].cast<int64_t>();
|
||||
a->set(special_body_ports);
|
||||
}
|
||||
else if (const auto& a =
|
||||
ngraph::as_type<ngraph::AttributeAdapter<std::shared_ptr<ngraph::Variable>>>(
|
||||
&adapter))
|
||||
{
|
||||
std::string variable_id = m_attributes[name.c_str()].cast<std::string>();
|
||||
if (!m_variables.count(variable_id))
|
||||
{
|
||||
m_variables[variable_id] = std::make_shared<ngraph::Variable>(ngraph::VariableInfo{
|
||||
ngraph::PartialShape::dynamic(), ngraph::element::dynamic, variable_id});
|
||||
}
|
||||
a->set(m_variables[variable_id]);
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(
|
||||
false, "No AttributeVisitor support for accessing attribute named: ", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||
@ -222,6 +336,28 @@ void util::DictAttributeDeserializer::on_adapter(
|
||||
}
|
||||
}
|
||||
|
||||
void util::DictAttributeDeserializer::on_adapter(
|
||||
const std::string& name, ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter)
|
||||
{
|
||||
if (m_attributes.contains(name))
|
||||
{
|
||||
if (name == "body")
|
||||
{
|
||||
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
|
||||
const auto& body_outputs =
|
||||
as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
|
||||
const auto& body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
|
||||
auto body = std::make_shared<ngraph::Function>(body_outputs, body_parameters);
|
||||
adapter.set(body);
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(
|
||||
false, "No AttributeVisitor support for accessing attribute named: ", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
util::DictAttributeSerializer::DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
node->visit_attributes(*this);
|
||||
|
@ -22,6 +22,7 @@
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/variable.hpp"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
@ -32,114 +33,96 @@ namespace util
|
||||
class DictAttributeDeserializer : public ngraph::AttributeVisitor
|
||||
{
|
||||
public:
|
||||
DictAttributeDeserializer(const py::dict& attributes);
|
||||
DictAttributeDeserializer(
|
||||
const py::dict& attributes,
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>>& variables);
|
||||
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<bool>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int8_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int16_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int32_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int64_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint8_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint16_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint32_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint64_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<float>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
|
||||
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) override;
|
||||
|
||||
protected:
|
||||
const py::dict& m_attributes;
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>>& m_variables;
|
||||
};
|
||||
|
||||
class DictAttributeSerializer : public ngraph::AttributeVisitor
|
||||
{
|
||||
public:
|
||||
DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node);
|
||||
explicit DictAttributeSerializer(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<void>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<bool>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<int64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint8_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint16_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint32_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<uint64_t>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<float>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<double>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<bool>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::string>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int8_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int16_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int32_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<int64_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint8_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint16_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint32_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<uint64_t>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<float>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<std::string>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int8_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int16_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<int64_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint8_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint16_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint32_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<uint64_t>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<float>>& adapter) override;
|
||||
void on_adapter(const std::string& name,
|
||||
ngraph::ValueAccessor<std::vector<double>>& adapter) override;
|
||||
|
||||
template <typename T>
|
||||
T get_attribute(const std::string& name)
|
||||
|
@ -117,7 +117,8 @@ void regclass_pyngraph_Node(py::module m)
|
||||
[](std::shared_ptr<ngraph::Node>& self, const std::string& atr_name, py::object value) {
|
||||
py::dict attr_dict;
|
||||
attr_dict[atr_name.c_str()] = value;
|
||||
util::DictAttributeDeserializer dict_deserializer(attr_dict);
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> variables;
|
||||
util::DictAttributeDeserializer dict_deserializer(attr_dict, variables);
|
||||
self->visit_attributes(dict_deserializer);
|
||||
});
|
||||
}
|
||||
|
@ -31,9 +31,9 @@
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/util/variable.hpp"
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
#include "node_factory.hpp"
|
||||
#include "tensor_iterator_builder.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@ -60,14 +60,7 @@ namespace
|
||||
"Currently NodeFactory doesn't support Constant node: ",
|
||||
op_type_name);
|
||||
|
||||
if (op_type_name == "TensorIterator")
|
||||
{
|
||||
// XXX: How to differentiate opsets?
|
||||
return util::TensorIteratorBuilder(as_node_vector(arguments), attributes)
|
||||
.configure(std::static_pointer_cast<ngraph::op::TensorIterator>(op_node));
|
||||
}
|
||||
|
||||
util::DictAttributeDeserializer visitor(attributes);
|
||||
util::DictAttributeDeserializer visitor(attributes, m_variables);
|
||||
|
||||
op_node->set_arguments(arguments);
|
||||
op_node->visit_attributes(visitor);
|
||||
@ -105,6 +98,7 @@ namespace
|
||||
}
|
||||
|
||||
const ngraph::OpSet& m_opset = ngraph::get_opset7();
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Variable>> m_variables;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1,224 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "tensor_iterator_builder.hpp"
|
||||
|
||||
util::TensorIteratorBuilder::TensorIteratorBuilder(const ngraph::NodeVector& arguments,
|
||||
const py::dict& attributes)
|
||||
: m_arguments(arguments)
|
||||
, m_attributes(attributes)
|
||||
{
|
||||
get_graph_body();
|
||||
// Set-up TI inputs.
|
||||
NGRAPH_CHECK(m_attributes.contains("slice_input_desc"),
|
||||
"The required \"slice_input_desc\" attribute is missing. Can't build "
|
||||
"TensorIterator operator.");
|
||||
m_slice_input_desc = m_attributes["slice_input_desc"].cast<py::list>();
|
||||
|
||||
if (m_attributes.contains("merged_input_desc"))
|
||||
{
|
||||
m_merged_input_desc = m_attributes["merged_input_desc"].cast<py::list>();
|
||||
}
|
||||
|
||||
if (m_attributes.contains("invariant_input_desc"))
|
||||
{
|
||||
m_invariant_input_desc = m_attributes["invariant_input_desc"].cast<py::list>();
|
||||
}
|
||||
|
||||
if (m_attributes.contains("body_output_desc"))
|
||||
{
|
||||
py::list body_output_desc = m_attributes["body_output_desc"].cast<py::list>();
|
||||
for (py::handle h : body_output_desc)
|
||||
{
|
||||
py::dict desc = h.cast<py::dict>();
|
||||
desc["type"] = "BodyOutputDesc";
|
||||
check_attribute(desc, "output_idx", "BodyOutputDesc");
|
||||
m_outputs.emplace(desc["output_idx"].cast<int64_t>(), desc);
|
||||
}
|
||||
}
|
||||
if (m_attributes.contains("concat_output_desc"))
|
||||
{
|
||||
py::list concat_output_desc = m_attributes["concat_output_desc"].cast<py::list>();
|
||||
for (py::handle h : concat_output_desc)
|
||||
{
|
||||
py::dict desc = h.cast<py::dict>();
|
||||
desc["type"] = "ConcatOutputDesc";
|
||||
check_attribute(desc, "output_idx", "ConcatOutputDesc");
|
||||
m_outputs.emplace(desc["output_idx"].cast<int64_t>(), desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::op::TensorIterator>
|
||||
util::TensorIteratorBuilder::configure(std::shared_ptr<ngraph::op::TensorIterator>&& ti_node)
|
||||
{
|
||||
ti_node->set_body(m_body);
|
||||
set_tensor_iterator_sliced_inputs(ti_node);
|
||||
set_tensor_iterator_merged_inputs(ti_node);
|
||||
set_tensor_iterator_invariant_inputs(ti_node);
|
||||
set_tensor_iterator_outputs(ti_node);
|
||||
ti_node->constructor_validate_and_infer_types();
|
||||
|
||||
return std::move(ti_node);
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::check_attribute(const py::dict& attrs,
|
||||
std::string attr_name,
|
||||
std::string desc_name) const
|
||||
{
|
||||
NGRAPH_CHECK(attrs.contains(attr_name),
|
||||
"The required \"",
|
||||
attr_name,
|
||||
"\" attribute is missing. Can't build TensorIterator's ",
|
||||
desc_name,
|
||||
".");
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::get_graph_body()
|
||||
{
|
||||
NGRAPH_CHECK(m_attributes.contains("body"),
|
||||
"The required \"body\" attribute is missing. Can't build TensorIterator "
|
||||
"operator.");
|
||||
|
||||
const py::dict& body_attrs = m_attributes["body"].cast<py::dict>();
|
||||
|
||||
NGRAPH_CHECK(body_attrs.contains("parameters"),
|
||||
"The required body's \"parameters\" "
|
||||
"attribute is missing. Can't build TensorIterator's body.");
|
||||
NGRAPH_CHECK(body_attrs.contains("results"),
|
||||
"The required body's \"results\" "
|
||||
"attribute is missing. Can't build TensorIterator's body.");
|
||||
|
||||
m_body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
|
||||
m_body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
|
||||
m_body = std::make_shared<ngraph::Function>(m_body_outputs, m_body_parameters);
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::set_tensor_iterator_sliced_inputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
|
||||
{
|
||||
for (py::handle h : m_slice_input_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
check_attribute(desc, "input_idx", "SliceInputDesc");
|
||||
check_attribute(desc, "body_parameter_idx", "SliceInputDesc");
|
||||
check_attribute(desc, "start", "SliceInputDesc");
|
||||
check_attribute(desc, "stride", "SliceInputDesc");
|
||||
check_attribute(desc, "part_size", "SliceInputDesc");
|
||||
check_attribute(desc, "end", "SliceInputDesc");
|
||||
check_attribute(desc, "axis", "SliceInputDesc");
|
||||
|
||||
ti_node->set_sliced_input(m_body_parameters.at(desc["body_parameter_idx"].cast<int64_t>()),
|
||||
m_arguments.at(desc["input_idx"].cast<int64_t>()),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
}
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::set_tensor_iterator_merged_inputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
|
||||
{
|
||||
for (py::handle h : m_merged_input_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
check_attribute(desc, "input_idx", "MergedInputDesc");
|
||||
check_attribute(desc, "body_parameter_idx", "MergedInputDesc");
|
||||
check_attribute(desc, "body_value_idx", "MergedInputDesc");
|
||||
|
||||
ti_node->set_merged_input(m_body_parameters.at(desc["body_parameter_idx"].cast<int64_t>()),
|
||||
m_arguments.at(desc["input_idx"].cast<int64_t>()),
|
||||
m_body_outputs.at(desc["body_value_idx"].cast<int64_t>()));
|
||||
}
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::set_tensor_iterator_invariant_inputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
|
||||
{
|
||||
for (py::handle h : m_invariant_input_desc)
|
||||
{
|
||||
const py::dict& desc = h.cast<py::dict>();
|
||||
check_attribute(desc, "input_idx", "InvariantInputDesc");
|
||||
check_attribute(desc, "body_parameter_idx", "InvariantInputDesc");
|
||||
|
||||
ti_node->set_invariant_input(
|
||||
m_body_parameters.at(desc["body_parameter_idx"].cast<int64_t>()),
|
||||
m_arguments.at(desc["input_idx"].cast<int64_t>()));
|
||||
}
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::set_tensor_iterator_outputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
|
||||
{
|
||||
for (const auto& elem : m_outputs)
|
||||
{
|
||||
const py::dict& desc = elem.second.cast<py::dict>();
|
||||
if (desc["type"].cast<std::string>() == "BodyOutputDesc")
|
||||
{
|
||||
set_tensor_iterator_body_output(desc, ti_node);
|
||||
}
|
||||
else if (desc["type"].cast<std::string>() == "ConcatOutputDesc")
|
||||
{
|
||||
set_tensor_iterator_concatenated_body_output(desc, ti_node);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph::ngraph_error("Unrecognized TensorIterator output type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::set_tensor_iterator_body_output(
|
||||
const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
|
||||
{
|
||||
check_attribute(desc, "body_value_idx", "BodyOutputDesc");
|
||||
check_attribute(desc, "iteration", "BodyOutputDesc");
|
||||
|
||||
NGRAPH_CHECK(desc["output_idx"].cast<size_t>() == ti_node->get_output_size(),
|
||||
"Descriptor output idx value is different from currently configured "
|
||||
"TensorIterator output.");
|
||||
|
||||
ti_node->get_iter_value(m_body_outputs.at(desc["body_value_idx"].cast<int64_t>()),
|
||||
desc["iteration"].cast<int64_t>());
|
||||
}
|
||||
|
||||
void util::TensorIteratorBuilder::set_tensor_iterator_concatenated_body_output(
|
||||
const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const
|
||||
{
|
||||
check_attribute(desc, "body_value_idx", "ConcatOutputDesc");
|
||||
check_attribute(desc, "start", "ConcatOutputDesc");
|
||||
check_attribute(desc, "stride", "ConcatOutputDesc");
|
||||
check_attribute(desc, "part_size", "ConcatOutputDesc");
|
||||
check_attribute(desc, "end", "ConcatOutputDesc");
|
||||
check_attribute(desc, "axis", "ConcatOutputDesc");
|
||||
|
||||
NGRAPH_CHECK(desc["output_idx"].cast<size_t>() == ti_node->get_output_size(),
|
||||
"Descriptor output idx value is different from currently configured "
|
||||
"TensorIterator output.");
|
||||
|
||||
ti_node->get_concatenated_slices(m_body_outputs.at(desc["body_value_idx"].cast<int64_t>()),
|
||||
desc["start"].cast<int64_t>(),
|
||||
desc["stride"].cast<int64_t>(),
|
||||
desc["part_size"].cast<int64_t>(),
|
||||
desc["end"].cast<int64_t>(),
|
||||
desc["axis"].cast<int64_t>());
|
||||
}
|
@ -1,135 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cctype>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/tensor_iterator.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace util
|
||||
{
|
||||
class TensorIteratorBuilder
|
||||
{
|
||||
public:
|
||||
///
|
||||
/// \brief Initialize TensorIterator node builder.
|
||||
///
|
||||
/// \param[in] arguments The arguments passed to TensorIterator node.
|
||||
/// \param[in] attributes The TensorIterator's attributes. This
|
||||
/// py::dict contains all descriptors for
|
||||
/// plethora of TensorIterator available inputs
|
||||
/// and outputs.
|
||||
///
|
||||
TensorIteratorBuilder(const ngraph::NodeVector& arguments, const py::dict& attributes);
|
||||
|
||||
///
|
||||
/// \brief Configure instance of TensorIterator node with set-up parameters.
|
||||
///
|
||||
/// \param ti_node The TensorIterator node instance to configure.
|
||||
///
|
||||
/// \return TensorIterator node.
|
||||
///
|
||||
std::shared_ptr<ngraph::op::TensorIterator>
|
||||
configure(std::shared_ptr<ngraph::op::TensorIterator>&& ti_node);
|
||||
|
||||
private:
|
||||
///
|
||||
/// \brief Helper to conduct attribute presence.
|
||||
///
|
||||
/// \param[in] attrs The attributes
|
||||
/// \param[in] attr_name The attribute name
|
||||
/// \param[in] desc_name The description name
|
||||
///
|
||||
inline void check_attribute(const py::dict& attrs,
|
||||
std::string attr_name,
|
||||
std::string desc_name) const;
|
||||
|
||||
///
|
||||
/// \brief Retrieve the TI graph body.
|
||||
///
|
||||
void get_graph_body();
|
||||
|
||||
///
|
||||
/// \brief Sets the tensor iterator sliced inputs.
|
||||
///
|
||||
/// \param ti_node The TI node we will set input to.
|
||||
///
|
||||
void set_tensor_iterator_sliced_inputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
|
||||
|
||||
///
|
||||
/// \brief Sets the tensor iterator merged inputs.
|
||||
///
|
||||
/// \param ti_node The TI node we will set inputs to.
|
||||
///
|
||||
void set_tensor_iterator_merged_inputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
|
||||
|
||||
///
|
||||
/// \brief Sets the tensor iterator invariant inputs.
|
||||
///
|
||||
/// \param ti_node The TI node we will set inputs to.
|
||||
///
|
||||
void set_tensor_iterator_invariant_inputs(
|
||||
std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
|
||||
|
||||
///
|
||||
/// \brief Sets the tensor iterator outputs.
|
||||
///
|
||||
/// \param ti_node The TI node we will set outputs to.
|
||||
///
|
||||
void
|
||||
set_tensor_iterator_outputs(std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
|
||||
|
||||
///
|
||||
/// \brief Sets the tensor iterator body output.
|
||||
///
|
||||
/// \param[in] desc The descriptor of the TI body output.
|
||||
/// \param ti_node The TI node we will set output to.
|
||||
///
|
||||
void set_tensor_iterator_body_output(
|
||||
const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
|
||||
|
||||
///
|
||||
/// \brief Sets the tensor iterator concatenated body output.
|
||||
///
|
||||
/// \param[in] desc The descriptor of the TI body output.
|
||||
/// \param ti_node The TI node we will set output to.
|
||||
///
|
||||
void set_tensor_iterator_concatenated_body_output(
|
||||
const py::dict& desc, std::shared_ptr<ngraph::op::TensorIterator>& ti_node) const;
|
||||
|
||||
const ngraph::NodeVector& m_arguments;
|
||||
const py::dict& m_attributes;
|
||||
ngraph::OutputVector m_body_outputs;
|
||||
ngraph::ParameterVector m_body_parameters;
|
||||
std::shared_ptr<ngraph::Function> m_body;
|
||||
py::list m_slice_input_desc;
|
||||
py::list m_merged_input_desc;
|
||||
py::list m_invariant_input_desc;
|
||||
std::map<int64_t, const py::dict> m_outputs;
|
||||
};
|
||||
} // namespace util
|
@ -19,10 +19,9 @@ from _pyngraph import PartialShape
|
||||
|
||||
import ngraph as ng
|
||||
import ngraph.opset1 as ng_opset1
|
||||
import ngraph.opset5 as ng_opset5
|
||||
from ngraph.impl import Type
|
||||
|
||||
from tests import skip_segfault
|
||||
|
||||
np_types = [np.float32, np.int32]
|
||||
integral_np_types = [
|
||||
np.int8,
|
||||
@ -718,14 +717,89 @@ def test_rnn_sequence():
|
||||
assert list(node_param.get_output_shape(1)) == expected_shape_h
|
||||
|
||||
|
||||
@skip_segfault
|
||||
def test_loop():
|
||||
trip_count = 8
|
||||
condition = True
|
||||
from ngraph.utils.tensor_iterator_types import (
|
||||
GraphBody,
|
||||
TensorIteratorSliceInputDesc,
|
||||
TensorIteratorMergedInputDesc,
|
||||
TensorIteratorInvariantInputDesc,
|
||||
TensorIteratorBodyOutputDesc,
|
||||
TensorIteratorConcatOutputDesc,
|
||||
)
|
||||
|
||||
node_default = ng.loop(trip_count, condition)
|
||||
condition = ng.constant(True, dtype=np.bool)
|
||||
trip_count = ng.constant(16, dtype=np.int32)
|
||||
# Body parameters
|
||||
body_timestep = ng.parameter([], np.int32, "timestep")
|
||||
body_data_in = ng.parameter([1, 2, 2], np.float32, "body_in")
|
||||
body_prev_cma = ng.parameter([2, 2], np.float32, "body_prev_cma")
|
||||
body_const_one = ng.parameter([], np.int32, "body_const_one")
|
||||
|
||||
assert node_default.get_type_name() == "Loop"
|
||||
# CMA = cumulative moving average
|
||||
prev_cum_sum = ng.multiply(ng.convert(body_timestep, "f32"), body_prev_cma)
|
||||
curr_cum_sum = ng.add(prev_cum_sum, ng.squeeze(body_data_in, [0]))
|
||||
elem_cnt = ng.add(body_const_one, body_timestep)
|
||||
curr_cma = ng.divide(curr_cum_sum, ng.convert(elem_cnt, "f32"))
|
||||
cma_hist = ng.unsqueeze(curr_cma, [0])
|
||||
|
||||
# TI inputs
|
||||
data = ng.parameter([16, 2, 2], np.float32, "data")
|
||||
# Iterations count
|
||||
zero = ng.constant(0, dtype=np.int32)
|
||||
one = ng.constant(1, dtype=np.int32)
|
||||
initial_cma = ng.constant(np.zeros([2, 2], dtype=np.float32), dtype=np.float32)
|
||||
iter_cnt = ng.range(zero, np.int32(16), np.int32(1))
|
||||
ti_inputs = [iter_cnt, data, initial_cma, one]
|
||||
body_const_condition = ng.constant(True, dtype=np.bool)
|
||||
|
||||
graph_body = GraphBody([body_timestep, body_data_in, body_prev_cma, body_const_one],
|
||||
[curr_cma, cma_hist, body_const_condition])
|
||||
ti_slice_input_desc = [
|
||||
# timestep
|
||||
# input_idx, body_param_idx, start, stride, part_size, end, axis
|
||||
TensorIteratorSliceInputDesc(2, 0, 0, 1, 1, -1, 0),
|
||||
# data
|
||||
TensorIteratorSliceInputDesc(3, 1, 0, 1, 1, -1, 0),
|
||||
]
|
||||
ti_merged_input_desc = [
|
||||
# body prev/curr_cma
|
||||
TensorIteratorMergedInputDesc(4, 2, 0),
|
||||
]
|
||||
ti_invariant_input_desc = [
|
||||
# body const one
|
||||
TensorIteratorInvariantInputDesc(5, 3),
|
||||
]
|
||||
|
||||
# TI outputs
|
||||
ti_body_output_desc = [
|
||||
# final average
|
||||
TensorIteratorBodyOutputDesc(0, 0, -1),
|
||||
]
|
||||
ti_concat_output_desc = [
|
||||
# history of cma
|
||||
TensorIteratorConcatOutputDesc(1, 1, 0, 1, 1, -1, 0),
|
||||
]
|
||||
|
||||
node = ng.loop(
|
||||
trip_count,
|
||||
condition,
|
||||
ti_inputs,
|
||||
graph_body,
|
||||
ti_slice_input_desc,
|
||||
ti_merged_input_desc,
|
||||
ti_invariant_input_desc,
|
||||
ti_body_output_desc,
|
||||
ti_concat_output_desc,
|
||||
2,
|
||||
-1,
|
||||
)
|
||||
|
||||
assert node.get_type_name() == "Loop"
|
||||
assert node.get_output_size() == 2
|
||||
# final average
|
||||
assert list(node.get_output_shape(0)) == [2, 2]
|
||||
# cma history
|
||||
assert list(node.get_output_shape(1)) == [16, 2, 2]
|
||||
|
||||
|
||||
def test_roi_pooling():
|
||||
@ -1096,6 +1170,28 @@ def test_tensor_iterator():
|
||||
assert list(node.get_output_shape(1)) == [16, 2, 2]
|
||||
|
||||
|
||||
def test_read_value_opset5():
|
||||
init_value = ng_opset5.parameter([2, 2], name="init_value", dtype=np.int32)
|
||||
|
||||
node = ng_opset5.read_value(init_value, "var_id_667")
|
||||
|
||||
assert node.get_type_name() == "ReadValue"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [2, 2]
|
||||
assert node.get_output_element_type(0) == Type.i32
|
||||
|
||||
|
||||
def test_assign_opset5():
|
||||
input_data = ng_opset5.parameter([5, 7], name="input_data", dtype=np.int32)
|
||||
rv = ng_opset5.read_value(input_data, "var_id_667")
|
||||
node = ng_opset5.assign(rv, "var_id_667")
|
||||
|
||||
assert node.get_type_name() == "Assign"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [5, 7]
|
||||
assert node.get_output_element_type(0) == Type.i32
|
||||
|
||||
|
||||
def test_read_value():
|
||||
init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user