Python API for Assign, ReadValue and ExtractImagePatches (#719)

This commit is contained in:
Tomasz Dołbniak 2020-06-03 15:01:43 +02:00 committed by GitHub
parent 63a77bb4a1
commit 53927034da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 94 additions and 5 deletions

View File

@ -29,6 +29,7 @@ from ngraph.ops import absolute as abs
from ngraph.ops import acos from ngraph.ops import acos
from ngraph.ops import add from ngraph.ops import add
from ngraph.ops import asin from ngraph.ops import asin
from ngraph.ops import assign
from ngraph.ops import atan from ngraph.ops import atan
from ngraph.ops import avg_pool from ngraph.ops import avg_pool
from ngraph.ops import batch_norm_inference from ngraph.ops import batch_norm_inference
@ -59,6 +60,7 @@ from ngraph.ops import elu
from ngraph.ops import embedding_bag_offsets_sum from ngraph.ops import embedding_bag_offsets_sum
from ngraph.ops import embedding_bag_packed_sum from ngraph.ops import embedding_bag_packed_sum
from ngraph.ops import embedding_segments_sum from ngraph.ops import embedding_segments_sum
from ngraph.ops import extract_image_patches
from ngraph.ops import equal from ngraph.ops import equal
from ngraph.ops import erf from ngraph.ops import erf
from ngraph.ops import exp from ngraph.ops import exp
@ -108,6 +110,7 @@ from ngraph.ops import prior_box
from ngraph.ops import prior_box_clustered from ngraph.ops import prior_box_clustered
from ngraph.ops import psroi_pooling from ngraph.ops import psroi_pooling
from ngraph.ops import proposal from ngraph.ops import proposal
from ngraph.ops import read_value
from ngraph.ops import reduce_logical_and from ngraph.ops import reduce_logical_and
from ngraph.ops import reduce_logical_or from ngraph.ops import reduce_logical_or
from ngraph.ops import reduce_max from ngraph.ops import reduce_max

View File

@ -3438,3 +3438,53 @@ def proposal(
return _get_node_factory().create( return _get_node_factory().create(
"Proposal", [class_probs, box_logits, as_node(image_shape)], attrs "Proposal", [class_probs, box_logits, as_node(image_shape)], attrs
) )
@nameable_op
def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
"""Return a node which produces the Assign operation.
:param new_value: Node producing a value to be assigned to a variable.
:param variable_id: Id of a variable to be updated.
:param name: Optional name for output node.
:return: Assign node
"""
return _get_node_factory().create("Assign", [as_node(new_value)], {"variable_id": variable_id})
@nameable_op
def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
"""Return a node which produces the Assign operation.
:param init_value: Node producing a value to be returned instead of an unassigned variable.
:param variable_id: Id of a variable to be read.
:param name: Optional name for output node.
:return: ReadValue node
"""
return _get_node_factory().create("ReadValue", [as_node(init_value)], {"variable_id": variable_id})
@nameable_op
def extract_image_patches(
image: NodeInput,
sizes: TensorShape,
strides: List[int],
rates: TensorShape,
auto_pad: str,
name: Optional[str] = None,
) -> Node:
"""Return a node which produces the ExtractImagePatches operation.
:param image: 4-D Input data to extract image patches.
:param sizes: Patch size in the format of [size_rows, size_cols].
:param strides: Patch movement stride in the format of [stride_rows, stride_cols]
:param rates: Element seleciton rate for creating a patch.
:param auto_pad: Padding type.
:param name: Optional name for output node.
:return: ExtractImagePatches node
"""
return _get_node_factory().create(
"ExtractImagePatches",
[as_node(image)],
{"sizes": sizes, "strides": strides, "rates": rates, "auto_pad": auto_pad},
)

View File

@ -845,3 +845,39 @@ def test_proposal(int_dtype, fp_dtype):
assert node.get_type_name() == "Proposal" assert node.get_type_name() == "Proposal"
assert node.get_output_size() == 1 assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5] assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5]
def test_read_value():
init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32)
node = ng.read_value(init_value, "var_id_667")
assert node.get_type_name() == "ReadValue"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [2, 2]
assert node.get_output_element_type(0) == Type.i32
def test_assign():
input_data = ng.parameter([5, 7], name="input_data", dtype=np.int32)
rv = ng.read_value(input_data, "var_id_667")
node = ng.assign(rv, "var_id_667")
assert node.get_type_name() == "Assign"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [5, 7]
assert node.get_output_element_type(0) == Type.i32
def test_extract_image_patches():
image = ng.parameter([64, 3, 10, 10], name="image", dtype=np.int32)
sizes = [3, 3];
strides = [5, 5];
rates = [1, 1];
padding = "VALID";
node = ng.extract_image_patches(image, sizes, strides, rates, padding)
assert node.get_type_name() == "ExtractImagePatches"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [64, 27, 2, 2]
assert node.get_output_element_type(0) == Type.i32

View File

@ -21,8 +21,8 @@ using namespace ngraph;
constexpr NodeTypeInfo op::ReadValue::type_info; constexpr NodeTypeInfo op::ReadValue::type_info;
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id) op::ReadValue::ReadValue(const Output<Node>& init_value, const std::string& variable_id)
: Op({new_value}) : Op({init_value})
, m_variable_id(variable_id) , m_variable_id(variable_id)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();

View File

@ -36,9 +36,9 @@ namespace ngraph
/// \brief Constructs a ReadValue operation. /// \brief Constructs a ReadValue operation.
/// ///
/// \param new_value Node that produces the input tensor. /// \param init_value Node that produces the input tensor.
/// \param variable_id identificator of the variable to create. /// \param variable_id identificator of the variable to create.
ReadValue(const Output<Node>& new_value, const std::string& variable_id); ReadValue(const Output<Node>& init_value, const std::string& variable_id);
void validate_and_infer_types() override; void validate_and_infer_types() override;