Python API for Assign, ReadValue and ExtractImagePatches (#719)
This commit is contained in:
parent
63a77bb4a1
commit
53927034da
@ -29,6 +29,7 @@ from ngraph.ops import absolute as abs
|
||||
from ngraph.ops import acos
|
||||
from ngraph.ops import add
|
||||
from ngraph.ops import asin
|
||||
from ngraph.ops import assign
|
||||
from ngraph.ops import atan
|
||||
from ngraph.ops import avg_pool
|
||||
from ngraph.ops import batch_norm_inference
|
||||
@ -59,6 +60,7 @@ from ngraph.ops import elu
|
||||
from ngraph.ops import embedding_bag_offsets_sum
|
||||
from ngraph.ops import embedding_bag_packed_sum
|
||||
from ngraph.ops import embedding_segments_sum
|
||||
from ngraph.ops import extract_image_patches
|
||||
from ngraph.ops import equal
|
||||
from ngraph.ops import erf
|
||||
from ngraph.ops import exp
|
||||
@ -108,6 +110,7 @@ from ngraph.ops import prior_box
|
||||
from ngraph.ops import prior_box_clustered
|
||||
from ngraph.ops import psroi_pooling
|
||||
from ngraph.ops import proposal
|
||||
from ngraph.ops import read_value
|
||||
from ngraph.ops import reduce_logical_and
|
||||
from ngraph.ops import reduce_logical_or
|
||||
from ngraph.ops import reduce_max
|
||||
|
@ -3438,3 +3438,53 @@ def proposal(
|
||||
return _get_node_factory().create(
|
||||
"Proposal", [class_probs, box_logits, as_node(image_shape)], attrs
|
||||
)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
|
||||
"""Return a node which produces the Assign operation.
|
||||
|
||||
:param new_value: Node producing a value to be assigned to a variable.
|
||||
:param variable_id: Id of a variable to be updated.
|
||||
:param name: Optional name for output node.
|
||||
:return: Assign node
|
||||
"""
|
||||
return _get_node_factory().create("Assign", [as_node(new_value)], {"variable_id": variable_id})
|
||||
|
||||
|
||||
@nameable_op
|
||||
def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
|
||||
"""Return a node which produces the Assign operation.
|
||||
|
||||
:param init_value: Node producing a value to be returned instead of an unassigned variable.
|
||||
:param variable_id: Id of a variable to be read.
|
||||
:param name: Optional name for output node.
|
||||
:return: ReadValue node
|
||||
"""
|
||||
return _get_node_factory().create("ReadValue", [as_node(init_value)], {"variable_id": variable_id})
|
||||
|
||||
|
||||
@nameable_op
|
||||
def extract_image_patches(
|
||||
image: NodeInput,
|
||||
sizes: TensorShape,
|
||||
strides: List[int],
|
||||
rates: TensorShape,
|
||||
auto_pad: str,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which produces the ExtractImagePatches operation.
|
||||
|
||||
:param image: 4-D Input data to extract image patches.
|
||||
:param sizes: Patch size in the format of [size_rows, size_cols].
|
||||
:param strides: Patch movement stride in the format of [stride_rows, stride_cols]
|
||||
:param rates: Element seleciton rate for creating a patch.
|
||||
:param auto_pad: Padding type.
|
||||
:param name: Optional name for output node.
|
||||
:return: ExtractImagePatches node
|
||||
"""
|
||||
return _get_node_factory().create(
|
||||
"ExtractImagePatches",
|
||||
[as_node(image)],
|
||||
{"sizes": sizes, "strides": strides, "rates": rates, "auto_pad": auto_pad},
|
||||
)
|
||||
|
@ -845,3 +845,39 @@ def test_proposal(int_dtype, fp_dtype):
|
||||
assert node.get_type_name() == "Proposal"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5]
|
||||
|
||||
|
||||
def test_read_value():
|
||||
init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32)
|
||||
|
||||
node = ng.read_value(init_value, "var_id_667")
|
||||
|
||||
assert node.get_type_name() == "ReadValue"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [2, 2]
|
||||
assert node.get_output_element_type(0) == Type.i32
|
||||
|
||||
|
||||
def test_assign():
|
||||
input_data = ng.parameter([5, 7], name="input_data", dtype=np.int32)
|
||||
rv = ng.read_value(input_data, "var_id_667")
|
||||
node = ng.assign(rv, "var_id_667")
|
||||
|
||||
assert node.get_type_name() == "Assign"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [5, 7]
|
||||
assert node.get_output_element_type(0) == Type.i32
|
||||
|
||||
|
||||
def test_extract_image_patches():
|
||||
image = ng.parameter([64, 3, 10, 10], name="image", dtype=np.int32)
|
||||
sizes = [3, 3];
|
||||
strides = [5, 5];
|
||||
rates = [1, 1];
|
||||
padding = "VALID";
|
||||
node = ng.extract_image_patches(image, sizes, strides, rates, padding)
|
||||
|
||||
assert node.get_type_name() == "ExtractImagePatches"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [64, 27, 2, 2]
|
||||
assert node.get_output_element_type(0) == Type.i32
|
||||
|
@ -21,8 +21,8 @@ using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::ReadValue::type_info;
|
||||
|
||||
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id)
|
||||
: Op({new_value})
|
||||
op::ReadValue::ReadValue(const Output<Node>& init_value, const std::string& variable_id)
|
||||
: Op({init_value})
|
||||
, m_variable_id(variable_id)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
|
@ -36,9 +36,9 @@ namespace ngraph
|
||||
|
||||
/// \brief Constructs a ReadValue operation.
|
||||
///
|
||||
/// \param new_value Node that produces the input tensor.
|
||||
/// \param variable_id identificator of the variable to create.
|
||||
ReadValue(const Output<Node>& new_value, const std::string& variable_id);
|
||||
/// \param init_value Node that produces the input tensor.
|
||||
/// \param variable_id identificator of the variable to create.
|
||||
ReadValue(const Output<Node>& init_value, const std::string& variable_id);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user