Python API for Assign, ReadValue and ExtractImagePatches (#719)
This commit is contained in:
parent
63a77bb4a1
commit
53927034da
@ -29,6 +29,7 @@ from ngraph.ops import absolute as abs
|
|||||||
from ngraph.ops import acos
|
from ngraph.ops import acos
|
||||||
from ngraph.ops import add
|
from ngraph.ops import add
|
||||||
from ngraph.ops import asin
|
from ngraph.ops import asin
|
||||||
|
from ngraph.ops import assign
|
||||||
from ngraph.ops import atan
|
from ngraph.ops import atan
|
||||||
from ngraph.ops import avg_pool
|
from ngraph.ops import avg_pool
|
||||||
from ngraph.ops import batch_norm_inference
|
from ngraph.ops import batch_norm_inference
|
||||||
@ -59,6 +60,7 @@ from ngraph.ops import elu
|
|||||||
from ngraph.ops import embedding_bag_offsets_sum
|
from ngraph.ops import embedding_bag_offsets_sum
|
||||||
from ngraph.ops import embedding_bag_packed_sum
|
from ngraph.ops import embedding_bag_packed_sum
|
||||||
from ngraph.ops import embedding_segments_sum
|
from ngraph.ops import embedding_segments_sum
|
||||||
|
from ngraph.ops import extract_image_patches
|
||||||
from ngraph.ops import equal
|
from ngraph.ops import equal
|
||||||
from ngraph.ops import erf
|
from ngraph.ops import erf
|
||||||
from ngraph.ops import exp
|
from ngraph.ops import exp
|
||||||
@ -108,6 +110,7 @@ from ngraph.ops import prior_box
|
|||||||
from ngraph.ops import prior_box_clustered
|
from ngraph.ops import prior_box_clustered
|
||||||
from ngraph.ops import psroi_pooling
|
from ngraph.ops import psroi_pooling
|
||||||
from ngraph.ops import proposal
|
from ngraph.ops import proposal
|
||||||
|
from ngraph.ops import read_value
|
||||||
from ngraph.ops import reduce_logical_and
|
from ngraph.ops import reduce_logical_and
|
||||||
from ngraph.ops import reduce_logical_or
|
from ngraph.ops import reduce_logical_or
|
||||||
from ngraph.ops import reduce_max
|
from ngraph.ops import reduce_max
|
||||||
|
@ -3438,3 +3438,53 @@ def proposal(
|
|||||||
return _get_node_factory().create(
|
return _get_node_factory().create(
|
||||||
"Proposal", [class_probs, box_logits, as_node(image_shape)], attrs
|
"Proposal", [class_probs, box_logits, as_node(image_shape)], attrs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@nameable_op
|
||||||
|
def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
|
||||||
|
"""Return a node which produces the Assign operation.
|
||||||
|
|
||||||
|
:param new_value: Node producing a value to be assigned to a variable.
|
||||||
|
:param variable_id: Id of a variable to be updated.
|
||||||
|
:param name: Optional name for output node.
|
||||||
|
:return: Assign node
|
||||||
|
"""
|
||||||
|
return _get_node_factory().create("Assign", [as_node(new_value)], {"variable_id": variable_id})
|
||||||
|
|
||||||
|
|
||||||
|
@nameable_op
|
||||||
|
def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node:
|
||||||
|
"""Return a node which produces the Assign operation.
|
||||||
|
|
||||||
|
:param init_value: Node producing a value to be returned instead of an unassigned variable.
|
||||||
|
:param variable_id: Id of a variable to be read.
|
||||||
|
:param name: Optional name for output node.
|
||||||
|
:return: ReadValue node
|
||||||
|
"""
|
||||||
|
return _get_node_factory().create("ReadValue", [as_node(init_value)], {"variable_id": variable_id})
|
||||||
|
|
||||||
|
|
||||||
|
@nameable_op
|
||||||
|
def extract_image_patches(
|
||||||
|
image: NodeInput,
|
||||||
|
sizes: TensorShape,
|
||||||
|
strides: List[int],
|
||||||
|
rates: TensorShape,
|
||||||
|
auto_pad: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Node:
|
||||||
|
"""Return a node which produces the ExtractImagePatches operation.
|
||||||
|
|
||||||
|
:param image: 4-D Input data to extract image patches.
|
||||||
|
:param sizes: Patch size in the format of [size_rows, size_cols].
|
||||||
|
:param strides: Patch movement stride in the format of [stride_rows, stride_cols]
|
||||||
|
:param rates: Element seleciton rate for creating a patch.
|
||||||
|
:param auto_pad: Padding type.
|
||||||
|
:param name: Optional name for output node.
|
||||||
|
:return: ExtractImagePatches node
|
||||||
|
"""
|
||||||
|
return _get_node_factory().create(
|
||||||
|
"ExtractImagePatches",
|
||||||
|
[as_node(image)],
|
||||||
|
{"sizes": sizes, "strides": strides, "rates": rates, "auto_pad": auto_pad},
|
||||||
|
)
|
||||||
|
@ -845,3 +845,39 @@ def test_proposal(int_dtype, fp_dtype):
|
|||||||
assert node.get_type_name() == "Proposal"
|
assert node.get_type_name() == "Proposal"
|
||||||
assert node.get_output_size() == 1
|
assert node.get_output_size() == 1
|
||||||
assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5]
|
assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_value():
|
||||||
|
init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32)
|
||||||
|
|
||||||
|
node = ng.read_value(init_value, "var_id_667")
|
||||||
|
|
||||||
|
assert node.get_type_name() == "ReadValue"
|
||||||
|
assert node.get_output_size() == 1
|
||||||
|
assert list(node.get_output_shape(0)) == [2, 2]
|
||||||
|
assert node.get_output_element_type(0) == Type.i32
|
||||||
|
|
||||||
|
|
||||||
|
def test_assign():
|
||||||
|
input_data = ng.parameter([5, 7], name="input_data", dtype=np.int32)
|
||||||
|
rv = ng.read_value(input_data, "var_id_667")
|
||||||
|
node = ng.assign(rv, "var_id_667")
|
||||||
|
|
||||||
|
assert node.get_type_name() == "Assign"
|
||||||
|
assert node.get_output_size() == 1
|
||||||
|
assert list(node.get_output_shape(0)) == [5, 7]
|
||||||
|
assert node.get_output_element_type(0) == Type.i32
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_image_patches():
|
||||||
|
image = ng.parameter([64, 3, 10, 10], name="image", dtype=np.int32)
|
||||||
|
sizes = [3, 3];
|
||||||
|
strides = [5, 5];
|
||||||
|
rates = [1, 1];
|
||||||
|
padding = "VALID";
|
||||||
|
node = ng.extract_image_patches(image, sizes, strides, rates, padding)
|
||||||
|
|
||||||
|
assert node.get_type_name() == "ExtractImagePatches"
|
||||||
|
assert node.get_output_size() == 1
|
||||||
|
assert list(node.get_output_shape(0)) == [64, 27, 2, 2]
|
||||||
|
assert node.get_output_element_type(0) == Type.i32
|
||||||
|
@ -21,8 +21,8 @@ using namespace ngraph;
|
|||||||
|
|
||||||
constexpr NodeTypeInfo op::ReadValue::type_info;
|
constexpr NodeTypeInfo op::ReadValue::type_info;
|
||||||
|
|
||||||
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id)
|
op::ReadValue::ReadValue(const Output<Node>& init_value, const std::string& variable_id)
|
||||||
: Op({new_value})
|
: Op({init_value})
|
||||||
, m_variable_id(variable_id)
|
, m_variable_id(variable_id)
|
||||||
{
|
{
|
||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
|
@ -36,9 +36,9 @@ namespace ngraph
|
|||||||
|
|
||||||
/// \brief Constructs a ReadValue operation.
|
/// \brief Constructs a ReadValue operation.
|
||||||
///
|
///
|
||||||
/// \param new_value Node that produces the input tensor.
|
/// \param init_value Node that produces the input tensor.
|
||||||
/// \param variable_id identificator of the variable to create.
|
/// \param variable_id identificator of the variable to create.
|
||||||
ReadValue(const Output<Node>& new_value, const std::string& variable_id);
|
ReadValue(const Output<Node>& init_value, const std::string& variable_id);
|
||||||
|
|
||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user