Add eltwise types resolving. Support big int constants. (#15415)
* Add eltwise types resolving. Support big int constants. * Update src/bindings/python/src/openvino/frontend/pytorch/decoder.py * Small fix * Fix some cases * Add tests for add in different types * Add tests for mul * Add tests for sub and div * Small fixes * Return list handling (needed for empty lists) * Add test for empty list * Update src/frontends/pytorch/src/op/mul.cpp Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> * Use refs instead of ptrs * Apply suggestions from code review Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> * Apply code review suggestions * Fix code style * Add more eltwise ops --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
8051c2d535
commit
92649105ed
@ -16,7 +16,10 @@ def get_type_from_py_type(value):
|
||||
if isinstance(value, float):
|
||||
return OVType.f32
|
||||
if isinstance(value, int):
|
||||
return OVType.i32
|
||||
# Python int is 64 bit, but we will convert it to int32 except cases when it can't fit in 32 bits
|
||||
if torch.iinfo(torch.int).min <= value <= torch.iinfo(torch.int).max:
|
||||
return OVType.i32
|
||||
return OVType.i64
|
||||
if isinstance(value, bool):
|
||||
return OVType.boolean
|
||||
return OVType.dynamic
|
||||
@ -27,13 +30,13 @@ def ivalue_to_constant(ivalue):
|
||||
if ov_type.is_static():
|
||||
return op.Constant(ov_type, Shape([]), [ivalue]).outputs()
|
||||
|
||||
if isinstance(ivalue, list):
|
||||
if isinstance(ivalue, (list, tuple)):
|
||||
assert len(ivalue) > 0, "Can't deduce type for empty list"
|
||||
ov_type = get_type_from_py_type(ivalue[0])
|
||||
assert ov_type.is_static(), "Can't deduce type for list"
|
||||
return op.Constant(ov_type, Shape([len(ivalue)]), ivalue).outputs()
|
||||
|
||||
if ivalue.type() in pt_to_ov_type_map:
|
||||
if isinstance(ivalue, torch.Tensor) and ivalue.type() in pt_to_ov_type_map:
|
||||
try:
|
||||
ovshape = PartialShape(ivalue.size())
|
||||
ovtype = pt_to_ov_type_map[ivalue.type()]
|
||||
@ -46,6 +49,7 @@ def ivalue_to_constant(ivalue):
|
||||
ovshape = PartialShape(nvalues.shape)
|
||||
ov_const = op.Constant(ovtype, ovshape.get_shape(), nvalues.flatten().tolist())
|
||||
return ov_const.outputs()
|
||||
return None
|
||||
|
||||
|
||||
def get_value_from_getattr(getattr_node, self_module):
|
||||
@ -69,25 +73,22 @@ def get_value_from_getattr(getattr_node, self_module):
|
||||
pt_to_ov_type_map = {
|
||||
"float": OVType.f32,
|
||||
"int": OVType.i32,
|
||||
"bool": OVType.boolean,
|
||||
"torch.float16": OVType.f16,
|
||||
"torch.float32": OVType.f32,
|
||||
"torch.float64": OVType.f64,
|
||||
"torch.uint8": OVType.u8,
|
||||
"torch.int8": OVType.i8,
|
||||
"torch.int32": OVType.i32,
|
||||
"torch.bool": OVType.boolean,
|
||||
"torch.int64": OVType.i64,
|
||||
"torch.bool": OVType.boolean,
|
||||
"torch.DoubleTensor": OVType.f64,
|
||||
"torch.FloatTensor": OVType.f32,
|
||||
"torch.IntTensor": OVType.i32,
|
||||
"torch.LongTensor": OVType.i64,
|
||||
"torch.BoolTensor": OVType.boolean,
|
||||
}
|
||||
|
||||
pt_to_py_type_map = {
|
||||
"float": "float",
|
||||
"int": "int",
|
||||
"torch.float32": "float",
|
||||
"torch.int32": "int",
|
||||
"torch.int64": "int",
|
||||
"torch.bool": "bool",
|
||||
}
|
||||
|
||||
np_to_ov_type_map = {
|
||||
"float32": OVType.f32,
|
||||
"int32": OVType.i32,
|
||||
@ -106,7 +107,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
self.graph_element = graph_element
|
||||
self.pt_module = pt_module
|
||||
|
||||
def inputs(self):
|
||||
def inputs(self) -> list:
|
||||
return [x.unique() for x in self.graph_element.inputs()]
|
||||
|
||||
def get_input(self, index: int):
|
||||
@ -150,7 +151,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
# Not yet recognized
|
||||
return OVAny(OVType.dynamic)
|
||||
|
||||
def get_shape_for_value(self, value):
|
||||
def get_shape_for_value(self, value: torch.Value):
|
||||
if value.isCompleteTensor():
|
||||
ps = PartialShape(value.type().sizes())
|
||||
return ps
|
||||
@ -161,7 +162,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
pass
|
||||
return PartialShape.dynamic()
|
||||
|
||||
def get_type_for_value(self, value):
|
||||
def get_type_for_value(self, value: torch.Value):
|
||||
full_type = self._get_known_type_for_value(value.type())
|
||||
return full_type
|
||||
|
||||
@ -184,46 +185,46 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def get_subgraph_size(self) -> int:
|
||||
return len(self.get_subgraphs()) if hasattr(self.graph_element, "blocks") else 1
|
||||
|
||||
def visit_subgraph(self, node_visitor):
|
||||
def visit_subgraph(self, node_visitor) -> None:
|
||||
# make sure topological order is satisfied
|
||||
for node in self.graph_element.nodes():
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, node)
|
||||
self.m_decoders.append(decoder)
|
||||
node_visitor(decoder)
|
||||
|
||||
def get_subgraphs(self):
|
||||
def get_subgraphs(self) -> list:
|
||||
return list(self.graph_element.blocks())
|
||||
|
||||
def get_subgraph_decoder(self, index):
|
||||
def get_subgraph_decoder(self, index: int):
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index])
|
||||
self.m_decoders.append(decoder)
|
||||
return decoder
|
||||
|
||||
def get_op_type(self):
|
||||
def get_op_type(self) -> str:
|
||||
return self.graph_element.kind()
|
||||
|
||||
def get_schema(self):
|
||||
def get_schema(self) -> str:
|
||||
return self.graph_element.schema()
|
||||
|
||||
def outputs(self):
|
||||
def outputs(self) -> list:
|
||||
return [x.unique() for x in self.graph_element.outputs()]
|
||||
|
||||
def _raw_outputs(self):
|
||||
def _raw_outputs(self) -> list:
|
||||
return list(self.graph_element.outputs())
|
||||
|
||||
def _raw_output(self, index):
|
||||
def _raw_output(self, index: int):
|
||||
return self._raw_outputs()[index]
|
||||
|
||||
def _raw_inputs(self):
|
||||
def _raw_inputs(self) -> list:
|
||||
return list(self.graph_element.inputs())
|
||||
|
||||
def _raw_input(self, index):
|
||||
def _raw_input(self, index: int):
|
||||
return self._raw_inputs()[index]
|
||||
|
||||
def num_of_outputs(self):
|
||||
return len(self.outputs())
|
||||
|
||||
def output(self, index):
|
||||
def output(self, index: int):
|
||||
return self.outputs()[index]
|
||||
|
||||
def mark_node(self, node):
|
||||
@ -232,7 +233,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def try_decode_get_attr(self):
|
||||
pt_value = get_value_from_getattr(self.graph_element, self.pt_module)
|
||||
assert pt_value is not None, "Couldn't retrieve value from prim::GetAttr"
|
||||
if not isinstance(pt_value, torch.jit.ScriptModule) or isinstance(pt_value, torch.jit.TracedModule):
|
||||
if not isinstance(pt_value, (torch.jit.ScriptModule, torch.jit.TracedModule)):
|
||||
return ivalue_to_constant(pt_value)
|
||||
else:
|
||||
return []
|
||||
@ -244,17 +245,10 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
|
||||
pt_type = pt_value.type()
|
||||
if isinstance(pt_type, torch.TensorType):
|
||||
return self.as_constant_tensor(pt_value)
|
||||
return self._as_constant_tensor(pt_value)
|
||||
if isinstance(pt_type, torch.ListType):
|
||||
return self.as_constant_list(pt_value)
|
||||
if str(pt_type) in ["torch.int32", "int"]:
|
||||
return op.Constant(OVType.i32, Shape([]), [pt_value.toIValue()]).outputs()
|
||||
if str(pt_type) in ["torch.float", "torch.FloatType", "float"]:
|
||||
return op.Constant(OVType.f32, Shape([]), [pt_value.toIValue()]).outputs()
|
||||
if str(pt_type) in ["torch.bool", "bool"]:
|
||||
return op.Constant(OVType.boolean, Shape([]), [pt_value.toIValue()]).outputs()
|
||||
|
||||
return None
|
||||
return self._as_constant_list(pt_value)
|
||||
return ivalue_to_constant(pt_value.toIValue())
|
||||
|
||||
def as_string(self):
|
||||
if not self.get_op_type() == "prim::Constant":
|
||||
@ -265,7 +259,8 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return pt_value.toIValue()
|
||||
return None
|
||||
|
||||
def as_constant_tensor(self, pt_value):
|
||||
@staticmethod
|
||||
def _as_constant_tensor(pt_value: torch.Value):
|
||||
ivalue = pt_value.toIValue()
|
||||
if pt_value.isCompleteTensor():
|
||||
try:
|
||||
@ -295,7 +290,8 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return ivalue_to_constant(ivalue)
|
||||
return None
|
||||
|
||||
def as_constant_list(self, pt_value):
|
||||
@staticmethod
|
||||
def _as_constant_list(pt_value: torch.Value):
|
||||
# For now it is treat a list as a 1D tensor; it is required by converters to avoid need to massively
|
||||
# rewrite them in that part where constant attributes are queried
|
||||
pt_element_type = str(pt_value.type().getElementType())
|
||||
@ -308,7 +304,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
ov_const = op.Constant(ovtype, ovshape.get_shape(), ivalue)
|
||||
return ov_const.outputs()
|
||||
|
||||
def input_is_none(self, index):
|
||||
def input_is_none(self, index: int) -> bool:
|
||||
if index >= len(self.inputs()) or self._raw_input(index) is None:
|
||||
return True
|
||||
else:
|
||||
|
@ -2,8 +2,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/add.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,12 +15,14 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_add(NodeContext& context) {
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
align_eltwise_input_types(context, lhs, rhs);
|
||||
if (!context.input_is_none(2)) {
|
||||
auto converted_alpha = std::make_shared<opset10::ConvertLike>(context.get_input(2), rhs);
|
||||
rhs = std::make_shared<opset10::Multiply>(converted_alpha, rhs);
|
||||
auto converted_alpha = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(context.get_input(2), rhs));
|
||||
rhs = context.mark_node(std::make_shared<ov::op::v1::Multiply>(converted_alpha, rhs));
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Add>(context.get_input(0), rhs))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Add>(lhs, rhs))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,9 +3,14 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/floor.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
@ -14,21 +19,27 @@ namespace op {
|
||||
OutputVector translate_div(NodeContext& context) {
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto res = context.mark_node(std::make_shared<opset10::Divide>(x, y, true));
|
||||
std::string rounding_mode = "";
|
||||
if (!context.input_is_none(2)) {
|
||||
auto rounding_mode = context.const_input<std::string>(2);
|
||||
if (rounding_mode == "floor") {
|
||||
res = context.mark_node(std::make_shared<opset10::Floor>(res));
|
||||
} else if (rounding_mode == "trunc") {
|
||||
const auto convert = context.mark_node(std::make_shared<opset10::Convert>(res, element::i64));
|
||||
res = context.mark_node(std::make_shared<opset10::ConvertLike>(convert, x));
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false,
|
||||
"Openvino Pytorch Frontend doesn't support rounding mode ",
|
||||
rounding_mode,
|
||||
" for aten::div");
|
||||
rounding_mode = context.const_input<std::string>(2);
|
||||
}
|
||||
if (rounding_mode.empty()) {
|
||||
// if no rounding mode and both inputs are ints cast BOTH to fp32
|
||||
const auto x_dtype = x.get_element_type();
|
||||
const auto y_dtype = y.get_element_type();
|
||||
if (x_dtype.is_static() && x_dtype.is_integral() && y_dtype.is_static() && y_dtype.is_integral()) {
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
|
||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::f32));
|
||||
}
|
||||
}
|
||||
align_eltwise_input_types(context, x, y, true);
|
||||
auto res = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
|
||||
if (rounding_mode == "floor") {
|
||||
res = context.mark_node(std::make_shared<v0::Floor>(res));
|
||||
} else if (rounding_mode == "trunc") {
|
||||
const auto convert = context.mark_node(std::make_shared<v0::Convert>(res, element::i64));
|
||||
res = context.mark_node(std::make_shared<v1::ConvertLike>(convert, x));
|
||||
}
|
||||
return {res};
|
||||
};
|
||||
|
||||
|
25
src/frontends/pytorch/src/op/pow.cpp
Normal file
25
src/frontends/pytorch/src/op/pow.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/power.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_pow(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
align_eltwise_input_types(context, lhs, rhs, true);
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Power>(lhs, rhs))};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -3,9 +3,13 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
@ -14,13 +18,14 @@ namespace op {
|
||||
OutputVector translate_sub(NodeContext& context) {
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
align_eltwise_input_types(context, x, y);
|
||||
// default alpha is 1 so no need to multiply if alpha is not provided
|
||||
if (!context.input_is_none(2)) {
|
||||
auto alpha = context.get_input(2);
|
||||
auto casted_alpha = context.mark_node(std::make_shared<opset10::ConvertLike>(alpha, y));
|
||||
y = context.mark_node(std::make_shared<opset10::Multiply>(casted_alpha, y));
|
||||
auto casted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, y));
|
||||
y = context.mark_node(std::make_shared<v1::Multiply>(casted_alpha, y));
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Subtract>(x, y))};
|
||||
return {context.mark_node(std::make_shared<v1::Subtract>(x, y))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -77,6 +77,7 @@ OP_CONVERTER(translate_numel);
|
||||
OP_CONVERTER(translate_ones);
|
||||
OP_CONVERTER(translate_ones_like);
|
||||
OP_CONVERTER(translate_pad);
|
||||
OP_CONVERTER(translate_pow);
|
||||
OP_CONVERTER(translate_reciprocal);
|
||||
OP_CONVERTER(translate_relu6);
|
||||
OP_CONVERTER(translate_remainder);
|
||||
@ -175,7 +176,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::dropout_", op::skip_node},
|
||||
{"aten::elu", op::translate_elu},
|
||||
{"aten::embedding", op::translate_embedding},
|
||||
{"aten::eq", op::translate_1to1_match_2_inputs<opset10::Equal>},
|
||||
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
||||
{"aten::exp", op::translate_1to1_match_1_inputs<opset10::Exp>},
|
||||
{"aten::expand", op::translate_expand},
|
||||
{"aten::expand_as", op::translate_expand_as},
|
||||
@ -191,8 +192,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::gelu", op::translate_gelu},
|
||||
{"aten::glu", op::translate_glu},
|
||||
{"aten::group_norm", op::translate_group_norm},
|
||||
{"aten::ge", op::translate_1to1_match_2_inputs<opset10::GreaterEqual>},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs<opset10::Greater>},
|
||||
{"aten::ge", op::translate_1to1_match_2_inputs_align_types<opset10::GreaterEqual>},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
|
||||
{"aten::grid_sampler", op::translate_grid_sampler},
|
||||
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>},
|
||||
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>},
|
||||
@ -209,8 +210,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::leaky_relu_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::PRelu>>},
|
||||
{"aten::len", op::translate_len},
|
||||
{"aten::linear", op::translate_linear},
|
||||
{"aten::le", op::translate_1to1_match_2_inputs<opset10::LessEqual>},
|
||||
{"aten::lt", op::translate_1to1_match_2_inputs<opset10::Less>},
|
||||
{"aten::le", op::translate_1to1_match_2_inputs_align_types<opset10::LessEqual>},
|
||||
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
|
||||
{"aten::log", op::translate_log},
|
||||
{"aten::log_", op::inplace_op<op::translate_log>},
|
||||
{"aten::log2", op::translate_log2},
|
||||
@ -228,9 +229,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::mul", op::translate_1to1_match_2_inputs<opset10::Multiply>},
|
||||
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::Multiply>>},
|
||||
{"aten::ne", op::translate_1to1_match_2_inputs<opset10::NotEqual>},
|
||||
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
|
||||
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
|
||||
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
|
||||
{"aten::neg", op::translate_neg},
|
||||
{"aten::norm", op::translate_norm},
|
||||
{"aten::nonzero", op::translate_nonzero},
|
||||
@ -242,7 +243,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::ones_like", op::translate_ones_like},
|
||||
{"aten::pad", op::translate_pad},
|
||||
{"aten::permute", op::translate_1to1_match_2_inputs<opset10::Transpose>},
|
||||
{"aten::pow", op::translate_1to1_match_2_inputs<opset10::Power>},
|
||||
{"aten::pow", op::translate_pow},
|
||||
{"aten::reciprocal", op::translate_reciprocal},
|
||||
{"aten::relu", op::translate_1to1_match_1_inputs<opset10::Relu>},
|
||||
{"aten::relu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Relu>>},
|
||||
|
@ -459,6 +459,107 @@ Any simplified_type_interpret(Any type) {
|
||||
return type;
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::unordered_map<size_t, element::Type> bit_to_float{
|
||||
{16, element::f16},
|
||||
{32, element::f32},
|
||||
{64, element::f64},
|
||||
};
|
||||
std::unordered_map<size_t, element::Type> bit_to_int{
|
||||
// {4, element::i4}, torch don't have int4
|
||||
{8, element::i8},
|
||||
{16, element::i16},
|
||||
{32, element::i32},
|
||||
{64, element::i64},
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void align_eltwise_input_types(const NodeContext& context,
|
||||
ov::Output<ov::Node>& lhs,
|
||||
ov::Output<ov::Node>& rhs,
|
||||
bool align_scalars) {
|
||||
const auto& lhs_type = lhs.get_element_type();
|
||||
const auto& rhs_type = rhs.get_element_type();
|
||||
if (lhs_type.is_dynamic() || rhs_type.is_dynamic()) {
|
||||
// if any of types is not known, align to lhs type.
|
||||
// TODO: can be fixed with special operation?
|
||||
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
|
||||
return;
|
||||
}
|
||||
|
||||
// Both types are static, align types. If float and int types are used convert int type to f32, after that align
|
||||
// to the largest bitness, if both float or both int, just align bitness
|
||||
if (lhs_type == rhs_type)
|
||||
return;
|
||||
|
||||
// if one of operands is scalar, the resulting type is taken from the other operand except when scalar is float
|
||||
// type and other operand is int, in that case BOTH operands get fp32 type
|
||||
const auto& lhs_rank = lhs.get_partial_shape().rank();
|
||||
const auto& rhs_rank = rhs.get_partial_shape().rank();
|
||||
// consider dynamic rank as non scalar
|
||||
const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0;
|
||||
const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0;
|
||||
if (is_lhs_scalar && is_rhs_scalar) {
|
||||
// if both scalar, align to lhs
|
||||
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
|
||||
return;
|
||||
}
|
||||
auto lhs_dst_type = lhs_type;
|
||||
auto rhs_dst_type = rhs_type;
|
||||
if (is_lhs_scalar) {
|
||||
if (lhs_type.is_real() && !rhs_type.is_real()) {
|
||||
// if div we need to also align float types to highest bitness regardless of scalar
|
||||
if (!align_scalars)
|
||||
lhs_dst_type = element::f32;
|
||||
rhs_dst_type = element::f32;
|
||||
} else {
|
||||
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
|
||||
return;
|
||||
}
|
||||
} else if (is_rhs_scalar) {
|
||||
if (!lhs_type.is_real() && rhs_type.is_real()) {
|
||||
lhs_dst_type = element::f32;
|
||||
// if div we need to also align float types to highest bitness regardless of scalar
|
||||
if (!align_scalars)
|
||||
rhs_dst_type = element::f32;
|
||||
} else {
|
||||
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {
|
||||
// Do nothing with bool
|
||||
return;
|
||||
}
|
||||
|
||||
if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) {
|
||||
lhs_dst_type = element::f32;
|
||||
} else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) {
|
||||
rhs_dst_type = element::f32;
|
||||
}
|
||||
// Align bitness to higher
|
||||
if (lhs_dst_type.bitwidth() != rhs_dst_type.bitwidth()) {
|
||||
const auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth());
|
||||
element::Type* type_to_align = &lhs_dst_type;
|
||||
if (rhs_dst_type.bitwidth() < dst_bitness)
|
||||
type_to_align = &rhs_dst_type;
|
||||
if (type_to_align->is_real()) {
|
||||
*type_to_align = bit_to_float.at(dst_bitness);
|
||||
} else {
|
||||
*type_to_align = bit_to_int.at(dst_bitness);
|
||||
}
|
||||
}
|
||||
|
||||
// Cast to destination types
|
||||
if (lhs_dst_type != lhs_type) {
|
||||
lhs = context.mark_node(std::make_shared<opset10::Convert>(lhs, lhs_dst_type));
|
||||
}
|
||||
if (rhs_dst_type != rhs_type) {
|
||||
rhs = context.mark_node(std::make_shared<opset10::Convert>(rhs, rhs_dst_type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
@ -55,6 +55,11 @@ std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node>
|
||||
// TODO: Elimitate the need of this function by implementing more accurate custom data type handling
|
||||
Any simplified_type_interpret(Any type);
|
||||
|
||||
void align_eltwise_input_types(const NodeContext& context,
|
||||
ov::Output<ov::Node>& lhs,
|
||||
ov::Output<ov::Node>& rhs,
|
||||
bool align_scalars = false);
|
||||
|
||||
namespace op {
|
||||
template <OutputVector (*T)(NodeContext&), size_t idx = 0>
|
||||
OutputVector inplace_op(NodeContext& context) {
|
||||
@ -87,6 +92,20 @@ OutputVector translate_1to1_match_2_inputs(NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<T>(inputs[0], inputs[1]))};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Operation has less then 2 inputs.");
|
||||
for (int i = 2; i < inputs.size(); i++) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||
}
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
|
||||
auto lhs = inputs[0];
|
||||
auto rhs = inputs[1];
|
||||
align_eltwise_input_types(context, lhs, rhs);
|
||||
return {context.mark_node(std::make_shared<T>(lhs, rhs))};
|
||||
}
|
||||
|
||||
inline OutputVector return_false_scalar(NodeContext& context) {
|
||||
return {context.mark_node(ov::op::v0::Constant::create(element::boolean, Shape{}, {false}))};
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ def get_scripted_model(model):
|
||||
model = torch.jit.script(model)
|
||||
model.eval()
|
||||
model = torch.jit.freeze(model)
|
||||
print(model.inlined_graph) # will help debugging
|
||||
return model
|
||||
|
||||
|
||||
@ -82,3 +83,364 @@ def test_pytorch_decoder_get_input_type_none():
|
||||
assert isinstance(list(div_node.inputs())[2].type(), torch.NoneType)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, div_node)
|
||||
assert isinstance(nc_decoder.get_input_type(2).value, DecoderType.PyNone)
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_fp16_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.float16)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.f16
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_fp32_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.float32)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.f32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_fp64_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.float64)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.f64
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_bool_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 0], dtype=torch.bool)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.boolean
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_u8_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.uint8)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.u8
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_i8_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.int8)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i8
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_i32_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.int)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_i64_tensor():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class SomeTensor(torch.nn.Module):
|
||||
def forward(self):
|
||||
return torch.tensor([1, 2], dtype=torch.int64)
|
||||
|
||||
model = get_scripted_model(SomeTensor())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i64
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_int64_max():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
|
||||
class I64MaxConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return 9223372036854775807
|
||||
|
||||
model = get_scripted_model(I64MaxConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
int64_const = consts[0]
|
||||
print(int64_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, int64_const)
|
||||
assert nc_decoder.as_constant() is not None
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_int_list():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class ListConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return [1, 2]
|
||||
|
||||
model = get_scripted_model(ListConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
print(some_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_float_list():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class ListConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return [float(1), float(2)]
|
||||
|
||||
model = get_scripted_model(ListConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
print(some_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.f32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_bool_list():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class ListConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return [True, False]
|
||||
|
||||
model = get_scripted_model(ListConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
print(some_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.boolean
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_int_tuple():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class ListConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (1, 2)
|
||||
|
||||
model = get_scripted_model(ListConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
print(some_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_float_tuple():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class ListConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (float(1), float(2))
|
||||
|
||||
model = get_scripted_model(ListConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
print(some_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.f32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.xfail(reason="Bool tuple gets converted to i32 tuple.")
|
||||
def test_pytorch_decoder_can_convert_bool_tuple():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class ListConst(torch.nn.Module):
|
||||
def forward(self):
|
||||
return (True, False)
|
||||
|
||||
model = get_scripted_model(ListConst())
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 0
|
||||
some_const = consts[0]
|
||||
print(some_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, some_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.boolean
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([2])
|
||||
|
||||
|
||||
@pytest.mark.precommit
|
||||
def test_pytorch_decoder_can_convert_empty_list():
|
||||
from openvino.frontend.pytorch.decoder import TorchScriptPythonDecoder
|
||||
from openvino.runtime import PartialShape, Type
|
||||
|
||||
class aten_roll(torch.nn.Module):
|
||||
def __init__(self, shifts):
|
||||
super(aten_roll, self).__init__()
|
||||
self.shits = shifts
|
||||
|
||||
def forward(self, x):
|
||||
# roll has optional input dim, which is empty int list by default
|
||||
return torch.roll(x, self.shits)
|
||||
|
||||
model = get_scripted_model(aten_roll(1))
|
||||
consts = [n for n in model.inlined_graph.nodes() if n.kind() ==
|
||||
"prim::Constant"]
|
||||
assert len(consts) > 1
|
||||
empty_const = consts[1]
|
||||
print(empty_const)
|
||||
nc_decoder = TorchScriptPythonDecoder(model, empty_const)
|
||||
ov_const = nc_decoder.as_constant()
|
||||
assert ov_const is not None
|
||||
assert len(ov_const) == 1
|
||||
assert ov_const[0].get_element_type() == Type.i32
|
||||
assert ov_const[0].get_partial_shape() == PartialShape([0])
|
||||
|
@ -10,7 +10,8 @@ from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
@pytest.mark.parametrize('alpha', (-0.5, 0, 0.5, 1, 2))
|
||||
@pytest.mark.parametrize('input_rhs', (np.random.randn(2, 5, 3, 4).astype(np.float32),
|
||||
np.random.randn(1, 5, 3, 4).astype(np.float32),
|
||||
np.random.randn(
|
||||
1, 5, 3, 4).astype(np.float32),
|
||||
np.random.randn(1).astype(np.float32)))
|
||||
class TestAdd(PytorchLayerTest):
|
||||
|
||||
@ -36,3 +37,66 @@ class TestAdd(PytorchLayerTest):
|
||||
def test_add(self, ie_device, precision, ir_version, alpha, input_rhs):
|
||||
self.input_rhs = input_rhs
|
||||
self._test(*self.create_model(alpha), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestAddTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
|
||||
class aten_add(torch.nn.Module):
|
||||
def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
super().__init__()
|
||||
self.lhs_type = lhs_type
|
||||
self.rhs_type = rhs_type
|
||||
if len(lhs_shape) == 0:
|
||||
self.forward = self.forward1
|
||||
elif len(rhs_shape) == 0:
|
||||
self.forward = self.forward2
|
||||
else:
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return torch.add(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
|
||||
def forward2(self, lhs):
|
||||
return torch.add(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2)
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
[torch.int32, torch.float64],
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
[torch.int64, torch.float64],
|
||||
[torch.float32, torch.int32],
|
||||
[torch.float32, torch.int64],
|
||||
[torch.float32, torch.float64],
|
||||
])
|
||||
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_add_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
self.lhs_type = lhs_type
|
||||
self.lhs_shape = lhs_shape
|
||||
self.rhs_type = rhs_type
|
||||
self.rhs_shape = rhs_shape
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
|
||||
ie_device, precision, ir_version)
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
@ -12,8 +13,6 @@ class TestComp(PytorchLayerTest):
|
||||
return (np.random.randn(1, 3, 24, 24).astype(np.float32), np.random.randn(1, 3, 24, 24).astype(np.float32))
|
||||
|
||||
def create_model(self, op_type):
|
||||
import torch
|
||||
|
||||
class aten_eq(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x == y
|
||||
@ -57,3 +56,79 @@ class TestComp(PytorchLayerTest):
|
||||
@pytest.mark.precommit
|
||||
def test_comp(self, op, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(op), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestCompMixedTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randint(0, 3, self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (torch.randint(0, 3, self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randint(0, 3, self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randint(0, 3, self.rhs_shape).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape, op):
|
||||
|
||||
ops = {
|
||||
"eq": torch.eq,
|
||||
"ne": torch.ne,
|
||||
"lt": torch.lt,
|
||||
"gt": torch.gt,
|
||||
"ge": torch.ge,
|
||||
"le": torch.le
|
||||
}
|
||||
|
||||
op_fn = ops[op]
|
||||
|
||||
class aten_comp(torch.nn.Module):
|
||||
def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape, op_fn):
|
||||
super().__init__()
|
||||
self.lhs_type = lhs_type
|
||||
self.rhs_type = rhs_type
|
||||
self.op_fn = op_fn
|
||||
if len(lhs_shape) == 0:
|
||||
self.forward = self.forward1
|
||||
elif len(rhs_shape) == 0:
|
||||
self.forward = self.forward2
|
||||
else:
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return self.op_fn(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type))
|
||||
|
||||
def forward2(self, lhs):
|
||||
return self.op_fn(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type))
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return self.op_fn(lhs.to(self.lhs_type), rhs.to(self.rhs_type))
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_comp(lhs_type, lhs_shape, rhs_type, rhs_shape, op_fn), ref_net, f"aten::{op}"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
[torch.int32, torch.float64],
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
[torch.int64, torch.float64],
|
||||
[torch.float32, torch.int32],
|
||||
[torch.float32, torch.int64],
|
||||
[torch.float32, torch.float64],
|
||||
])
|
||||
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
])
|
||||
@pytest.mark.parametrize("op", ["eq", "ne", "lt", "gt", "le", "ge"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_eq_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape, op):
|
||||
self.lhs_type = lhs_type
|
||||
self.lhs_shape = lhs_shape
|
||||
self.rhs_type = rhs_type
|
||||
self.rhs_shape = rhs_shape
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape, op),
|
||||
ie_device, precision, ir_version)
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
@ -12,7 +13,6 @@ class TestDiv(PytorchLayerTest):
|
||||
return (self.input_array.astype(self.input_type), self.other_array.astype(self.other_type))
|
||||
|
||||
def create_model(self, rounding_mode):
|
||||
import torch
|
||||
|
||||
class aten_div(torch.nn.Module):
|
||||
def __init__(self, rounding_mode):
|
||||
@ -26,34 +26,6 @@ class TestDiv(PytorchLayerTest):
|
||||
|
||||
return aten_div(rounding_mode), ref_net, "aten::div"
|
||||
|
||||
@pytest.mark.parametrize(("input_array", "other_array"), [
|
||||
[10 * np.random.rand(5, 5), np.random.uniform(low=1, high=5, size=(1))],
|
||||
[10 * np.random.rand(5, 5, 1), np.random.uniform(low=1, high=5, size=(1))],
|
||||
[10 * np.random.rand(1, 1, 5, 5), np.random.uniform(
|
||||
low=1, high=5, size=(1))],
|
||||
[10 * np.random.rand(5, 5, 1), np.random.uniform(
|
||||
low=1, high=5, size=(5, 1))]
|
||||
])
|
||||
@pytest.mark.parametrize(("types"), [
|
||||
(np.float32, np.float32),
|
||||
pytest.param((np.int32, np.float32), marks=pytest.mark.xfail),
|
||||
pytest.param((np.float32, np.int32), marks=pytest.mark.xfail),
|
||||
pytest.param((np.int32, np.int32), marks=pytest.mark.xfail)
|
||||
])
|
||||
@pytest.mark.parametrize('rounding_mode', ([
|
||||
None,
|
||||
"floor",
|
||||
"trunc"
|
||||
]))
|
||||
@pytest.mark.nightly
|
||||
def test_div(self, input_array, other_array, types, rounding_mode, ie_device, precision, ir_version):
|
||||
self.input_array = input_array
|
||||
self.input_type = types[0]
|
||||
self.other_array = other_array
|
||||
self.other_type = types[1]
|
||||
self._test(*self.create_model(rounding_mode),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.parametrize(("input_array", "other_array"), [
|
||||
[np.array([0.7620, 2.5548, -0.5944, -0.7438, 0.9274]), np.array(0.5)],
|
||||
[np.array([[-0.3711, -1.9353, -0.4605, -0.2917],
|
||||
@ -76,3 +48,74 @@ class TestDiv(PytorchLayerTest):
|
||||
self.other_type = np.float32
|
||||
self._test(*self.create_model(rounding_mode),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestDivTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randint(2, 5, self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (10 * torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (10 * torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randint(2, 5, self.rhs_shape).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode):
|
||||
|
||||
class aten_div(torch.nn.Module):
|
||||
def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode):
|
||||
super().__init__()
|
||||
self.lhs_type = lhs_type
|
||||
self.rhs_type = rhs_type
|
||||
self.rm = rounding_mode
|
||||
if len(lhs_shape) == 0:
|
||||
self.forward = self.forward1
|
||||
elif len(rhs_shape) == 0:
|
||||
self.forward = self.forward2
|
||||
else:
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return torch.div(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), rounding_mode=self.rm)
|
||||
|
||||
def forward2(self, lhs):
|
||||
return torch.div(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), rounding_mode=self.rm)
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return torch.div(lhs.to(self.lhs_type), rhs.to(self.rhs_type), rounding_mode=self.rm)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_div(lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode), ref_net, "aten::div"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
[torch.int32, torch.float64],
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
[torch.int64, torch.float64],
|
||||
[torch.float32, torch.int32],
|
||||
[torch.float32, torch.int64],
|
||||
[torch.float32, torch.float64],
|
||||
])
|
||||
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
])
|
||||
@pytest.mark.parametrize('rounding_mode', ([
|
||||
None,
|
||||
"floor",
|
||||
"trunc"
|
||||
]))
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_div_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode):
|
||||
self.lhs_type = lhs_type
|
||||
self.lhs_shape = lhs_shape
|
||||
self.rhs_type = rhs_type
|
||||
self.rhs_shape = rhs_shape
|
||||
if rounding_mode == "floor" and not lhs_type.is_floating_point and not rhs_type.is_floating_point:
|
||||
pytest.skip("Floor rounding mode and int inputs produce wrong results")
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape, rounding_mode),
|
||||
ie_device, precision, ir_version)
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
@ -12,7 +13,6 @@ class TestMul(PytorchLayerTest):
|
||||
return (self.input_array.astype(self.input_type), self.other_array.astype(self.other_type))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_mul(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -26,35 +26,78 @@ class TestMul(PytorchLayerTest):
|
||||
return aten_mul(), ref_net, "aten::mul"
|
||||
|
||||
@pytest.mark.parametrize(("input_array", "other_array"), [
|
||||
[np.random.rand(1, 2), np.random.rand(2, 1)],
|
||||
[np.random.rand(3, 1, 2), np.random.rand(3, 1, 2)],
|
||||
[np.random.rand(4, 1, 1), np.random.rand(1, 1, 4)],
|
||||
])
|
||||
@pytest.mark.parametrize(("types"), [
|
||||
(np.float32, np.float32),
|
||||
# Type promotion
|
||||
pytest.param((np.int32, np.float32), marks=pytest.mark.xfail(reason="101869")),
|
||||
pytest.param((np.float32, np.int32), marks=pytest.mark.xfail(reason="101869")),
|
||||
pytest.param((np.int32, np.int32), marks=pytest.mark.xfail(reason="101869"))
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
def test_mul_random(self, input_array, other_array, types, ie_device, precision, ir_version):
|
||||
self.input_array = input_array
|
||||
self.input_type = types[0]
|
||||
self.other_array = other_array
|
||||
self.other_type = types[1]
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("input_array", "other_array"), [
|
||||
[np.array([ 0.2015, -0.4255, 2.6087]), np.array(100)],
|
||||
[np.array([[ 1.1207], [-0.3137], [0.0700], [0.8378]]), np.array([[0.5146, 0.1216, -0.5244, 2.2382]])],
|
||||
[np.array([0.2015, -0.4255, 2.6087]), np.array(100)],
|
||||
[np.array([[1.1207], [-0.3137], [0.0700], [0.8378]]),
|
||||
np.array([[0.5146, 0.1216, -0.5244, 2.2382]])],
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_mul_pt_spec(self, input_array, other_array, ie_device, precision, ir_version):
|
||||
self.input_array = input_array
|
||||
self.input_array = input_array
|
||||
self.input_type = np.float32
|
||||
self.other_array = other_array
|
||||
self.other_type = np.float32
|
||||
self.other_type = np.float32
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestMulTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
|
||||
class aten_mul(torch.nn.Module):
|
||||
def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
super().__init__()
|
||||
self.lhs_type = lhs_type
|
||||
self.rhs_type = rhs_type
|
||||
if len(lhs_shape) == 0:
|
||||
self.forward = self.forward1
|
||||
elif len(rhs_shape) == 0:
|
||||
self.forward = self.forward2
|
||||
else:
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return torch.mul(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type))
|
||||
|
||||
def forward2(self, lhs):
|
||||
return torch.mul(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type))
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return torch.mul(lhs.to(self.lhs_type), rhs.to(self.rhs_type))
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_mul(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::mul"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
[torch.int32, torch.float64],
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
[torch.int64, torch.float64],
|
||||
[torch.float32, torch.int32],
|
||||
[torch.float32, torch.int64],
|
||||
[torch.float32, torch.float64],
|
||||
])
|
||||
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_mul_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
self.lhs_type = lhs_type
|
||||
self.lhs_shape = lhs_shape
|
||||
self.rhs_type = rhs_type
|
||||
self.rhs_shape = rhs_shape
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
|
||||
ie_device, precision, ir_version)
|
||||
|
@ -42,3 +42,65 @@ class TestPow(PytorchLayerTest):
|
||||
def test_pow(self, ie_device, precision, ir_version, test_input):
|
||||
self.test_input = test_input
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestPowMixedTypes(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randn(self.rhs_shape) * 2 + 0.6).to(self.rhs_type).numpy(),
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (torch.randint(1, 3, self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randint(1, 3, self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
(torch.randn(self.rhs_shape) * 2 + 0.6).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
|
||||
class aten_pow(torch.nn.Module):
|
||||
def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
super().__init__()
|
||||
self.lhs_type = lhs_type
|
||||
self.rhs_type = rhs_type
|
||||
if len(lhs_shape) == 0:
|
||||
self.forward = self.forward1
|
||||
elif len(rhs_shape) == 0:
|
||||
self.forward = self.forward2
|
||||
else:
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return torch.pow(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type))
|
||||
|
||||
def forward2(self, lhs):
|
||||
return torch.pow(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type))
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return torch.pow(lhs.to(self.lhs_type), rhs.to(self.rhs_type))
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_pow(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::pow"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
[torch.int32, torch.float64],
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
[torch.int64, torch.float64],
|
||||
[torch.float32, torch.int32],
|
||||
[torch.float32, torch.int64],
|
||||
[torch.float32, torch.float64],
|
||||
])
|
||||
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_pow_mixed_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
self.lhs_type = lhs_type
|
||||
self.lhs_shape = lhs_shape
|
||||
self.rhs_type = rhs_type
|
||||
self.rhs_shape = rhs_shape
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
|
||||
ie_device, precision, ir_version)
|
||||
|
@ -24,13 +24,78 @@ class TestSub(PytorchLayerTest):
|
||||
return aten_sub(), ref_net, "aten::sub"
|
||||
|
||||
@pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 4).astype(np.float32),
|
||||
np.random.randn(2, 3, 4).astype(np.float32),
|
||||
np.random.randn(
|
||||
2, 3, 4).astype(np.float32),
|
||||
np.random.randn(1)),
|
||||
(np.random.randn(4, 2, 3).astype(np.float32),
|
||||
np.random.randn(1, 2, 3).astype(np.float32),
|
||||
np.random.randn(1)),])
|
||||
np.random.randn(
|
||||
1, 2, 3).astype(np.float32),
|
||||
np.random.randn(1)), ])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_sub(self, ie_device, precision, ir_version, input_data):
|
||||
self.input_data = input_data
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestSubTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
|
||||
class aten_sub(torch.nn.Module):
|
||||
def __init__(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
super().__init__()
|
||||
self.lhs_type = lhs_type
|
||||
self.rhs_type = rhs_type
|
||||
if len(lhs_shape) == 0:
|
||||
self.forward = self.forward1
|
||||
elif len(rhs_shape) == 0:
|
||||
self.forward = self.forward2
|
||||
else:
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return torch.sub(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
|
||||
def forward2(self, lhs):
|
||||
return torch.sub(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2)
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return torch.sub(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_sub(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::sub"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
# [torch.int32, torch.float64], fp64 produce ov error of eltwise constant fold
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
# [torch.int64, torch.float64], fp64 produce ov error of eltwise constant fold
|
||||
[torch.float32, torch.int32],
|
||||
[torch.float32, torch.int64],
|
||||
# [torch.float32, torch.float64], fp64 produce ov error of eltwise constant fold
|
||||
])
|
||||
@pytest.mark.parametrize(("lhs_shape", "rhs_shape"), [([2, 3], [2, 3]),
|
||||
([2, 3], []),
|
||||
([], [2, 3]),
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_sub_types(self, ie_device, precision, ir_version, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
self.lhs_type = lhs_type
|
||||
self.lhs_shape = lhs_shape
|
||||
self.rhs_type = rhs_type
|
||||
self.rhs_shape = rhs_shape
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
|
||||
ie_device, precision, ir_version)
|
||||
|
Loading…
Reference in New Issue
Block a user