Python API for If operation (#7934)

* add if in python_api

* add test for if python api

* fix code style

* fix typr of output_desc

* move to ne api

* Revert "add if in python_api"

This reverts commit fca6e5a449.

* Revert compatibility if_op

* fix codestyle

* add new tests, disable test where bug is appear

* fix codestyle

* fix test

* rewrite interface if python_api

* fix dict_attribute_visitor.cpp

* fix codestyle

* fix codestyle

* update dict_attribute_visitor.cpp

* add comapti

* add compatibilty tests

* fix if_op description, and paths

* add compatible test

* fix comp opset

* fix opset 8 whitespace

* fix codestyle

* fix tests

* delete old tests

* Revert fixes in test_reduction.py, test_node_factory.py

* Revert fixes in test_reduction.py, test_node_factory.py

* fix tuple

* fix tuple

* fix tuple

* fix tuple

* fix test path

* fix test path
This commit is contained in:
Eugeny Volosenkov 2021-12-01 18:07:02 +03:00 committed by GitHub
parent caf1f22f63
commit 81b003e688
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 537 additions and 103 deletions

View File

@ -95,6 +95,7 @@ from ngraph.opset8 import hard_sigmoid
from ngraph.opset8 import hsigmoid
from ngraph.opset8 import hswish
from ngraph.opset8 import idft
from ngraph.opset8 import if_op
from ngraph.opset8 import interpolate
from ngraph.opset8 import less
from ngraph.opset8 import less_equal

View File

@ -69,6 +69,7 @@ from ngraph.opset1.ops import hard_sigmoid
from ngraph.opset5.ops import hsigmoid
from ngraph.opset4.ops import hswish
from ngraph.opset7.ops import idft
from ngraph.opset8.ops import if_op
from ngraph.opset1.ops import interpolate
from ngraph.opset1.ops import less
from ngraph.opset1.ops import less_equal

View File

@ -3,7 +3,7 @@
"""Factory functions for all ngraph ops."""
from functools import partial
from typing import Callable, Iterable, List, Optional, Set, Union
from typing import Callable, Iterable, List, Optional, Set, Union, Tuple
import numpy as np
from ngraph.impl import Node, Shape
@ -369,6 +369,42 @@ def random_uniform(
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
@nameable_op
def if_op(
condition: NodeInput,
inputs: List[Node],
bodies: Tuple[GraphBody, GraphBody],
input_desc: Tuple[List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]],
output_desc: Tuple[List[TensorIteratorBodyOutputDesc], List[TensorIteratorBodyOutputDesc]],
name: Optional[str] = None,
) -> Node:
"""Execute one of the bodies depending on condtion value.
@param condition: A scalar or 1D tensor with 1 element specifying body will be executed.
If condition is True, then body will be executed, False - else_body.
@param inputs: The provided inputs to If operation.
@param bodies: Two graphs (then_body, else_body) which will be executed depending on
condition value.
@param input_desc Two lists (for then_body and else_body) which contain rules how If
inputs are connected with body parameters.
@param output_desc: Two lists (for then_body and else_body) which contain rules how If
outputs are connected with body results.
@param name: The optional name for the created output node.
@return: The new node which performs If operation.
"""
attributes = {
"then_body": bodies[0].serialize(),
"else_body": bodies[1].serialize(),
"then_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[0]]},
"else_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[1]]},
"then_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[0]]},
"else_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[1]]}
}
return _get_node_factory_opset8().create("If", as_nodes(condition, *inputs),
attributes)
@nameable_op
def slice(
data: NodeInput,

View File

@ -112,7 +112,7 @@ class TensorIteratorOutputDesc(object):
class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
"""Represents an output from a specific iteration."""
def __init__(self, body_value_idx: int, output_idx: int, iteration: int,) -> None:
def __init__(self, body_value_idx: int, output_idx: int, iteration: int = -1,) -> None:
super().__init__(body_value_idx, output_idx)
self.iteration = iteration

View File

@ -28,65 +28,71 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name, ngraph
&adapter)) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> input_descs;
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
for (py::handle h : slice_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
input_descs.push_back(slice_in);
if (input_desc.contains("slice_input_desc") && !input_desc["slice_input_desc"].is_none()) {
for (py::handle h : input_desc["slice_input_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
input_descs.push_back(slice_in);
}
}
for (py::handle h : merged_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["body_value_idx"].cast<int64_t>());
input_descs.push_back(merged_in);
if (input_desc.contains("merged_input_desc") && !input_desc["merged_input_desc"].is_none()) {
for (py::handle h : input_desc["merged_input_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["body_value_idx"].cast<int64_t>());
input_descs.push_back(merged_in);
}
}
for (py::handle h : invariant_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
if (input_desc.contains("invariant_input_desc") && !input_desc["invariant_input_desc"].is_none()) {
for (py::handle h : input_desc["invariant_input_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
}
}
a->set(input_descs);
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> output_descs;
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
for (py::handle h : body_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["iteration"].cast<int64_t>());
output_descs.push_back(body_output);
if (output_desc.contains("body_output_desc") && !output_desc["body_output_desc"].is_none()) {
for (py::handle h : output_desc["body_output_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["iteration"].cast<int64_t>());
output_descs.push_back(body_output);
}
}
for (py::handle h : concat_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
output_descs.push_back(concat_output);
if (output_desc.contains("concat_output_desc") && !output_desc["concat_output_desc"].is_none()) {
for (py::handle h : output_desc["concat_output_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
output_descs.push_back(concat_output);
}
}
a->set(output_descs);
} else if (const auto& a =
@ -241,7 +247,7 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
if (m_attributes.contains(name)) {
if (name == "body") {
if (name == "body" || name == "then_body" || name == "else_body") {
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
const auto& body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();

View File

@ -69,6 +69,7 @@ from openvino.runtime.opset1.ops import hard_sigmoid
from openvino.runtime.opset5.ops import hsigmoid
from openvino.runtime.opset4.ops import hswish
from openvino.runtime.opset7.ops import idft
from openvino.runtime.opset8.ops import if_op
from openvino.runtime.opset1.ops import interpolate
from openvino.runtime.opset1.ops import less
from openvino.runtime.opset1.ops import less_equal

View File

@ -3,7 +3,7 @@
"""Factory functions for all openvino ops."""
from functools import partial
from typing import Callable, Iterable, List, Optional, Set, Union
from typing import Callable, Iterable, List, Optional, Set, Union, Tuple
import numpy as np
from openvino.runtime.impl import Node, Shape
@ -369,6 +369,42 @@ def random_uniform(
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
@nameable_op
def if_op(
condition: NodeInput,
inputs: List[Node],
bodies: Tuple[GraphBody, GraphBody],
input_desc: Tuple[List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]],
output_desc: Tuple[List[TensorIteratorBodyOutputDesc], List[TensorIteratorBodyOutputDesc]],
name: Optional[str] = None,
) -> Node:
"""Execute one of the bodies depending on condtion value.
@param condition: A scalar or 1D tensor with 1 element specifying body will be executed.
If condition is True, then body will be executed, False - else_body.
@param inputs: The provided inputs to If operation.
@param bodies: Two graphs (then_body, else_body) which will be executed depending on
condition value.
@param input_desc Two lists (for then_body and else_body) which contain rules how If
inputs are connected with body parameters.
@param output_desc: Two lists (for then_body and else_body) which contain rules how If
outputs are connected with body results.
@param name: The optional name for the created output node.
@return: The new node which performs If operation.
"""
attributes = {
"then_body": bodies[0].serialize(),
"else_body": bodies[1].serialize(),
"then_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[0]]},
"else_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[1]]},
"then_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[0]]},
"else_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[1]]}
}
return _get_node_factory_opset8().create("If", as_nodes(condition, *inputs),
attributes)
@nameable_op
def slice(
data: NodeInput,

View File

@ -112,7 +112,7 @@ class TensorIteratorOutputDesc(object):
class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
"""Represents an output from a specific iteration."""
def __init__(self, body_value_idx: int, output_idx: int, iteration: int,) -> None:
def __init__(self, body_value_idx: int, output_idx: int, iteration: int = -1) -> None:
super().__init__(body_value_idx, output_idx)
self.iteration = iteration

View File

@ -28,37 +28,39 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name, ov::Va
&adapter)) {
std::vector<std::shared_ptr<ov::op::util::SubGraphOp::InputDescription>> input_descs;
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
for (py::handle h : slice_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto slice_in = std::make_shared<ov::op::util::SubGraphOp::SliceInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
input_descs.push_back(slice_in);
if (input_desc.contains("slice_input_desc") && !input_desc["slice_input_desc"].is_none()) {
for (py::handle h : input_desc["slice_input_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto slice_in = std::make_shared<ov::op::util::SubGraphOp::SliceInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
input_descs.push_back(slice_in);
}
}
if (input_desc.contains("merged_input_desc") && !input_desc["merged_input_desc"].is_none()) {
for (py::handle h : input_desc["merged_input_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto merged_in = std::make_shared<ov::op::util::SubGraphOp::MergedInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["body_value_idx"].cast<int64_t>());
input_descs.push_back(merged_in);
}
}
for (py::handle h : merged_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto merged_in = std::make_shared<ov::op::util::SubGraphOp::MergedInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>(),
desc["body_value_idx"].cast<int64_t>());
input_descs.push_back(merged_in);
}
for (py::handle h : invariant_input_desc) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ov::op::util::SubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
if (input_desc.contains("invariant_input_desc") && !input_desc["invariant_input_desc"].is_none()) {
for (py::handle h : input_desc["invariant_input_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto invariant_in = std::make_shared<ov::op::util::SubGraphOp::InvariantInputDescription>(
desc["input_idx"].cast<int64_t>(),
desc["body_parameter_idx"].cast<int64_t>());
input_descs.push_back(invariant_in);
}
}
a->set(input_descs);
} else if (const auto& a = ov::as_type<
@ -66,28 +68,29 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name, ov::Va
&adapter)) {
std::vector<std::shared_ptr<ov::op::util::SubGraphOp::OutputDescription>> output_descs;
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
for (py::handle h : body_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ov::op::util::SubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["iteration"].cast<int64_t>());
output_descs.push_back(body_output);
if (output_desc.contains("body_output_desc") && !output_desc["body_output_desc"].is_none()) {
for (py::handle h : output_desc["body_output_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto body_output = std::make_shared<ov::op::util::SubGraphOp::BodyOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["iteration"].cast<int64_t>());
output_descs.push_back(body_output);
}
}
for (py::handle h : concat_output_desc) {
const py::dict& desc = h.cast<py::dict>();
auto concat_output = std::make_shared<ov::op::util::SubGraphOp::ConcatOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
output_descs.push_back(concat_output);
if (output_desc.contains("concat_output_desc") && !output_desc["concat_output_desc"].is_none()) {
for (py::handle h : output_desc["concat_output_desc"].cast<py::list>()) {
const py::dict& desc = h.cast<py::dict>();
auto concat_output = std::make_shared<ov::op::util::SubGraphOp::ConcatOutputDescription>(
desc["body_value_idx"].cast<int64_t>(),
desc["output_idx"].cast<int64_t>(),
desc["start"].cast<int64_t>(),
desc["stride"].cast<int64_t>(),
desc["part_size"].cast<int64_t>(),
desc["end"].cast<int64_t>(),
desc["axis"].cast<int64_t>());
output_descs.push_back(concat_output);
}
}
a->set(output_descs);
} else if (const auto& a = ov::as_type<ov::AttributeAdapter<ov::op::v5::Loop::SpecialBodyPorts>>(&adapter)) {
@ -241,7 +244,7 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
ov::ValueAccessor<std::shared_ptr<ov::Function>>& adapter) {
if (m_attributes.contains(name)) {
if (name == "body") {
if (name == "body" || name == "then_body" || name == "else_body") {
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ov::NodeVector>());
const auto& body_parameters = body_attrs["parameters"].cast<ov::ParameterVector>();

View File

@ -0,0 +1,175 @@
import numpy as np
import openvino.runtime.opset8 as ov
import pytest
from openvino.runtime.utils.tensor_iterator_types import (
GraphBody,
TensorIteratorInvariantInputDesc,
TensorIteratorBodyOutputDesc,
)
from tests.runtime import get_runtime
def create_simple_if_with_two_outputs(condition_val):
condition = ov.constant(condition_val, dtype=np.bool)
# then_body
X_t = ov.parameter([], np.float32, "X")
Y_t = ov.parameter([], np.float32, "Y")
Z_t = ov.parameter([], np.float32, "Z")
add_t = ov.add(X_t, Y_t)
mul_t = ov.multiply(Y_t, Z_t)
then_body_res_1 = ov.result(add_t)
then_body_res_2 = ov.result(mul_t)
then_body = GraphBody([X_t, Y_t, Z_t], [then_body_res_1, then_body_res_2])
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1),
TensorIteratorInvariantInputDesc(3, 2)]
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
# else_body
X_e = ov.parameter([], np.float32, "X")
Z_e = ov.parameter([], np.float32, "Z")
W_e = ov.parameter([], np.float32, "W")
add_e = ov.add(X_e, W_e)
pow_e = ov.power(W_e, Z_e)
else_body_res_1 = ov.result(add_e)
else_body_res_2 = ov.result(pow_e)
else_body = GraphBody([X_e, Z_e, W_e], [else_body_res_1, else_body_res_2])
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1),
TensorIteratorInvariantInputDesc(4, 2)]
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
X = ov.constant(15.0, dtype=np.float32)
Y = ov.constant(-5.0, dtype=np.float32)
Z = ov.constant(4.0, dtype=np.float32)
W = ov.constant(2.0, dtype=np.float32)
if_node = ov.if_op(condition, [X, Y, Z, W], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
return if_node
def create_diff_if_with_two_outputs(condition_val):
condition = ov.constant(condition_val, dtype=np.bool)
# then_body
X_t = ov.parameter([2], np.float32, "X")
Y_t = ov.parameter([2], np.float32, "Y")
mmul_t = ov.matmul(X_t, Y_t, False, False)
mul_t = ov.multiply(Y_t, X_t)
then_body_res_1 = ov.result(mmul_t)
then_body_res_2 = ov.result(mul_t)
then_body = GraphBody([X_t, Y_t], [then_body_res_1, then_body_res_2])
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
# else_body
X_e = ov.parameter([2], np.float32, "X")
Z_e = ov.parameter([], np.float32, "Z")
mul_e = ov.multiply(X_e, Z_e)
else_body_res_1 = ov.result(Z_e)
else_body_res_2 = ov.result(mul_e)
else_body = GraphBody([X_e, Z_e], [else_body_res_1, else_body_res_2])
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1)]
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
X = ov.constant([3, 4], dtype=np.float32)
Y = ov.constant([2, 1], dtype=np.float32)
Z = ov.constant(4.0, dtype=np.float32)
if_node = ov.if_op(condition, [X, Y, Z], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
return if_node
def simple_if(condition_val):
condition = ov.constant(condition_val, dtype=np.bool)
# then_body
X_t = ov.parameter([2], np.float32, "X")
Y_t = ov.parameter([2], np.float32, "Y")
then_mul = ov.multiply(X_t, Y_t)
then_body_res_1 = ov.result(then_mul)
then_body = GraphBody([X_t, Y_t], [then_body_res_1])
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
# else_body
X_e = ov.parameter([2], np.float32, "X")
Y_e = ov.parameter([2], np.float32, "Y")
add_e = ov.add(X_e, Y_e)
else_body_res_1 = ov.result(add_e)
else_body = GraphBody([X_e, Y_e], [else_body_res_1])
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
X = ov.constant([3, 4], dtype=np.float32)
Y = ov.constant([2, 1], dtype=np.float32)
if_node = ov.if_op(condition, [X, Y], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
relu = ov.relu(if_node)
return relu
def simple_if_without_parameters(condition_val):
condition = ov.constant(condition_val, dtype=np.bool)
# then_body
then_constant = ov.constant(0.7, dtype=np.float)
then_body_res_1 = ov.result(then_constant)
then_body = GraphBody([], [then_body_res_1])
then_body_inputs = []
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
# else_body
else_const = ov.constant(9.0, dtype=np.float)
else_body_res_1 = ov.result(else_const)
else_body = GraphBody([], [else_body_res_1])
else_body_inputs = []
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
if_node = ov.if_op(condition, [], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
relu = ov.relu(if_node)
return relu
def check_results(results, expected_results):
assert len(results) == len(expected_results)
for id_result, res in enumerate(results):
assert np.allclose(res, expected_results[id_result])
def check_if(if_model, cond_val, exp_results):
last_node = if_model(cond_val)
runtime = get_runtime()
computation = runtime.computation(last_node)
results = computation()
check_results(results, exp_results)
# After deleting evalute method for if, constant folding stopped working.
# As result bug with id 67255 began to appear
@pytest.mark.xfail(reason="bug 67255")
def test_if_with_two_outputs():
check_if(create_simple_if_with_two_outputs, True,
[np.array([10], dtype=np.float32), np.array([-20], dtype=np.float32)])
check_if(create_simple_if_with_two_outputs, False,
[np.array([17], dtype=np.float32), np.array([16], dtype=np.float32)])
@pytest.mark.xfail(reason="bug 67255")
def test_diff_if_with_two_outputs():
check_if(create_diff_if_with_two_outputs, True,
[np.array([10], dtype=np.float32), np.array([6, 4], dtype=np.float32)])
check_if(create_diff_if_with_two_outputs, False,
[np.array([4], dtype=np.float32), np.array([12, 16], dtype=np.float32)])
def test_simple_if():
check_if(simple_if, True, [np.array([6, 4], dtype=np.float32)])
check_if(simple_if, False, [np.array([5, 5], dtype=np.float32)])
def test_simple_if_without_body_parameters():
check_if(simple_if_without_parameters, True, [np.array([0.7], dtype=np.float32)])
check_if(simple_if_without_parameters, False, [np.array([9.0], dtype=np.float32)])

View File

@ -0,0 +1,175 @@
import ngraph as ng
import numpy as np
import pytest
from ngraph.utils.tensor_iterator_types import (
GraphBody,
TensorIteratorInvariantInputDesc,
TensorIteratorBodyOutputDesc,
)
from tests_compatibility.runtime import get_runtime
def create_simple_if_with_two_outputs(condition_val):
condition = ng.constant(condition_val, dtype=np.bool)
# then_body
X_t = ng.parameter([], np.float32, "X")
Y_t = ng.parameter([], np.float32, "Y")
Z_t = ng.parameter([], np.float32, "Z")
add_t = ng.add(X_t, Y_t)
mul_t = ng.multiply(Y_t, Z_t)
then_body_res_1 = ng.result(add_t)
then_body_res_2 = ng.result(mul_t)
then_body = GraphBody([X_t, Y_t, Z_t], [then_body_res_1, then_body_res_2])
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1),
TensorIteratorInvariantInputDesc(3, 2)]
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
# else_body
X_e = ng.parameter([], np.float32, "X")
Z_e = ng.parameter([], np.float32, "Z")
W_e = ng.parameter([], np.float32, "W")
add_e = ng.add(X_e, W_e)
pow_e = ng.power(W_e, Z_e)
else_body_res_1 = ng.result(add_e)
else_body_res_2 = ng.result(pow_e)
else_body = GraphBody([X_e, Z_e, W_e], [else_body_res_1, else_body_res_2])
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1),
TensorIteratorInvariantInputDesc(4, 2)]
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
X = ng.constant(15.0, dtype=np.float32)
Y = ng.constant(-5.0, dtype=np.float32)
Z = ng.constant(4.0, dtype=np.float32)
W = ng.constant(2.0, dtype=np.float32)
if_node = ng.if_op(condition, [X, Y, Z, W], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
return if_node
def create_diff_if_with_two_outputs(condition_val):
condition = ng.constant(condition_val, dtype=np.bool)
# then_body
X_t = ng.parameter([2], np.float32, "X")
Y_t = ng.parameter([2], np.float32, "Y")
mmul_t = ng.matmul(X_t, Y_t, False, False)
mul_t = ng.multiply(Y_t, X_t)
then_body_res_1 = ng.result(mmul_t)
then_body_res_2 = ng.result(mul_t)
then_body = GraphBody([X_t, Y_t], [then_body_res_1, then_body_res_2])
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
# else_body
X_e = ng.parameter([2], np.float32, "X")
Z_e = ng.parameter([], np.float32, "Z")
mul_e = ng.multiply(X_e, Z_e)
else_body_res_1 = ng.result(Z_e)
else_body_res_2 = ng.result(mul_e)
else_body = GraphBody([X_e, Z_e], [else_body_res_1, else_body_res_2])
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1)]
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
X = ng.constant([3, 4], dtype=np.float32)
Y = ng.constant([2, 1], dtype=np.float32)
Z = ng.constant(4.0, dtype=np.float32)
if_node = ng.if_op(condition, [X, Y, Z], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
return if_node
def simple_if(condition_val):
condition = ng.constant(condition_val, dtype=np.bool)
# then_body
X_t = ng.parameter([2], np.float32, "X")
Y_t = ng.parameter([2], np.float32, "Y")
then_mul = ng.multiply(X_t, Y_t)
then_body_res_1 = ng.result(then_mul)
then_body = GraphBody([X_t, Y_t], [then_body_res_1])
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
# else_body
X_e = ng.parameter([2], np.float32, "X")
Y_e = ng.parameter([2], np.float32, "Y")
add_e = ng.add(X_e, Y_e)
else_body_res_1 = ng.result(add_e)
else_body = GraphBody([X_e, Y_e], [else_body_res_1])
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
X = ng.constant([3, 4], dtype=np.float32)
Y = ng.constant([2, 1], dtype=np.float32)
if_node = ng.if_op(condition, [X, Y], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
relu = ng.relu(if_node)
return relu
def simple_if_without_parameters(condition_val):
condition = ng.constant(condition_val, dtype=np.bool)
# then_body
then_constant = ng.constant(0.7, dtype=np.float)
then_body_res_1 = ng.result(then_constant)
then_body = GraphBody([], [then_body_res_1])
then_body_inputs = []
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
# else_body
else_const = ng.constant(9.0, dtype=np.float)
else_body_res_1 = ng.result(else_const)
else_body = GraphBody([], [else_body_res_1])
else_body_inputs = []
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
if_node = ng.if_op(condition, [], (then_body, else_body), (then_body_inputs, else_body_inputs),
(then_body_outputs, else_body_outputs))
relu = ng.relu(if_node)
return relu
def check_results(results, expected_results):
assert len(results) == len(expected_results)
for id_result, res in enumerate(results):
assert np.allclose(res, expected_results[id_result])
def check_if(if_model, cond_val, exp_results):
last_node = if_model(cond_val)
runtime = get_runtime()
computation = runtime.computation(last_node)
results = computation()
check_results(results, exp_results)
# After deleting evalute method for if, constant folding stopped working.
# As result bug with id 67255 began to appear
@pytest.mark.xfail(reason="bug 67255")
def test_if_with_two_outputs():
check_if(create_simple_if_with_two_outputs, True,
[np.array([10], dtype=np.float32), np.array([-20], dtype=np.float32)])
check_if(create_simple_if_with_two_outputs, False,
[np.array([17], dtype=np.float32), np.array([16], dtype=np.float32)])
@pytest.mark.xfail(reason="bug 67255")
def test_diff_if_with_two_outputs():
check_if(create_diff_if_with_two_outputs, True,
[np.array([10], dtype=np.float32), np.array([6, 4], dtype=np.float32)])
check_if(create_diff_if_with_two_outputs, False,
[np.array([4], dtype=np.float32), np.array([12, 16], dtype=np.float32)])
def test_simple_if():
check_if(simple_if, True, [np.array([6, 4], dtype=np.float32)])
check_if(simple_if, False, [np.array([5, 5], dtype=np.float32)])
def test_simple_if_without_body_parameters():
check_if(simple_if_without_parameters, True, [np.array([0.7], dtype=np.float32)])
check_if(simple_if_without_parameters, False, [np.array([9.0], dtype=np.float32)])