diff --git a/ngraph/python/src/ngraph/__init__.py b/ngraph/python/src/ngraph/__init__.py index 6055ba92e59..d644fda3758 100644 --- a/ngraph/python/src/ngraph/__init__.py +++ b/ngraph/python/src/ngraph/__init__.py @@ -29,6 +29,7 @@ from ngraph.ops import absolute as abs from ngraph.ops import acos from ngraph.ops import add from ngraph.ops import asin +from ngraph.ops import assign from ngraph.ops import atan from ngraph.ops import avg_pool from ngraph.ops import batch_norm_inference @@ -59,6 +60,7 @@ from ngraph.ops import elu from ngraph.ops import embedding_bag_offsets_sum from ngraph.ops import embedding_bag_packed_sum from ngraph.ops import embedding_segments_sum +from ngraph.ops import extract_image_patches from ngraph.ops import equal from ngraph.ops import erf from ngraph.ops import exp @@ -108,6 +110,7 @@ from ngraph.ops import prior_box from ngraph.ops import prior_box_clustered from ngraph.ops import psroi_pooling from ngraph.ops import proposal +from ngraph.ops import read_value from ngraph.ops import reduce_logical_and from ngraph.ops import reduce_logical_or from ngraph.ops import reduce_max diff --git a/ngraph/python/src/ngraph/ops.py b/ngraph/python/src/ngraph/ops.py index 51a299ef8f9..8d5fe41dfb4 100644 --- a/ngraph/python/src/ngraph/ops.py +++ b/ngraph/python/src/ngraph/ops.py @@ -3438,3 +3438,53 @@ def proposal( return _get_node_factory().create( "Proposal", [class_probs, box_logits, as_node(image_shape)], attrs ) + + +@nameable_op +def assign(new_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node: + """Return a node which produces the Assign operation. + + :param new_value: Node producing a value to be assigned to a variable. + :param variable_id: Id of a variable to be updated. + :param name: Optional name for output node. + :return: Assign node + """ + return _get_node_factory().create("Assign", [as_node(new_value)], {"variable_id": variable_id}) + + +@nameable_op +def read_value(init_value: NodeInput, variable_id: str, name: Optional[str] = None) -> Node: + """Return a node which produces the Assign operation. + + :param init_value: Node producing a value to be returned instead of an unassigned variable. + :param variable_id: Id of a variable to be read. + :param name: Optional name for output node. + :return: ReadValue node + """ + return _get_node_factory().create("ReadValue", [as_node(init_value)], {"variable_id": variable_id}) + + +@nameable_op +def extract_image_patches( + image: NodeInput, + sizes: TensorShape, + strides: List[int], + rates: TensorShape, + auto_pad: str, + name: Optional[str] = None, +) -> Node: + """Return a node which produces the ExtractImagePatches operation. + + :param image: 4-D Input data to extract image patches. + :param sizes: Patch size in the format of [size_rows, size_cols]. + :param strides: Patch movement stride in the format of [stride_rows, stride_cols] + :param rates: Element seleciton rate for creating a patch. + :param auto_pad: Padding type. + :param name: Optional name for output node. + :return: ExtractImagePatches node + """ + return _get_node_factory().create( + "ExtractImagePatches", + [as_node(image)], + {"sizes": sizes, "strides": strides, "rates": rates, "auto_pad": auto_pad}, + ) diff --git a/ngraph/python/test/ngraph/test_create_op.py b/ngraph/python/test/ngraph/test_create_op.py index abb50adce9e..9b041d8fc8d 100644 --- a/ngraph/python/test/ngraph/test_create_op.py +++ b/ngraph/python/test/ngraph/test_create_op.py @@ -845,3 +845,39 @@ def test_proposal(int_dtype, fp_dtype): assert node.get_type_name() == "Proposal" assert node.get_output_size() == 1 assert list(node.get_output_shape(0)) == [batch_size * attributes["attrs.post_nms_topn"], 5] + + +def test_read_value(): + init_value = ng.parameter([2, 2], name="init_value", dtype=np.int32) + + node = ng.read_value(init_value, "var_id_667") + + assert node.get_type_name() == "ReadValue" + assert node.get_output_size() == 1 + assert list(node.get_output_shape(0)) == [2, 2] + assert node.get_output_element_type(0) == Type.i32 + + +def test_assign(): + input_data = ng.parameter([5, 7], name="input_data", dtype=np.int32) + rv = ng.read_value(input_data, "var_id_667") + node = ng.assign(rv, "var_id_667") + + assert node.get_type_name() == "Assign" + assert node.get_output_size() == 1 + assert list(node.get_output_shape(0)) == [5, 7] + assert node.get_output_element_type(0) == Type.i32 + + +def test_extract_image_patches(): + image = ng.parameter([64, 3, 10, 10], name="image", dtype=np.int32) + sizes = [3, 3]; + strides = [5, 5]; + rates = [1, 1]; + padding = "VALID"; + node = ng.extract_image_patches(image, sizes, strides, rates, padding) + + assert node.get_type_name() == "ExtractImagePatches" + assert node.get_output_size() == 1 + assert list(node.get_output_shape(0)) == [64, 27, 2, 2] + assert node.get_output_element_type(0) == Type.i32 diff --git a/ngraph/src/ngraph/op/read_value.cpp b/ngraph/src/ngraph/op/read_value.cpp index 9f7abb6ed7d..f6581a6b6ed 100644 --- a/ngraph/src/ngraph/op/read_value.cpp +++ b/ngraph/src/ngraph/op/read_value.cpp @@ -21,8 +21,8 @@ using namespace ngraph; constexpr NodeTypeInfo op::ReadValue::type_info; -op::ReadValue::ReadValue(const Output& new_value, const std::string& variable_id) - : Op({new_value}) +op::ReadValue::ReadValue(const Output& init_value, const std::string& variable_id) + : Op({init_value}) , m_variable_id(variable_id) { constructor_validate_and_infer_types(); diff --git a/ngraph/src/ngraph/op/read_value.hpp b/ngraph/src/ngraph/op/read_value.hpp index ea451f3ed45..ca3f5325f0d 100644 --- a/ngraph/src/ngraph/op/read_value.hpp +++ b/ngraph/src/ngraph/op/read_value.hpp @@ -36,9 +36,9 @@ namespace ngraph /// \brief Constructs a ReadValue operation. /// - /// \param new_value Node that produces the input tensor. - /// \param variable_id identificator of the variable to create. - ReadValue(const Output& new_value, const std::string& variable_id); + /// \param init_value Node that produces the input tensor. + /// \param variable_id identificator of the variable to create. + ReadValue(const Output& init_value, const std::string& variable_id); void validate_and_infer_types() override;