Python API for If operation (#7934)
* add if in python_api
* add test for if python api
* fix code style
* fix typr of output_desc
* move to ne api
* Revert "add if in python_api"
This reverts commit fca6e5a449
.
* Revert compatibility if_op
* fix codestyle
* add new tests, disable test where bug is appear
* fix codestyle
* fix test
* rewrite interface if python_api
* fix dict_attribute_visitor.cpp
* fix codestyle
* fix codestyle
* update dict_attribute_visitor.cpp
* add comapti
* add compatibilty tests
* fix if_op description, and paths
* add compatible test
* fix comp opset
* fix opset 8 whitespace
* fix codestyle
* fix tests
* delete old tests
* Revert fixes in test_reduction.py, test_node_factory.py
* Revert fixes in test_reduction.py, test_node_factory.py
* fix tuple
* fix tuple
* fix tuple
* fix tuple
* fix test path
* fix test path
This commit is contained in:
parent
caf1f22f63
commit
81b003e688
@ -95,6 +95,7 @@ from ngraph.opset8 import hard_sigmoid
|
|||||||
from ngraph.opset8 import hsigmoid
|
from ngraph.opset8 import hsigmoid
|
||||||
from ngraph.opset8 import hswish
|
from ngraph.opset8 import hswish
|
||||||
from ngraph.opset8 import idft
|
from ngraph.opset8 import idft
|
||||||
|
from ngraph.opset8 import if_op
|
||||||
from ngraph.opset8 import interpolate
|
from ngraph.opset8 import interpolate
|
||||||
from ngraph.opset8 import less
|
from ngraph.opset8 import less
|
||||||
from ngraph.opset8 import less_equal
|
from ngraph.opset8 import less_equal
|
||||||
|
@ -69,6 +69,7 @@ from ngraph.opset1.ops import hard_sigmoid
|
|||||||
from ngraph.opset5.ops import hsigmoid
|
from ngraph.opset5.ops import hsigmoid
|
||||||
from ngraph.opset4.ops import hswish
|
from ngraph.opset4.ops import hswish
|
||||||
from ngraph.opset7.ops import idft
|
from ngraph.opset7.ops import idft
|
||||||
|
from ngraph.opset8.ops import if_op
|
||||||
from ngraph.opset1.ops import interpolate
|
from ngraph.opset1.ops import interpolate
|
||||||
from ngraph.opset1.ops import less
|
from ngraph.opset1.ops import less
|
||||||
from ngraph.opset1.ops import less_equal
|
from ngraph.opset1.ops import less_equal
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Factory functions for all ngraph ops."""
|
"""Factory functions for all ngraph ops."""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Iterable, List, Optional, Set, Union
|
from typing import Callable, Iterable, List, Optional, Set, Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ngraph.impl import Node, Shape
|
from ngraph.impl import Node, Shape
|
||||||
@ -369,6 +369,42 @@ def random_uniform(
|
|||||||
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
||||||
|
|
||||||
|
|
||||||
|
@nameable_op
|
||||||
|
def if_op(
|
||||||
|
condition: NodeInput,
|
||||||
|
inputs: List[Node],
|
||||||
|
bodies: Tuple[GraphBody, GraphBody],
|
||||||
|
input_desc: Tuple[List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]],
|
||||||
|
output_desc: Tuple[List[TensorIteratorBodyOutputDesc], List[TensorIteratorBodyOutputDesc]],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Node:
|
||||||
|
"""Execute one of the bodies depending on condtion value.
|
||||||
|
|
||||||
|
@param condition: A scalar or 1D tensor with 1 element specifying body will be executed.
|
||||||
|
If condition is True, then body will be executed, False - else_body.
|
||||||
|
@param inputs: The provided inputs to If operation.
|
||||||
|
@param bodies: Two graphs (then_body, else_body) which will be executed depending on
|
||||||
|
condition value.
|
||||||
|
@param input_desc Two lists (for then_body and else_body) which contain rules how If
|
||||||
|
inputs are connected with body parameters.
|
||||||
|
@param output_desc: Two lists (for then_body and else_body) which contain rules how If
|
||||||
|
outputs are connected with body results.
|
||||||
|
@param name: The optional name for the created output node.
|
||||||
|
|
||||||
|
@return: The new node which performs If operation.
|
||||||
|
"""
|
||||||
|
attributes = {
|
||||||
|
"then_body": bodies[0].serialize(),
|
||||||
|
"else_body": bodies[1].serialize(),
|
||||||
|
"then_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[0]]},
|
||||||
|
"else_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[1]]},
|
||||||
|
"then_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[0]]},
|
||||||
|
"else_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[1]]}
|
||||||
|
}
|
||||||
|
return _get_node_factory_opset8().create("If", as_nodes(condition, *inputs),
|
||||||
|
attributes)
|
||||||
|
|
||||||
|
|
||||||
@nameable_op
|
@nameable_op
|
||||||
def slice(
|
def slice(
|
||||||
data: NodeInput,
|
data: NodeInput,
|
||||||
|
@ -112,7 +112,7 @@ class TensorIteratorOutputDesc(object):
|
|||||||
class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
|
class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
|
||||||
"""Represents an output from a specific iteration."""
|
"""Represents an output from a specific iteration."""
|
||||||
|
|
||||||
def __init__(self, body_value_idx: int, output_idx: int, iteration: int,) -> None:
|
def __init__(self, body_value_idx: int, output_idx: int, iteration: int = -1,) -> None:
|
||||||
super().__init__(body_value_idx, output_idx)
|
super().__init__(body_value_idx, output_idx)
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
|
||||||
|
@ -28,65 +28,71 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name, ngraph
|
|||||||
&adapter)) {
|
&adapter)) {
|
||||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> input_descs;
|
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> input_descs;
|
||||||
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||||
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
|
|
||||||
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
|
if (input_desc.contains("slice_input_desc") && !input_desc["slice_input_desc"].is_none()) {
|
||||||
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
|
for (py::handle h : input_desc["slice_input_desc"].cast<py::list>()) {
|
||||||
for (py::handle h : slice_input_desc) {
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
|
||||||
auto slice_in = std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>(
|
desc["input_idx"].cast<int64_t>(),
|
||||||
desc["input_idx"].cast<int64_t>(),
|
desc["body_parameter_idx"].cast<int64_t>(),
|
||||||
desc["body_parameter_idx"].cast<int64_t>(),
|
desc["start"].cast<int64_t>(),
|
||||||
desc["start"].cast<int64_t>(),
|
desc["stride"].cast<int64_t>(),
|
||||||
desc["stride"].cast<int64_t>(),
|
desc["part_size"].cast<int64_t>(),
|
||||||
desc["part_size"].cast<int64_t>(),
|
desc["end"].cast<int64_t>(),
|
||||||
desc["end"].cast<int64_t>(),
|
desc["axis"].cast<int64_t>());
|
||||||
desc["axis"].cast<int64_t>());
|
input_descs.push_back(slice_in);
|
||||||
input_descs.push_back(slice_in);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (py::handle h : merged_input_desc) {
|
if (input_desc.contains("merged_input_desc") && !input_desc["merged_input_desc"].is_none()) {
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
for (py::handle h : input_desc["merged_input_desc"].cast<py::list>()) {
|
||||||
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
desc["input_idx"].cast<int64_t>(),
|
auto merged_in = std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>(
|
||||||
desc["body_parameter_idx"].cast<int64_t>(),
|
desc["input_idx"].cast<int64_t>(),
|
||||||
desc["body_value_idx"].cast<int64_t>());
|
desc["body_parameter_idx"].cast<int64_t>(),
|
||||||
input_descs.push_back(merged_in);
|
desc["body_value_idx"].cast<int64_t>());
|
||||||
|
input_descs.push_back(merged_in);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (py::handle h : invariant_input_desc) {
|
if (input_desc.contains("invariant_input_desc") && !input_desc["invariant_input_desc"].is_none()) {
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
for (py::handle h : input_desc["invariant_input_desc"].cast<py::list>()) {
|
||||||
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
desc["input_idx"].cast<int64_t>(),
|
auto invariant_in = std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>(
|
||||||
desc["body_parameter_idx"].cast<int64_t>());
|
desc["input_idx"].cast<int64_t>(),
|
||||||
input_descs.push_back(invariant_in);
|
desc["body_parameter_idx"].cast<int64_t>());
|
||||||
|
input_descs.push_back(invariant_in);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
a->set(input_descs);
|
a->set(input_descs);
|
||||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<
|
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<
|
||||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
|
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
|
||||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> output_descs;
|
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> output_descs;
|
||||||
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||||
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
|
if (output_desc.contains("body_output_desc") && !output_desc["body_output_desc"].is_none()) {
|
||||||
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
|
for (py::handle h : output_desc["body_output_desc"].cast<py::list>()) {
|
||||||
for (py::handle h : body_output_desc) {
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
|
||||||
auto body_output = std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>(
|
desc["body_value_idx"].cast<int64_t>(),
|
||||||
desc["body_value_idx"].cast<int64_t>(),
|
desc["output_idx"].cast<int64_t>(),
|
||||||
desc["output_idx"].cast<int64_t>(),
|
desc["iteration"].cast<int64_t>());
|
||||||
desc["iteration"].cast<int64_t>());
|
output_descs.push_back(body_output);
|
||||||
output_descs.push_back(body_output);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (py::handle h : concat_output_desc) {
|
if (output_desc.contains("concat_output_desc") && !output_desc["concat_output_desc"].is_none()) {
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
for (py::handle h : output_desc["concat_output_desc"].cast<py::list>()) {
|
||||||
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
desc["body_value_idx"].cast<int64_t>(),
|
auto concat_output = std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(
|
||||||
desc["output_idx"].cast<int64_t>(),
|
desc["body_value_idx"].cast<int64_t>(),
|
||||||
desc["start"].cast<int64_t>(),
|
desc["output_idx"].cast<int64_t>(),
|
||||||
desc["stride"].cast<int64_t>(),
|
desc["start"].cast<int64_t>(),
|
||||||
desc["part_size"].cast<int64_t>(),
|
desc["stride"].cast<int64_t>(),
|
||||||
desc["end"].cast<int64_t>(),
|
desc["part_size"].cast<int64_t>(),
|
||||||
desc["axis"].cast<int64_t>());
|
desc["end"].cast<int64_t>(),
|
||||||
output_descs.push_back(concat_output);
|
desc["axis"].cast<int64_t>());
|
||||||
|
output_descs.push_back(concat_output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
a->set(output_descs);
|
a->set(output_descs);
|
||||||
} else if (const auto& a =
|
} else if (const auto& a =
|
||||||
@ -241,7 +247,7 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
|||||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||||
ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
|
ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
|
||||||
if (m_attributes.contains(name)) {
|
if (m_attributes.contains(name)) {
|
||||||
if (name == "body") {
|
if (name == "body" || name == "then_body" || name == "else_body") {
|
||||||
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
|
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
|
||||||
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
|
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ngraph::NodeVector>());
|
||||||
const auto& body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
|
const auto& body_parameters = body_attrs["parameters"].cast<ngraph::ParameterVector>();
|
||||||
|
@ -69,6 +69,7 @@ from openvino.runtime.opset1.ops import hard_sigmoid
|
|||||||
from openvino.runtime.opset5.ops import hsigmoid
|
from openvino.runtime.opset5.ops import hsigmoid
|
||||||
from openvino.runtime.opset4.ops import hswish
|
from openvino.runtime.opset4.ops import hswish
|
||||||
from openvino.runtime.opset7.ops import idft
|
from openvino.runtime.opset7.ops import idft
|
||||||
|
from openvino.runtime.opset8.ops import if_op
|
||||||
from openvino.runtime.opset1.ops import interpolate
|
from openvino.runtime.opset1.ops import interpolate
|
||||||
from openvino.runtime.opset1.ops import less
|
from openvino.runtime.opset1.ops import less
|
||||||
from openvino.runtime.opset1.ops import less_equal
|
from openvino.runtime.opset1.ops import less_equal
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
"""Factory functions for all openvino ops."""
|
"""Factory functions for all openvino ops."""
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Iterable, List, Optional, Set, Union
|
from typing import Callable, Iterable, List, Optional, Set, Union, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openvino.runtime.impl import Node, Shape
|
from openvino.runtime.impl import Node, Shape
|
||||||
@ -369,6 +369,42 @@ def random_uniform(
|
|||||||
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
return _get_node_factory_opset8().create("RandomUniform", inputs, attributes)
|
||||||
|
|
||||||
|
|
||||||
|
@nameable_op
|
||||||
|
def if_op(
|
||||||
|
condition: NodeInput,
|
||||||
|
inputs: List[Node],
|
||||||
|
bodies: Tuple[GraphBody, GraphBody],
|
||||||
|
input_desc: Tuple[List[TensorIteratorInvariantInputDesc], List[TensorIteratorInvariantInputDesc]],
|
||||||
|
output_desc: Tuple[List[TensorIteratorBodyOutputDesc], List[TensorIteratorBodyOutputDesc]],
|
||||||
|
name: Optional[str] = None,
|
||||||
|
) -> Node:
|
||||||
|
"""Execute one of the bodies depending on condtion value.
|
||||||
|
|
||||||
|
@param condition: A scalar or 1D tensor with 1 element specifying body will be executed.
|
||||||
|
If condition is True, then body will be executed, False - else_body.
|
||||||
|
@param inputs: The provided inputs to If operation.
|
||||||
|
@param bodies: Two graphs (then_body, else_body) which will be executed depending on
|
||||||
|
condition value.
|
||||||
|
@param input_desc Two lists (for then_body and else_body) which contain rules how If
|
||||||
|
inputs are connected with body parameters.
|
||||||
|
@param output_desc: Two lists (for then_body and else_body) which contain rules how If
|
||||||
|
outputs are connected with body results.
|
||||||
|
@param name: The optional name for the created output node.
|
||||||
|
|
||||||
|
@return: The new node which performs If operation.
|
||||||
|
"""
|
||||||
|
attributes = {
|
||||||
|
"then_body": bodies[0].serialize(),
|
||||||
|
"else_body": bodies[1].serialize(),
|
||||||
|
"then_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[0]]},
|
||||||
|
"else_inputs": {"invariant_input_desc": [desc.serialize() for desc in input_desc[1]]},
|
||||||
|
"then_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[0]]},
|
||||||
|
"else_outputs": {"body_output_desc": [desc.serialize() for desc in output_desc[1]]}
|
||||||
|
}
|
||||||
|
return _get_node_factory_opset8().create("If", as_nodes(condition, *inputs),
|
||||||
|
attributes)
|
||||||
|
|
||||||
|
|
||||||
@nameable_op
|
@nameable_op
|
||||||
def slice(
|
def slice(
|
||||||
data: NodeInput,
|
data: NodeInput,
|
||||||
|
@ -112,7 +112,7 @@ class TensorIteratorOutputDesc(object):
|
|||||||
class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
|
class TensorIteratorBodyOutputDesc(TensorIteratorOutputDesc):
|
||||||
"""Represents an output from a specific iteration."""
|
"""Represents an output from a specific iteration."""
|
||||||
|
|
||||||
def __init__(self, body_value_idx: int, output_idx: int, iteration: int,) -> None:
|
def __init__(self, body_value_idx: int, output_idx: int, iteration: int = -1) -> None:
|
||||||
super().__init__(body_value_idx, output_idx)
|
super().__init__(body_value_idx, output_idx)
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
|
||||||
|
@ -28,37 +28,39 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name, ov::Va
|
|||||||
&adapter)) {
|
&adapter)) {
|
||||||
std::vector<std::shared_ptr<ov::op::util::SubGraphOp::InputDescription>> input_descs;
|
std::vector<std::shared_ptr<ov::op::util::SubGraphOp::InputDescription>> input_descs;
|
||||||
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
const py::dict& input_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||||
const auto& merged_input_desc = input_desc["merged_input_desc"].cast<py::list>();
|
if (input_desc.contains("slice_input_desc") && !input_desc["slice_input_desc"].is_none()) {
|
||||||
const auto& slice_input_desc = input_desc["slice_input_desc"].cast<py::list>();
|
for (py::handle h : input_desc["slice_input_desc"].cast<py::list>()) {
|
||||||
const auto& invariant_input_desc = input_desc["invariant_input_desc"].cast<py::list>();
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
for (py::handle h : slice_input_desc) {
|
auto slice_in = std::make_shared<ov::op::util::SubGraphOp::SliceInputDescription>(
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
desc["input_idx"].cast<int64_t>(),
|
||||||
auto slice_in = std::make_shared<ov::op::util::SubGraphOp::SliceInputDescription>(
|
desc["body_parameter_idx"].cast<int64_t>(),
|
||||||
desc["input_idx"].cast<int64_t>(),
|
desc["start"].cast<int64_t>(),
|
||||||
desc["body_parameter_idx"].cast<int64_t>(),
|
desc["stride"].cast<int64_t>(),
|
||||||
desc["start"].cast<int64_t>(),
|
desc["part_size"].cast<int64_t>(),
|
||||||
desc["stride"].cast<int64_t>(),
|
desc["end"].cast<int64_t>(),
|
||||||
desc["part_size"].cast<int64_t>(),
|
desc["axis"].cast<int64_t>());
|
||||||
desc["end"].cast<int64_t>(),
|
input_descs.push_back(slice_in);
|
||||||
desc["axis"].cast<int64_t>());
|
}
|
||||||
input_descs.push_back(slice_in);
|
}
|
||||||
|
if (input_desc.contains("merged_input_desc") && !input_desc["merged_input_desc"].is_none()) {
|
||||||
|
for (py::handle h : input_desc["merged_input_desc"].cast<py::list>()) {
|
||||||
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
|
auto merged_in = std::make_shared<ov::op::util::SubGraphOp::MergedInputDescription>(
|
||||||
|
desc["input_idx"].cast<int64_t>(),
|
||||||
|
desc["body_parameter_idx"].cast<int64_t>(),
|
||||||
|
desc["body_value_idx"].cast<int64_t>());
|
||||||
|
input_descs.push_back(merged_in);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (py::handle h : merged_input_desc) {
|
if (input_desc.contains("invariant_input_desc") && !input_desc["invariant_input_desc"].is_none()) {
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
for (py::handle h : input_desc["invariant_input_desc"].cast<py::list>()) {
|
||||||
auto merged_in = std::make_shared<ov::op::util::SubGraphOp::MergedInputDescription>(
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
desc["input_idx"].cast<int64_t>(),
|
auto invariant_in = std::make_shared<ov::op::util::SubGraphOp::InvariantInputDescription>(
|
||||||
desc["body_parameter_idx"].cast<int64_t>(),
|
desc["input_idx"].cast<int64_t>(),
|
||||||
desc["body_value_idx"].cast<int64_t>());
|
desc["body_parameter_idx"].cast<int64_t>());
|
||||||
input_descs.push_back(merged_in);
|
input_descs.push_back(invariant_in);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (py::handle h : invariant_input_desc) {
|
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
|
||||||
auto invariant_in = std::make_shared<ov::op::util::SubGraphOp::InvariantInputDescription>(
|
|
||||||
desc["input_idx"].cast<int64_t>(),
|
|
||||||
desc["body_parameter_idx"].cast<int64_t>());
|
|
||||||
input_descs.push_back(invariant_in);
|
|
||||||
}
|
}
|
||||||
a->set(input_descs);
|
a->set(input_descs);
|
||||||
} else if (const auto& a = ov::as_type<
|
} else if (const auto& a = ov::as_type<
|
||||||
@ -66,28 +68,29 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name, ov::Va
|
|||||||
&adapter)) {
|
&adapter)) {
|
||||||
std::vector<std::shared_ptr<ov::op::util::SubGraphOp::OutputDescription>> output_descs;
|
std::vector<std::shared_ptr<ov::op::util::SubGraphOp::OutputDescription>> output_descs;
|
||||||
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
const py::dict& output_desc = m_attributes[name.c_str()].cast<py::dict>();
|
||||||
const auto& body_output_desc = output_desc["body_output_desc"].cast<py::list>();
|
if (output_desc.contains("body_output_desc") && !output_desc["body_output_desc"].is_none()) {
|
||||||
const auto& concat_output_desc = output_desc["concat_output_desc"].cast<py::list>();
|
for (py::handle h : output_desc["body_output_desc"].cast<py::list>()) {
|
||||||
for (py::handle h : body_output_desc) {
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
auto body_output = std::make_shared<ov::op::util::SubGraphOp::BodyOutputDescription>(
|
||||||
auto body_output = std::make_shared<ov::op::util::SubGraphOp::BodyOutputDescription>(
|
desc["body_value_idx"].cast<int64_t>(),
|
||||||
desc["body_value_idx"].cast<int64_t>(),
|
desc["output_idx"].cast<int64_t>(),
|
||||||
desc["output_idx"].cast<int64_t>(),
|
desc["iteration"].cast<int64_t>());
|
||||||
desc["iteration"].cast<int64_t>());
|
output_descs.push_back(body_output);
|
||||||
output_descs.push_back(body_output);
|
}
|
||||||
}
|
}
|
||||||
|
if (output_desc.contains("concat_output_desc") && !output_desc["concat_output_desc"].is_none()) {
|
||||||
for (py::handle h : concat_output_desc) {
|
for (py::handle h : output_desc["concat_output_desc"].cast<py::list>()) {
|
||||||
const py::dict& desc = h.cast<py::dict>();
|
const py::dict& desc = h.cast<py::dict>();
|
||||||
auto concat_output = std::make_shared<ov::op::util::SubGraphOp::ConcatOutputDescription>(
|
auto concat_output = std::make_shared<ov::op::util::SubGraphOp::ConcatOutputDescription>(
|
||||||
desc["body_value_idx"].cast<int64_t>(),
|
desc["body_value_idx"].cast<int64_t>(),
|
||||||
desc["output_idx"].cast<int64_t>(),
|
desc["output_idx"].cast<int64_t>(),
|
||||||
desc["start"].cast<int64_t>(),
|
desc["start"].cast<int64_t>(),
|
||||||
desc["stride"].cast<int64_t>(),
|
desc["stride"].cast<int64_t>(),
|
||||||
desc["part_size"].cast<int64_t>(),
|
desc["part_size"].cast<int64_t>(),
|
||||||
desc["end"].cast<int64_t>(),
|
desc["end"].cast<int64_t>(),
|
||||||
desc["axis"].cast<int64_t>());
|
desc["axis"].cast<int64_t>());
|
||||||
output_descs.push_back(concat_output);
|
output_descs.push_back(concat_output);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
a->set(output_descs);
|
a->set(output_descs);
|
||||||
} else if (const auto& a = ov::as_type<ov::AttributeAdapter<ov::op::v5::Loop::SpecialBodyPorts>>(&adapter)) {
|
} else if (const auto& a = ov::as_type<ov::AttributeAdapter<ov::op::v5::Loop::SpecialBodyPorts>>(&adapter)) {
|
||||||
@ -241,7 +244,7 @@ void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
|||||||
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
void util::DictAttributeDeserializer::on_adapter(const std::string& name,
|
||||||
ov::ValueAccessor<std::shared_ptr<ov::Function>>& adapter) {
|
ov::ValueAccessor<std::shared_ptr<ov::Function>>& adapter) {
|
||||||
if (m_attributes.contains(name)) {
|
if (m_attributes.contains(name)) {
|
||||||
if (name == "body") {
|
if (name == "body" || name == "then_body" || name == "else_body") {
|
||||||
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
|
const py::dict& body_attrs = m_attributes[name.c_str()].cast<py::dict>();
|
||||||
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ov::NodeVector>());
|
const auto& body_outputs = as_output_vector(body_attrs["results"].cast<ov::NodeVector>());
|
||||||
const auto& body_parameters = body_attrs["parameters"].cast<ov::ParameterVector>();
|
const auto& body_parameters = body_attrs["parameters"].cast<ov::ParameterVector>();
|
||||||
|
175
src/bindings/python/tests/test_ngraph/test_if.py
Normal file
175
src/bindings/python/tests/test_ngraph/test_if.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import numpy as np
|
||||||
|
import openvino.runtime.opset8 as ov
|
||||||
|
import pytest
|
||||||
|
from openvino.runtime.utils.tensor_iterator_types import (
|
||||||
|
GraphBody,
|
||||||
|
TensorIteratorInvariantInputDesc,
|
||||||
|
TensorIteratorBodyOutputDesc,
|
||||||
|
)
|
||||||
|
from tests.runtime import get_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def create_simple_if_with_two_outputs(condition_val):
|
||||||
|
condition = ov.constant(condition_val, dtype=np.bool)
|
||||||
|
|
||||||
|
# then_body
|
||||||
|
X_t = ov.parameter([], np.float32, "X")
|
||||||
|
Y_t = ov.parameter([], np.float32, "Y")
|
||||||
|
Z_t = ov.parameter([], np.float32, "Z")
|
||||||
|
|
||||||
|
add_t = ov.add(X_t, Y_t)
|
||||||
|
mul_t = ov.multiply(Y_t, Z_t)
|
||||||
|
then_body_res_1 = ov.result(add_t)
|
||||||
|
then_body_res_2 = ov.result(mul_t)
|
||||||
|
then_body = GraphBody([X_t, Y_t, Z_t], [then_body_res_1, then_body_res_2])
|
||||||
|
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1),
|
||||||
|
TensorIteratorInvariantInputDesc(3, 2)]
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
X_e = ov.parameter([], np.float32, "X")
|
||||||
|
Z_e = ov.parameter([], np.float32, "Z")
|
||||||
|
W_e = ov.parameter([], np.float32, "W")
|
||||||
|
|
||||||
|
add_e = ov.add(X_e, W_e)
|
||||||
|
pow_e = ov.power(W_e, Z_e)
|
||||||
|
else_body_res_1 = ov.result(add_e)
|
||||||
|
else_body_res_2 = ov.result(pow_e)
|
||||||
|
else_body = GraphBody([X_e, Z_e, W_e], [else_body_res_1, else_body_res_2])
|
||||||
|
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1),
|
||||||
|
TensorIteratorInvariantInputDesc(4, 2)]
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
X = ov.constant(15.0, dtype=np.float32)
|
||||||
|
Y = ov.constant(-5.0, dtype=np.float32)
|
||||||
|
Z = ov.constant(4.0, dtype=np.float32)
|
||||||
|
W = ov.constant(2.0, dtype=np.float32)
|
||||||
|
if_node = ov.if_op(condition, [X, Y, Z, W], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
return if_node
|
||||||
|
|
||||||
|
|
||||||
|
def create_diff_if_with_two_outputs(condition_val):
|
||||||
|
condition = ov.constant(condition_val, dtype=np.bool)
|
||||||
|
|
||||||
|
# then_body
|
||||||
|
X_t = ov.parameter([2], np.float32, "X")
|
||||||
|
Y_t = ov.parameter([2], np.float32, "Y")
|
||||||
|
mmul_t = ov.matmul(X_t, Y_t, False, False)
|
||||||
|
mul_t = ov.multiply(Y_t, X_t)
|
||||||
|
then_body_res_1 = ov.result(mmul_t)
|
||||||
|
then_body_res_2 = ov.result(mul_t)
|
||||||
|
then_body = GraphBody([X_t, Y_t], [then_body_res_1, then_body_res_2])
|
||||||
|
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
X_e = ov.parameter([2], np.float32, "X")
|
||||||
|
Z_e = ov.parameter([], np.float32, "Z")
|
||||||
|
mul_e = ov.multiply(X_e, Z_e)
|
||||||
|
else_body_res_1 = ov.result(Z_e)
|
||||||
|
else_body_res_2 = ov.result(mul_e)
|
||||||
|
else_body = GraphBody([X_e, Z_e], [else_body_res_1, else_body_res_2])
|
||||||
|
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1)]
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
X = ov.constant([3, 4], dtype=np.float32)
|
||||||
|
Y = ov.constant([2, 1], dtype=np.float32)
|
||||||
|
Z = ov.constant(4.0, dtype=np.float32)
|
||||||
|
if_node = ov.if_op(condition, [X, Y, Z], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
return if_node
|
||||||
|
|
||||||
|
|
||||||
|
def simple_if(condition_val):
|
||||||
|
condition = ov.constant(condition_val, dtype=np.bool)
|
||||||
|
# then_body
|
||||||
|
X_t = ov.parameter([2], np.float32, "X")
|
||||||
|
Y_t = ov.parameter([2], np.float32, "Y")
|
||||||
|
|
||||||
|
then_mul = ov.multiply(X_t, Y_t)
|
||||||
|
then_body_res_1 = ov.result(then_mul)
|
||||||
|
then_body = GraphBody([X_t, Y_t], [then_body_res_1])
|
||||||
|
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
X_e = ov.parameter([2], np.float32, "X")
|
||||||
|
Y_e = ov.parameter([2], np.float32, "Y")
|
||||||
|
add_e = ov.add(X_e, Y_e)
|
||||||
|
else_body_res_1 = ov.result(add_e)
|
||||||
|
else_body = GraphBody([X_e, Y_e], [else_body_res_1])
|
||||||
|
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
X = ov.constant([3, 4], dtype=np.float32)
|
||||||
|
Y = ov.constant([2, 1], dtype=np.float32)
|
||||||
|
if_node = ov.if_op(condition, [X, Y], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
relu = ov.relu(if_node)
|
||||||
|
return relu
|
||||||
|
|
||||||
|
|
||||||
|
def simple_if_without_parameters(condition_val):
|
||||||
|
condition = ov.constant(condition_val, dtype=np.bool)
|
||||||
|
|
||||||
|
# then_body
|
||||||
|
then_constant = ov.constant(0.7, dtype=np.float)
|
||||||
|
then_body_res_1 = ov.result(then_constant)
|
||||||
|
then_body = GraphBody([], [then_body_res_1])
|
||||||
|
then_body_inputs = []
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
else_const = ov.constant(9.0, dtype=np.float)
|
||||||
|
else_body_res_1 = ov.result(else_const)
|
||||||
|
else_body = GraphBody([], [else_body_res_1])
|
||||||
|
else_body_inputs = []
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
if_node = ov.if_op(condition, [], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
relu = ov.relu(if_node)
|
||||||
|
return relu
|
||||||
|
|
||||||
|
|
||||||
|
def check_results(results, expected_results):
|
||||||
|
assert len(results) == len(expected_results)
|
||||||
|
for id_result, res in enumerate(results):
|
||||||
|
assert np.allclose(res, expected_results[id_result])
|
||||||
|
|
||||||
|
|
||||||
|
def check_if(if_model, cond_val, exp_results):
|
||||||
|
last_node = if_model(cond_val)
|
||||||
|
runtime = get_runtime()
|
||||||
|
computation = runtime.computation(last_node)
|
||||||
|
results = computation()
|
||||||
|
check_results(results, exp_results)
|
||||||
|
|
||||||
|
|
||||||
|
# After deleting evalute method for if, constant folding stopped working.
|
||||||
|
# As result bug with id 67255 began to appear
|
||||||
|
@pytest.mark.xfail(reason="bug 67255")
|
||||||
|
def test_if_with_two_outputs():
|
||||||
|
check_if(create_simple_if_with_two_outputs, True,
|
||||||
|
[np.array([10], dtype=np.float32), np.array([-20], dtype=np.float32)])
|
||||||
|
check_if(create_simple_if_with_two_outputs, False,
|
||||||
|
[np.array([17], dtype=np.float32), np.array([16], dtype=np.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="bug 67255")
|
||||||
|
def test_diff_if_with_two_outputs():
|
||||||
|
check_if(create_diff_if_with_two_outputs, True,
|
||||||
|
[np.array([10], dtype=np.float32), np.array([6, 4], dtype=np.float32)])
|
||||||
|
check_if(create_diff_if_with_two_outputs, False,
|
||||||
|
[np.array([4], dtype=np.float32), np.array([12, 16], dtype=np.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_if():
|
||||||
|
check_if(simple_if, True, [np.array([6, 4], dtype=np.float32)])
|
||||||
|
check_if(simple_if, False, [np.array([5, 5], dtype=np.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_if_without_body_parameters():
|
||||||
|
check_if(simple_if_without_parameters, True, [np.array([0.7], dtype=np.float32)])
|
||||||
|
check_if(simple_if_without_parameters, False, [np.array([9.0], dtype=np.float32)])
|
175
src/bindings/python/tests_compatibility/test_ngraph/test_if.py
Normal file
175
src/bindings/python/tests_compatibility/test_ngraph/test_if.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import ngraph as ng
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from ngraph.utils.tensor_iterator_types import (
|
||||||
|
GraphBody,
|
||||||
|
TensorIteratorInvariantInputDesc,
|
||||||
|
TensorIteratorBodyOutputDesc,
|
||||||
|
)
|
||||||
|
from tests_compatibility.runtime import get_runtime
|
||||||
|
|
||||||
|
|
||||||
|
def create_simple_if_with_two_outputs(condition_val):
|
||||||
|
condition = ng.constant(condition_val, dtype=np.bool)
|
||||||
|
|
||||||
|
# then_body
|
||||||
|
X_t = ng.parameter([], np.float32, "X")
|
||||||
|
Y_t = ng.parameter([], np.float32, "Y")
|
||||||
|
Z_t = ng.parameter([], np.float32, "Z")
|
||||||
|
|
||||||
|
add_t = ng.add(X_t, Y_t)
|
||||||
|
mul_t = ng.multiply(Y_t, Z_t)
|
||||||
|
then_body_res_1 = ng.result(add_t)
|
||||||
|
then_body_res_2 = ng.result(mul_t)
|
||||||
|
then_body = GraphBody([X_t, Y_t, Z_t], [then_body_res_1, then_body_res_2])
|
||||||
|
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1),
|
||||||
|
TensorIteratorInvariantInputDesc(3, 2)]
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
X_e = ng.parameter([], np.float32, "X")
|
||||||
|
Z_e = ng.parameter([], np.float32, "Z")
|
||||||
|
W_e = ng.parameter([], np.float32, "W")
|
||||||
|
|
||||||
|
add_e = ng.add(X_e, W_e)
|
||||||
|
pow_e = ng.power(W_e, Z_e)
|
||||||
|
else_body_res_1 = ng.result(add_e)
|
||||||
|
else_body_res_2 = ng.result(pow_e)
|
||||||
|
else_body = GraphBody([X_e, Z_e, W_e], [else_body_res_1, else_body_res_2])
|
||||||
|
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1),
|
||||||
|
TensorIteratorInvariantInputDesc(4, 2)]
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
X = ng.constant(15.0, dtype=np.float32)
|
||||||
|
Y = ng.constant(-5.0, dtype=np.float32)
|
||||||
|
Z = ng.constant(4.0, dtype=np.float32)
|
||||||
|
W = ng.constant(2.0, dtype=np.float32)
|
||||||
|
if_node = ng.if_op(condition, [X, Y, Z, W], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
return if_node
|
||||||
|
|
||||||
|
|
||||||
|
def create_diff_if_with_two_outputs(condition_val):
|
||||||
|
condition = ng.constant(condition_val, dtype=np.bool)
|
||||||
|
|
||||||
|
# then_body
|
||||||
|
X_t = ng.parameter([2], np.float32, "X")
|
||||||
|
Y_t = ng.parameter([2], np.float32, "Y")
|
||||||
|
mmul_t = ng.matmul(X_t, Y_t, False, False)
|
||||||
|
mul_t = ng.multiply(Y_t, X_t)
|
||||||
|
then_body_res_1 = ng.result(mmul_t)
|
||||||
|
then_body_res_2 = ng.result(mul_t)
|
||||||
|
then_body = GraphBody([X_t, Y_t], [then_body_res_1, then_body_res_2])
|
||||||
|
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
X_e = ng.parameter([2], np.float32, "X")
|
||||||
|
Z_e = ng.parameter([], np.float32, "Z")
|
||||||
|
mul_e = ng.multiply(X_e, Z_e)
|
||||||
|
else_body_res_1 = ng.result(Z_e)
|
||||||
|
else_body_res_2 = ng.result(mul_e)
|
||||||
|
else_body = GraphBody([X_e, Z_e], [else_body_res_1, else_body_res_2])
|
||||||
|
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(3, 1)]
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0), TensorIteratorBodyOutputDesc(1, 1)]
|
||||||
|
|
||||||
|
X = ng.constant([3, 4], dtype=np.float32)
|
||||||
|
Y = ng.constant([2, 1], dtype=np.float32)
|
||||||
|
Z = ng.constant(4.0, dtype=np.float32)
|
||||||
|
if_node = ng.if_op(condition, [X, Y, Z], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
return if_node
|
||||||
|
|
||||||
|
|
||||||
|
def simple_if(condition_val):
|
||||||
|
condition = ng.constant(condition_val, dtype=np.bool)
|
||||||
|
# then_body
|
||||||
|
X_t = ng.parameter([2], np.float32, "X")
|
||||||
|
Y_t = ng.parameter([2], np.float32, "Y")
|
||||||
|
|
||||||
|
then_mul = ng.multiply(X_t, Y_t)
|
||||||
|
then_body_res_1 = ng.result(then_mul)
|
||||||
|
then_body = GraphBody([X_t, Y_t], [then_body_res_1])
|
||||||
|
then_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
X_e = ng.parameter([2], np.float32, "X")
|
||||||
|
Y_e = ng.parameter([2], np.float32, "Y")
|
||||||
|
add_e = ng.add(X_e, Y_e)
|
||||||
|
else_body_res_1 = ng.result(add_e)
|
||||||
|
else_body = GraphBody([X_e, Y_e], [else_body_res_1])
|
||||||
|
else_body_inputs = [TensorIteratorInvariantInputDesc(1, 0), TensorIteratorInvariantInputDesc(2, 1)]
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
X = ng.constant([3, 4], dtype=np.float32)
|
||||||
|
Y = ng.constant([2, 1], dtype=np.float32)
|
||||||
|
if_node = ng.if_op(condition, [X, Y], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
relu = ng.relu(if_node)
|
||||||
|
return relu
|
||||||
|
|
||||||
|
|
||||||
|
def simple_if_without_parameters(condition_val):
|
||||||
|
condition = ng.constant(condition_val, dtype=np.bool)
|
||||||
|
|
||||||
|
# then_body
|
||||||
|
then_constant = ng.constant(0.7, dtype=np.float)
|
||||||
|
then_body_res_1 = ng.result(then_constant)
|
||||||
|
then_body = GraphBody([], [then_body_res_1])
|
||||||
|
then_body_inputs = []
|
||||||
|
then_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
# else_body
|
||||||
|
else_const = ng.constant(9.0, dtype=np.float)
|
||||||
|
else_body_res_1 = ng.result(else_const)
|
||||||
|
else_body = GraphBody([], [else_body_res_1])
|
||||||
|
else_body_inputs = []
|
||||||
|
else_body_outputs = [TensorIteratorBodyOutputDesc(0, 0)]
|
||||||
|
|
||||||
|
if_node = ng.if_op(condition, [], (then_body, else_body), (then_body_inputs, else_body_inputs),
|
||||||
|
(then_body_outputs, else_body_outputs))
|
||||||
|
relu = ng.relu(if_node)
|
||||||
|
return relu
|
||||||
|
|
||||||
|
|
||||||
|
def check_results(results, expected_results):
|
||||||
|
assert len(results) == len(expected_results)
|
||||||
|
for id_result, res in enumerate(results):
|
||||||
|
assert np.allclose(res, expected_results[id_result])
|
||||||
|
|
||||||
|
|
||||||
|
def check_if(if_model, cond_val, exp_results):
|
||||||
|
last_node = if_model(cond_val)
|
||||||
|
runtime = get_runtime()
|
||||||
|
computation = runtime.computation(last_node)
|
||||||
|
results = computation()
|
||||||
|
check_results(results, exp_results)
|
||||||
|
|
||||||
|
|
||||||
|
# After deleting evalute method for if, constant folding stopped working.
|
||||||
|
# As result bug with id 67255 began to appear
|
||||||
|
@pytest.mark.xfail(reason="bug 67255")
|
||||||
|
def test_if_with_two_outputs():
|
||||||
|
check_if(create_simple_if_with_two_outputs, True,
|
||||||
|
[np.array([10], dtype=np.float32), np.array([-20], dtype=np.float32)])
|
||||||
|
check_if(create_simple_if_with_two_outputs, False,
|
||||||
|
[np.array([17], dtype=np.float32), np.array([16], dtype=np.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="bug 67255")
|
||||||
|
def test_diff_if_with_two_outputs():
|
||||||
|
check_if(create_diff_if_with_two_outputs, True,
|
||||||
|
[np.array([10], dtype=np.float32), np.array([6, 4], dtype=np.float32)])
|
||||||
|
check_if(create_diff_if_with_two_outputs, False,
|
||||||
|
[np.array([4], dtype=np.float32), np.array([12, 16], dtype=np.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_if():
|
||||||
|
check_if(simple_if, True, [np.array([6, 4], dtype=np.float32)])
|
||||||
|
check_if(simple_if, False, [np.array([5, 5], dtype=np.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_if_without_body_parameters():
|
||||||
|
check_if(simple_if_without_parameters, True, [np.array([0.7], dtype=np.float32)])
|
||||||
|
check_if(simple_if_without_parameters, False, [np.array([9.0], dtype=np.float32)])
|
Loading…
Reference in New Issue
Block a user