Python API for Assign, ReadValue and ExtractImagePatches (#719)

This commit is contained in:
Tomasz Dołbniak 2020-06-03 15:01:43 +02:00 committed by GitHub
parent 63a77bb4a1
commit 53927034da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 5 deletions

View File

@ -29,6 +29,7 @@ from ngraph.ops import absolute as abs
from ngraph.ops import acos
from ngraph.ops import add
from ngraph.ops import asin
from ngraph.ops import assign
from ngraph.ops import atan
from ngraph.ops import avg_pool
from ngraph.ops import batch_norm_inference
@ -59,6 +60,7 @@ from ngraph.ops import elu
from ngraph.ops import embedding_bag_offsets_sum
from ngraph.ops import embedding_bag_packed_sum
from ngraph.ops import embedding_segments_sum
from ngraph.ops import extract_image_patches
from ngraph.ops import equal
from ngraph.ops import erf
from ngraph.ops import exp
@ -108,6 +110,7 @@ from ngraph.ops import prior_box
from ngraph.ops import prior_box_clustered
from ngraph.ops import psroi_pooling
from ngraph.ops import proposal
from ngraph.ops import read_value
from ngraph.ops import reduce_logical_and
from ngraph.ops import reduce_logical_or
from ngraph.ops import reduce_max

View File

@ -3438,3 +3438,53 @@ def proposal(
return _get_node_factory().create(
"Proposal", [class_probs, box_logits, as_node(image_shape)], attrs
)
@nameable_op
def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
"""Return a node which produces the Assign operation.
:param new_value: Node producing a value to be assigned to a variable.
:param variable_id: Id of a variable to be updated.
:param name: Optional name for output node.
:return: Assign node
"""
return _get_node_factory().create("Assign", [as_node(new_value)], {"variable_id": variable_id})
@nameable_op
def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
"""Return a node which produces the Assign operation.
:param init_value: Node producing a value to be returned instead of an unassigned variable.
:param variable_id: Id of a variable to be read.
:param name: Optional name for output node.
:return: ReadValue node
"""
return _get_node_factory().create("ReadValue", [as_node(init_value)], {"variable_id": variable_id})
@nameable_op
def extract_image_patches(
image: NodeInput,
sizes: TensorShape,
strides: List[int],
rates: TensorShape,
auto_pad: str,
name: Optional[str] = None,
) -> Node:
"""Return a node which produces the ExtractImagePatches operation.
:param image: 4-D Input data to extract image patches.
:param sizes: Patch size in the format of [size_rows, size_cols].
:param strides: Patch movement stride in the format of [stride_rows, stride_cols]
:param rates: Element seleciton rate for creating a patch.
:param auto_pad: Padding type.
:param name: Optional name for output node.
:return: ExtractImagePatches node
"""
return _get_node_factory().create(
"ExtractImagePatches",
[as_node(image)],
{"sizes": sizes, "strides": strides, "rates": rates, "auto_pad": auto_pad},
)

View File

@ -845,3 +845,39 @@ def test_proposal(int_dtype, fp_dtype):
assert node.get_type_name() == "Proposal"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5]
def test_read_value():
init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32)
node = ng.read_value(init_value, "var_id_667")
assert node.get_type_name() == "ReadValue"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [2, 2]
assert node.get_output_element_type(0) == Type.i32
def test_assign():
input_data = ng.parameter([5, 7], name="input_data", dtype=np.int32)
rv = ng.read_value(input_data, "var_id_667")
node = ng.assign(rv, "var_id_667")
assert node.get_type_name() == "Assign"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [5, 7]
assert node.get_output_element_type(0) == Type.i32
def test_extract_image_patches():
image = ng.parameter([64, 3, 10, 10], name="image", dtype=np.int32)
sizes = [3, 3];
strides = [5, 5];
rates = [1, 1];
padding = "VALID";
node = ng.extract_image_patches(image, sizes, strides, rates, padding)
assert node.get_type_name() == "ExtractImagePatches"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [64, 27, 2, 2]
assert node.get_output_element_type(0) == Type.i32

View File

@ -21,8 +21,8 @@ using namespace ngraph;
constexpr NodeTypeInfo op::ReadValue::type_info;
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id)
: Op({new_value})
op::ReadValue::ReadValue(const Output<Node>& init_value, const std::string& variable_id)
: Op({init_value})
, m_variable_id(variable_id)
{
constructor_validate_and_infer_types();

View File

@ -36,9 +36,9 @@ namespace ngraph
/// \brief Constructs a ReadValue operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable_id identificator of the variable to create.
ReadValue(const Output<Node>& new_value, const std::string& variable_id);
/// \param init_value Node that produces the input tensor.
/// \param variable_id identificator of the variable to create.
ReadValue(const Output<Node>& init_value, const std::string& variable_id);
void validate_and_infer_types() override;