[PT FE] Implement custom op for types alignment (#20431)

* [PT FE] Implement custom op for types alignment

* Fix code style

* Fix inplace ops

* Fix layer tests

* Remove no longer needed change

* Fix ovc tests

* Fix fe tests
This commit is contained in:
Maxim Vafin 2023-10-23 22:54:08 +02:00 committed by GitHub
parent 009ef5657c
commit 8d0381b0fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 430 additions and 165 deletions

View File

@ -372,8 +372,8 @@ class TorchScriptPythonDecoder (Decoder):
return False
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::matmul"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convnd, we have to workaround that
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d", "aten::_convolution", "aten::matmul"]:
# AliasDB::may_contain_alias sometimes return True for tensors produced by convolution or matmul, we have to workaround that
return False
try:
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))

View File

@ -23,6 +23,7 @@ NodeInput = Union[Node, NumericData]
openvino_to_numpy_types_map = [
(Type.boolean, bool),
(Type.boolean, np.bool_),
(Type.f16, np.float16),
(Type.f32, np.float32),
(Type.f64, np.float64),
@ -39,6 +40,7 @@ openvino_to_numpy_types_map = [
openvino_to_numpy_types_str_map = [
("boolean", bool),
("boolean", np.bool_),
("f16", np.float16),
("f32", np.float32),
("f64", np.float64),

View File

@ -81,7 +81,7 @@ public:
explicit FrameworkNode(const OutputVector& inputs, size_t output_size = 1, size_t num_subgraphs = 0);
void validate_and_infer_types() override;
virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;

View File

@ -20,6 +20,7 @@
#include "transformations/op_conversions/convert_convertlike.hpp"
#include "transformations/resolve_names_collisions.hpp"
#include "transforms.hpp"
#include "transforms/align_types_removal.hpp"
#include "transforms/append_list_unpack_replacer.hpp"
#include "transforms/aten_cat_replacer.hpp"
#include "transforms/aten_getitem_replacer.hpp"
@ -177,6 +178,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::pass::MarkDequantizationSubgraph>(
element::TypeVector{element::u8, element::i8, element::u4, element::i4});
manager.register_pass<ov::pass::ConstantFolding>();
manager.register_pass<ov::frontend::pytorch::pass::AlignTypesRemoval>();
manager.register_pass<ov::pass::PushConstantToSubgraph>();
manager.register_pass<ov::pass::UnrollIf>();
manager.register_pass<ov::frontend::pytorch::pass::TupleUnpackInBodyReplacer>();
@ -204,6 +206,8 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
manager.register_pass<ov::frontend::pytorch::pass::U4BlockRepack>();
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
// Second pass of AlignTypesRemoval after all converting transformations
manager.register_pass<ov::frontend::pytorch::pass::AlignTypesRemoval>();
manager.register_pass<ov::pass::ResolveNameCollisions>();
manager.run_passes(model);

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include "internal_op.hpp"
#include "openvino/frontend/decoder.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
class AlignTypes : public InternalOperation {
public:
AlignTypes(const Output<Node>& lhs, const Output<Node>& rhs, bool align_scalars)
: InternalOperation("ov::align_types",
{lhs, rhs},
2,
"This is internal operation for type alignment and should be removed "
"at normalization step. It can't be removed if types can't be resolved."),
m_align_scalars(align_scalars) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
auto lhs = input_value(0);
auto rhs = input_value(1);
auto out_type = infer_types(lhs, rhs, m_align_scalars);
set_output_type(0, out_type, get_input_partial_shape(0));
set_output_type(1, out_type, get_input_partial_shape(1));
}
private:
const bool m_align_scalars;
};
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,56 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <vector>
#include "openvino/frontend/decoder.hpp"
#include "pt_framework_node.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
class InternalOpDecoder : public DummyDecoder {
public:
explicit InternalOpDecoder(const std::string& op_type, const size_t num_outputs)
: m_op_type(op_type),
m_num_outputs(num_outputs) {}
const std::string& get_op_type() const override {
return m_op_type;
}
size_t num_of_outputs() const override {
return m_num_outputs;
}
size_t get_subgraph_size() const override {
return 0;
}
const std::string& decoder_type_name() const override {
return m_decoder_type;
}
private:
const std::string m_op_type;
const std::string m_decoder_type = "internal_op";
const size_t m_num_outputs;
};
class InternalOperation : public PtFrameworkNode {
protected:
InternalOperation(const std::string& op_type,
const OutputVector& inputs,
size_t num_outputs,
const std::string& no_conversion_reason)
: PtFrameworkNode(std::make_shared<InternalOpDecoder>(op_type, num_outputs), inputs) {
auto attrs = get_attrs();
attrs[PtFrameworkNode::failed_conversion_key] = no_conversion_reason;
set_attrs(attrs);
}
};
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -15,7 +15,9 @@ namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_add(const NodeContext& context) {
using namespace ov::op;
OutputVector translate_add_common(const NodeContext& context, bool inplace) {
num_inputs_check(context, 2, 3);
auto lhs = context.get_input(0);
auto rhs = context.get_input(1);
@ -26,12 +28,28 @@ OutputVector translate_add(const NodeContext& context) {
// Case when two lists gets concatenated
FRONT_END_OP_CONVERSION_CHECK(false, "aten::add is used for concatenation of lists, not possible to convert");
}
align_eltwise_input_types(context, lhs, rhs, true);
if (!context.input_is_none(2)) {
auto converted_alpha = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(context.get_input(2), rhs));
rhs = context.mark_node(std::make_shared<ov::op::v1::Multiply>(converted_alpha, rhs));
if (inplace) {
if (lhs.get_element_type().is_dynamic() || lhs.get_element_type() != rhs.get_element_type())
rhs = context.mark_node(std::make_shared<v1::ConvertLike>(rhs, lhs));
} else {
align_eltwise_input_types(context, lhs, rhs, true);
}
return {context.mark_node(std::make_shared<ov::op::v1::Add>(lhs, rhs))};
if (!context.input_is_none(2)) {
auto converted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(2), rhs));
rhs = context.mark_node(std::make_shared<v1::Multiply>(converted_alpha, rhs));
}
auto add = context.mark_node(std::make_shared<v1::Add>(lhs, rhs));
if (inplace)
context.mutate_input(0, add);
return {add};
};
OutputVector translate_add(const NodeContext& context) {
return translate_add_common(context, false);
};
OutputVector translate_add_(const NodeContext& context) {
return translate_add_common(context, true);
};
} // namespace op

View File

@ -17,7 +17,7 @@ namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_div(const NodeContext& context) {
OutputVector translate_div_common(const NodeContext& context, bool inplace) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
@ -34,7 +34,12 @@ OutputVector translate_div(const NodeContext& context) {
y = context.mark_node(std::make_shared<v0::Convert>(y, element::f32));
}
}
align_eltwise_input_types(context, x, y, true);
if (inplace) {
if (x.get_element_type().is_dynamic() || x.get_element_type() != y.get_element_type())
y = context.mark_node(std::make_shared<v1::ConvertLike>(x, y));
} else {
align_eltwise_input_types(context, x, y, true);
}
auto res = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
// TODO: ticket 103296; Temporarily disable ConvertDivide transformation
disable_divide_conversion(res);
@ -44,9 +49,19 @@ OutputVector translate_div(const NodeContext& context) {
const auto convert = context.mark_node(std::make_shared<v0::Convert>(res, element::i32));
res = context.mark_node(std::make_shared<v1::ConvertLike>(convert, x));
}
if (inplace)
context.mutate_input(0, res);
return {res};
};
OutputVector translate_div(const NodeContext& context) {
return translate_div_common(context, false);
};
OutputVector translate_div_(const NodeContext& context) {
return translate_div_common(context, true);
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -15,18 +15,34 @@ namespace op {
using namespace ov::op;
OutputVector translate_sub(const NodeContext& context) {
OutputVector translate_sub_common(const NodeContext& context, bool inplace) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y);
if (inplace) {
if (x.get_element_type().is_dynamic() || x.get_element_type() != y.get_element_type())
y = context.mark_node(std::make_shared<v1::ConvertLike>(x, y));
} else {
align_eltwise_input_types(context, x, y);
}
// default alpha is 1 so no need to multiply if alpha is not provided
if (!context.input_is_none(2)) {
auto alpha = context.get_input(2);
auto casted_alpha = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, y));
y = context.mark_node(std::make_shared<v1::Multiply>(casted_alpha, y));
}
return {context.mark_node(std::make_shared<v1::Subtract>(x, y))};
auto sub = context.mark_node(std::make_shared<v1::Subtract>(x, y));
if (inplace)
context.mutate_input(0, sub);
return {sub};
};
OutputVector translate_sub(const NodeContext& context) {
return translate_sub_common(context, false);
};
OutputVector translate_sub_(const NodeContext& context) {
return translate_sub_common(context, true);
};
} // namespace op

View File

@ -23,6 +23,7 @@ OP_CONVERTER(translate_adaptive_max_pool3d);
OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_adaptive_max_pool1d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_add_);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_all);
@ -57,6 +58,7 @@ OP_CONVERTER(translate_deform_conv);
OP_CONVERTER(translate_derive_index);
OP_CONVERTER(translate_dim);
OP_CONVERTER(translate_div);
OP_CONVERTER(translate_div_);
OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
@ -175,6 +177,7 @@ OP_CONVERTER(translate_squeeze);
OP_CONVERTER(translate_std);
OP_CONVERTER(translate_std_mean);
OP_CONVERTER(translate_sub);
OP_CONVERTER(translate_sub_);
OP_CONVERTER(translate_sum);
OP_CONVERTER(translate_t);
OP_CONVERTER(translate_to);
@ -247,7 +250,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::adaptive_max_pool2d", op::quantizable_op<op::translate_adaptive_max_pool2d>},
{"aten::adaptive_max_pool3d", op::quantizable_op<op::translate_adaptive_max_pool3d>},
{"aten::add", op::translate_add},
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::add_", op::translate_add_},
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::all", op::translate_all},
@ -309,7 +312,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::dequantize", op::skip_node}, // we convert model to fp32 using FQ, so dequantization is not needed
{"aten::dim", op::translate_dim},
{"aten::div", op::translate_div},
{"aten::div_", op::inplace_op<op::translate_div>},
{"aten::div_", op::translate_div_},
{"aten::dropout", op::skip_node},
{"aten::dropout_", op::skip_node},
{"aten::elu", op::translate_elu},
@ -404,9 +407,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::minimum", op::translate_minimum},
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
{"aten::mul_", op::inplace_translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::multiply", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::multiply_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
{"aten::multiply_", op::inplace_translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::narrow", op::translate_narrow},
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten::neg", op::translate_neg},
@ -477,7 +480,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::std", op::translate_std},
{"aten::std_mean", op::translate_std_mean},
{"aten::sub", op::translate_sub},
{"aten::sub_", op::inplace_op<op::translate_sub>},
{"aten::sub_", op::translate_sub_},
{"aten::sum", op::translate_sum},
{"aten::swapaxes", op::quantizable_op<op::translate_transpose>},
{"aten::t", op::translate_t},

View File

@ -20,14 +20,17 @@ public:
PtFrameworkNode(const std::shared_ptr<TorchDecoder>& decoder,
const OutputVector& inputs,
size_t output_size,
bool is_backprop = false)
bool is_reverseprop = false)
: ov::op::util::FrameworkNode(inputs, output_size, decoder->get_subgraph_size()),
m_decoder(decoder) {
ov::op::util::FrameworkNodeAttrs attrs;
attrs.set_type_name("PTFrameworkNode");
if (is_backprop) {
attrs[op_type_key] = m_decoder->get_op_type() + "_backprop";
if (is_reverseprop) {
attrs[op_type_key] = m_decoder->get_op_type() + "_reverseprop";
attrs[schema_key] = "None";
attrs[failed_conversion_key] =
"This is an internal openvino operation representing reverse data propagation. It should not appear in "
"graph in normal conversion flow and might be result of other failures.";
} else {
attrs[op_type_key] = m_decoder->get_op_type();
attrs[schema_key] = m_decoder->get_schema();

View File

@ -0,0 +1,60 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "align_types_removal.hpp"
#include <memory>
#include <utility>
#include "helper_ops/align_types.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/util/framework_node.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
using namespace ov::op;
AlignTypesRemoval::AlignTypesRemoval() {
auto align_types_pattern = ov::pass::pattern::wrap_type<ov::op::util::FrameworkNode>();
ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) {
auto align_types = std::dynamic_pointer_cast<AlignTypes>(m.get_match_root());
if (!align_types)
return false;
auto lhs_itype = align_types->get_input_element_type(0);
auto rhs_itype = align_types->get_input_element_type(1);
auto lhs_otype = align_types->get_output_element_type(0);
auto rhs_otype = align_types->get_output_element_type(1);
if (lhs_otype.is_static() && rhs_otype.is_static()) {
auto out1 = align_types->input_value(0);
auto out2 = align_types->input_value(1);
if (lhs_itype != lhs_otype)
out1 = std::make_shared<v0::Convert>(align_types->input_value(0), lhs_otype);
if (rhs_itype != rhs_otype)
out2 = std::make_shared<v0::Convert>(align_types->input_value(1), rhs_otype);
align_types->output(0).replace(out1);
align_types->output(1).replace(out2);
return true;
}
return false;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(align_types_pattern,
"ov::frontend::pytorch::pass::AlignTypesRemoval");
this->register_matcher(m, callback);
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace pass {
class AlignTypesRemoval : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ov::frontend::pytorch::pass::AlignTypesRemoval");
AlignTypesRemoval();
};
} // namespace pass
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -26,16 +26,8 @@ using namespace ov::op;
StringEqualityReplacer::StringEqualityReplacer() {
auto framework_node_lhs = pattern::wrap_type<PtFrameworkNode>();
auto framework_node_rhs = pattern::wrap_type<PtFrameworkNode>();
auto convert_lhs = pattern::wrap_type<v0::Convert>({framework_node_lhs});
auto convert_like_lhs = pattern::wrap_type<v1::ConvertLike>({framework_node_lhs, framework_node_rhs});
auto convert_rhs = pattern::wrap_type<v0::Convert>({framework_node_rhs});
auto convert_like_rhs = pattern::wrap_type<v1::ConvertLike>({framework_node_rhs, framework_node_lhs});
auto lhs_pattern =
std::make_shared<pattern::op::Or>(OutputVector{framework_node_lhs, convert_lhs, convert_like_lhs});
auto rhs_pattern =
std::make_shared<pattern::op::Or>(OutputVector{framework_node_rhs, convert_rhs, convert_like_rhs});
auto equal_op = pattern::wrap_type<v1::Equal>({lhs_pattern, rhs_pattern});
auto not_equal_op = pattern::wrap_type<v1::NotEqual>({lhs_pattern, rhs_pattern});
auto equal_op = pattern::wrap_type<v1::Equal>({framework_node_lhs, framework_node_rhs});
auto not_equal_op = pattern::wrap_type<v1::NotEqual>({framework_node_lhs, framework_node_rhs});
auto string_equality_pattern = std::make_shared<pattern::op::Or>(OutputVector{equal_op, not_equal_op});

View File

@ -4,6 +4,7 @@
#include "utils.hpp"
#include "helper_ops/align_types.hpp"
#include "op_table.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/frontend/pytorch/decoder.hpp"
@ -381,33 +382,17 @@ std::unordered_map<size_t, element::Type> bit_to_int{
};
} // namespace
void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Output<Node>& rhs, bool align_scalars) {
element::Type infer_types(const Output<Node>& lhs, const Output<Node>& rhs, bool align_scalars) {
const auto& lhs_type = lhs.get_element_type();
const auto& rhs_type = rhs.get_element_type();
auto out_type = context.get_output_type(0);
if (out_type.is<element::Type>()) {
auto otype = out_type.as<element::Type>();
if (otype.is_real()) {
if (otype != lhs_type) {
lhs = context.mark_node(std::make_shared<ov::op::v0::Convert>(lhs, otype));
}
if (otype != rhs_type) {
rhs = context.mark_node(std::make_shared<ov::op::v0::Convert>(rhs, otype));
}
return;
}
}
if (lhs_type.is_dynamic() || rhs_type.is_dynamic()) {
// if any of types is not known, align to lhs type.
// TODO: can be fixed with special operation?
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
return element::dynamic;
}
// Both types are static, align types. If float and int types are used convert int type to f32, after that align
// to the largest bitness, if both float or both int, just align bitness
if (lhs_type == rhs_type)
return;
return lhs_type;
// if one of operands is scalar, the resulting type is taken from the other operand except when scalar is float
// type and other operand is int, in that case BOTH operands get fp32 type
@ -429,11 +414,9 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
if (!align_scalars)
rhs_dst_type = element::f32;
} else if (is_lhs_scalar && rhs_type != element::boolean) {
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
return;
return rhs_type;
} else if (is_rhs_scalar && lhs_type != element::boolean) {
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
return lhs_type;
}
if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) {
@ -470,13 +453,39 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
}
}
}
return lhs_dst_type;
}
// Cast to destination types
if (lhs_dst_type != lhs_type) {
lhs = context.mark_node(std::make_shared<opset10::Convert>(lhs, lhs_dst_type));
void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Output<Node>& rhs, bool align_scalars) {
const auto& lhs_type = lhs.get_element_type();
const auto& rhs_type = rhs.get_element_type();
auto out_type = context.get_output_type(0);
if (out_type.is<element::Type>()) {
auto otype = out_type.as<element::Type>();
if (otype.is_real()) {
if (otype != lhs_type) {
lhs = context.mark_node(std::make_shared<ov::op::v0::Convert>(lhs, otype));
}
if (otype != rhs_type) {
rhs = context.mark_node(std::make_shared<ov::op::v0::Convert>(rhs, otype));
}
return;
}
}
if (rhs_dst_type != rhs_type) {
rhs = context.mark_node(std::make_shared<opset10::Convert>(rhs, rhs_dst_type));
auto dst_type = infer_types(lhs, rhs, align_scalars);
if (dst_type.is_dynamic()) {
// We can't decide the type at this point, create a special operation
auto at = std::make_shared<AlignTypes>(lhs, rhs, align_scalars);
lhs = at->output(0);
rhs = at->output(1);
return;
}
// Cast to destination type
if (dst_type != lhs_type) {
lhs = context.mark_node(std::make_shared<opset10::Convert>(lhs, dst_type));
}
if (dst_type != rhs_type) {
rhs = context.mark_node(std::make_shared<opset10::Convert>(rhs, dst_type));
}
}

View File

@ -7,6 +7,7 @@
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
namespace ov {
@ -65,11 +66,11 @@ Any simplified_type_interpret(Any type);
void add_exception_to_fw_node(std::shared_ptr<Node> node, const std::string& msg);
element::Type infer_types(const Output<Node>& lhs, const Output<Node>& rhs, bool align_scalars);
void align_eltwise_input_types(const NodeContext& context,
Output<Node>& lhs,
Output<Node>& rhs,
bool align_scalars = false);
void align_output_types(const NodeContext& context, OutputVector& outputs);
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
@ -125,12 +126,31 @@ OutputVector translate_1to1_match_2_inputs_align_types(const NodeContext& contex
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
auto lhs = context.get_input(0);
auto rhs = context.get_input(1);
align_eltwise_input_types(context, lhs, rhs, true);
auto lhs_type = context.get_input_type(0);
auto rhs_type = context.get_input_type(1);
// If type is string or None, we shouldn't align
if (!lhs_type.is<type::Str>() && !rhs_type.is<type::Str>() && !lhs_type.is<type::PyNone>() &&
!rhs_type.is<type::PyNone>())
align_eltwise_input_types(context, lhs, rhs, true);
OutputVector res = {context.mark_node(std::make_shared<T>(lhs, rhs))};
align_output_types(context, res);
return res;
}
template <typename T, size_t idx = 0>
OutputVector inplace_translate_1to1_match_2_inputs_align_types(const NodeContext& context) {
num_inputs_check(context, 2, 2);
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
auto lhs = context.get_input(0);
auto rhs = context.get_input(1);
// For inplace op we know direction of type alignment
if (lhs.get_element_type().is_dynamic() || lhs.get_element_type() != rhs.get_element_type())
rhs = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(rhs, lhs));
OutputVector res = {context.mark_node(std::make_shared<T>(lhs, rhs))};
context.mutate_input(idx, res[0]);
return res;
}
inline OutputVector return_false_scalar(const NodeContext& context) {
return {context.mark_node(ov::op::v0::Constant::create(element::boolean, Shape{}, {false}))};
}
@ -168,7 +188,7 @@ public:
FRONT_END_NOT_IMPLEMENTED(get_output_debug_name);
}
virtual PartialShape get_output_shape(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_shape);
return PartialShape::dynamic();
}
virtual Any get_output_type(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_type);
@ -189,7 +209,7 @@ public:
FRONT_END_NOT_IMPLEMENTED(get_op_type);
}
virtual const std::string& get_schema() const override {
FRONT_END_NOT_IMPLEMENTED(get_schema);
return m_schema;
}
virtual size_t num_of_outputs() const override {
FRONT_END_NOT_IMPLEMENTED(num_of_outputs);
@ -218,6 +238,9 @@ public:
virtual OutputVector inlined_inputs(size_t start_index) const override {
FRONT_END_NOT_IMPLEMENTED(inlined_inputs);
}
private:
const std::string m_schema = "NONE";
};
} // namespace pytorch

View File

@ -284,8 +284,8 @@ def create_pytorch_jit_script_function(tmp_dir):
return torch.sigmoid(torch.relu(x * y))
inp_shape = PartialShape([Dimension(1, -1), Dimension(-1, 5), 10])
ref_model = make_ref_pt_model_two_inputs(inp_shape, dtype=Type.dynamic)
return scripted_fn, ref_model, {'input': [(inp_shape), (inp_shape)]}
ref_model = make_ref_pt_model_two_inputs(inp_shape)
return scripted_fn, ref_model, {'input': [(inp_shape, Type.f32), (inp_shape, Type.f32)]}
def create_pytorch_nn_module_layout_list(tmp_dir):
@ -472,9 +472,9 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
def create_pytorch_nn_module_shapes_list_static(tmp_dir):
pt_model = make_pt_model_two_inputs()
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20], dtype=Type.dynamic)
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20])
return pt_model, ref_model, {'input': [[1, 3, 20, 20], [1, 3, 20, 20]]}
return pt_model, ref_model, {'input': [([1, 3, 20, 20], Type.f32), ([1, 3, 20, 20], Type.f32)]}
def create_pytorch_nn_module_shapes_list_static_via_input(tmp_dir):
@ -490,17 +490,16 @@ def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir):
[-1, 3, 20, Dimension(-1, 20)]]
param1 = ov.opset8.parameter(PartialShape(
inp_shapes[0]), name="x", dtype=Type.dynamic)
inp_shapes[0]), name="x", dtype=Type.f32)
param2 = ov.opset8.parameter(PartialShape(
inp_shapes[1]), name="y", dtype=Type.dynamic)
cl = ov.opset8.convert_like(param2, param1)
mul = ov.opset8.multiply(param1, cl)
inp_shapes[1]), name="y", dtype=Type.f32)
mul = ov.opset8.multiply(param1, param2)
relu = ov.opset8.relu(mul)
sigm = ov.opset8.sigmoid(relu)
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input': inp_shapes}
return pt_model, ref_model, {'input': [(inp_shapes[0], Type.f32), (inp_shapes[1], Type.f32)]}
def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
@ -523,8 +522,8 @@ def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir):
pt_model = make_pt_model_one_input()
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)]]
ref_model = make_ref_pt_model_one_input(inp_shapes[0], dtype=Type.dynamic)
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)], Type.f32]
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
return pt_model, ref_model, {'input': inp_shapes}
@ -537,8 +536,8 @@ def create_pytorch_nn_module_shapes_list_dynamic_single_input_via_input(tmp_dir)
def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir):
pt_model = make_pt_model_one_input()
inp_shapes = [[1, 3, 20, 20]]
ref_model = make_ref_pt_model_one_input(inp_shapes[0], dtype=Type.dynamic)
inp_shapes = [[1, 3, 20, 20], Type.f32]
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
return pt_model, ref_model, {'input': inp_shapes}

View File

@ -281,8 +281,8 @@ def create_pytorch_jit_script_function(tmp_dir):
return torch.sigmoid(torch.relu(x * y))
inp_shape = PartialShape([Dimension(1, -1), Dimension(-1, 5), 10])
ref_model = make_ref_pt_model_two_inputs(inp_shape, dtype=Type.dynamic)
return scripted_fn, ref_model, {'input': [(inp_shape), (inp_shape)]}
ref_model = make_ref_pt_model_two_inputs(inp_shape)
return scripted_fn, ref_model, {'input': [(inp_shape, Type.f32), (inp_shape, Type.f32)]}
def create_pytorch_nn_module_layout_list(tmp_dir):
@ -469,9 +469,9 @@ def create_pytorch_nn_module_scale_list_compression_enabled(tmp_dir):
def create_pytorch_nn_module_shapes_list_static(tmp_dir):
pt_model = make_pt_model_two_inputs()
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20], dtype=Type.dynamic)
ref_model = make_ref_pt_model_two_inputs([1, 3, 20, 20])
return pt_model, ref_model, {'input': [[1, 3, 20, 20], [1, 3, 20, 20]]}
return pt_model, ref_model, {'input': [([1, 3, 20, 20], Type.f32), ([1, 3, 20, 20], Type.f32)]}
def create_pytorch_nn_module_shapes_list_static_via_input(tmp_dir):
@ -487,17 +487,16 @@ def create_pytorch_nn_module_shapes_list_dynamic(tmp_dir):
[-1, 3, 20, Dimension(-1, 20)]]
param1 = ov.opset8.parameter(PartialShape(
inp_shapes[0]), name="x", dtype=Type.dynamic)
inp_shapes[0]), name="x", dtype=Type.f32)
param2 = ov.opset8.parameter(PartialShape(
inp_shapes[1]), name="y", dtype=Type.dynamic)
cl = ov.opset8.convert_like(param2, param1)
mul = ov.opset8.multiply(param1, cl)
inp_shapes[1]), name="y", dtype=Type.f32)
mul = ov.opset8.multiply(param1, param2)
relu = ov.opset8.relu(mul)
sigm = ov.opset8.sigmoid(relu)
parameter_list = [param1, param2]
ref_model = Model([sigm], parameter_list, "test")
return pt_model, ref_model, {'input': inp_shapes}
return pt_model, ref_model, {'input': [(inp_shapes[0], Type.f32), (inp_shapes[1], Type.f32)]}
def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
@ -520,8 +519,8 @@ def create_pytorch_nn_module_shapes_list_dynamic_via_input(tmp_dir):
def create_pytorch_nn_module_shapes_list_dynamic_single_input(tmp_dir):
pt_model = make_pt_model_one_input()
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)]]
ref_model = make_ref_pt_model_one_input(inp_shapes[0], dtype=Type.dynamic)
inp_shapes = [[Dimension(-1), 3, 20, Dimension(20, -1)], Type.f32]
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
return pt_model, ref_model, {'input': inp_shapes}
@ -534,8 +533,8 @@ def create_pytorch_nn_module_shapes_list_dynamic_single_input_via_input(tmp_dir)
def create_pytorch_nn_module_shapes_list_static_single_input(tmp_dir):
pt_model = make_pt_model_one_input()
inp_shapes = [[1, 3, 20, 20]]
ref_model = make_ref_pt_model_one_input(inp_shapes[0], dtype=Type.dynamic)
inp_shapes = [[1, 3, 20, 20], Type.f32]
ref_model = make_ref_pt_model_one_input(inp_shapes[0])
return pt_model, ref_model, {'input': inp_shapes}

View File

@ -641,7 +641,7 @@ def f(x, y):
@pytest.mark.precommit
def test_pytorch_decoder_can_convert_scripted_function():
from openvino.tools.mo import convert_model
from openvino import convert_model, Type
scripted = torch.jit.script(f)
model = convert_model(scripted)
model = convert_model(scripted, input=[Type.f32, Type.f32])
assert model is not None

View File

@ -22,6 +22,9 @@ class aten_relu(torch.nn.Module):
class aten_multi_input_output(torch.nn.Module):
def forward(self, x, y, z):
x = x.to(torch.float32)
y = y.to(torch.float32)
z = z.to(torch.float32)
return torch.nn.functional.relu(x), x * y, z / x

View File

@ -77,10 +77,13 @@ class PytorchLayerTest:
if use_torch_compile_backend():
self.torch_compile_backend_test(model, torch_inputs, custom_eps)
else:
trace_model = kwargs.get('trace_model', False)
freeze_model = kwargs.get('freeze_model', True)
with torch.no_grad():
trace_model = kwargs.get('trace_model', False)
freeze_model = kwargs.get('freeze_model', True)
smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
if kwargs.get('use_convert_model', False):
smodel, converted_model = self.convert_via_mo(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
else:
smodel, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
if kind is not None and not isinstance(kind, (tuple, list)):
kind = [kind]
@ -162,12 +165,13 @@ class PytorchLayerTest:
raise RuntimeError("Please provide inputs generation function")
def convert_via_mo(self, model, example_input, trace_model, dynamic_shapes, ov_inputs, freeze_model):
from openvino.tools.ovc import convert_model
kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]}
from openvino import convert_model, PartialShape
if trace_model:
decoder = TorchScriptPythonDecoder(model, example_input=example_input, skip_freeze=not freeze_model)
kwargs = {"example_input": example_input if len(example_input) > 1 else example_input[0]}
else:
decoder = TorchScriptPythonDecoder(model, skip_freeze=not freeze_model)
kwargs = {"input": [(i.dtype, PartialShape([-1] * len(i.shape))) for i in example_input]}
smodel = decoder.pt_module
print(smodel.inlined_graph)
if not dynamic_shapes:

View File

@ -43,7 +43,7 @@ class TestAdd(PytorchLayerTest):
@pytest.mark.parametrize("op_type", ["add", "add_"])
def test_add(self, ie_device, precision, ir_version, alpha, input_rhs, op_type):
self.input_rhs = input_rhs
self._test(*self.create_model(alpha, op_type), ie_device, precision, ir_version)
self._test(*self.create_model(alpha, op_type), ie_device, precision, ir_version, use_convert_model=True)
class TestAddTypes(PytorchLayerTest):

View File

@ -55,7 +55,7 @@ class TestComp(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_comp(self, op, ie_device, precision, ir_version):
self._test(*self.create_model(op), ie_device, precision, ir_version)
self._test(*self.create_model(op), ie_device, precision, ir_version, use_convert_model=True)
class TestCompMixedTypes(PytorchLayerTest):

View File

@ -10,15 +10,6 @@ from pytorch_layer_test_class import PytorchLayerTest
from torchvision.ops import deform_conv2d
def xfail_106712(test_param):
return pytest.param(
test_param,
marks=pytest.mark.xfail(
reason="Depending on number of groups and number of output channels, deformable convolution may return incorrect reasults. Ticket 106712"
),
)
params = [
{
"weights_shape": [64, 64, 3, 3],
@ -62,15 +53,13 @@ params = [
"padding": (2, 2),
"dilation": (1, 1),
},
xfail_106712(
{
"weights_shape": [64, 16, 3, 3],
"offset_shape": [1, 18, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
}
),
{
"weights_shape": [64, 16, 3, 3],
"offset_shape": [1, 18, 64, 64],
"stride": (1, 1),
"padding": (1, 1),
"dilation": (1, 1),
},
{
"weights_shape": [60, 16, 3, 3],
"offset_shape": [1, 18, 64, 64],
@ -92,15 +81,13 @@ params = [
"padding": (1, 1),
"dilation": (1, 1),
},
xfail_106712(
{
"weights_shape": [64, 32, 3, 3],
"offset_shape": [1, 36, 68, 68],
"stride": (1, 1),
"padding": (3, 3),
"dilation": (1, 1),
}
),
{
"weights_shape": [64, 32, 3, 3],
"offset_shape": [1, 36, 68, 68],
"stride": (1, 1),
"padding": (3, 3),
"dilation": (1, 1),
},
{
"weights_shape": [62, 32, 3, 3],
"offset_shape": [1, 36, 68, 68],

View File

@ -56,7 +56,8 @@ class TestDevice(PytorchLayerTest):
ie_device,
precision,
ir_version,
trace_model=False
trace_model=False,
use_convert_model=True,
)
@pytest.mark.parametrize("device_string", ["cpu", "cuda"])
@ -68,5 +69,6 @@ class TestDevice(PytorchLayerTest):
ie_device,
precision,
ir_version,
trace_model=False
trace_model=False,
use_convert_model=True,
)

View File

@ -23,4 +23,4 @@ class TestDict(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_dict(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

View File

@ -34,7 +34,7 @@ class TestCdist(PytorchLayerTest):
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
reason='Ticket - 122715')
def test_cdist(self, p, ie_device, precision, ir_version):
self._test(*self.create_model(p), ie_device, precision, ir_version)
self._test(*self.create_model(p), ie_device, precision, ir_version, use_convert_model=True)
class TestPairwiseDistance(PytorchLayerTest):
@ -68,4 +68,4 @@ class TestPairwiseDistance(PytorchLayerTest):
@pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64',
reason='Ticket - 122715')
def test_cdist(self, p, eps, keepdim, ie_device, precision, ir_version):
self._test(*self.create_model(p, eps, keepdim), ie_device, precision, ir_version)
self._test(*self.create_model(p, eps, keepdim), ie_device, precision, ir_version, use_convert_model=True)

View File

@ -49,7 +49,7 @@ class TestDiv(PytorchLayerTest):
self.other_array = other_array
self.other_type = np.float32
self._test(*self.create_model(rounding_mode),
ie_device, precision, ir_version)
ie_device, precision, ir_version, use_convert_model=True)
class TestDivTypes(PytorchLayerTest):

View File

@ -134,7 +134,7 @@ class TestNewEmpty(PytorchLayerTest):
@pytest.mark.precommit
def test_new_empty(self, shape, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True)
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
@pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
@ -142,4 +142,4 @@ class TestNewEmpty(PytorchLayerTest):
@pytest.mark.nightly
def test_new_empty_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True)

View File

@ -45,4 +45,4 @@ class TestEq(PytorchLayerTest):
self.input_type = types[0]
self.other_array = other_array
self.other_type = types[1]
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

View File

@ -59,7 +59,7 @@ class TestFloorDivide(PytorchLayerTest):
def test_floor_divide(self, input_tensor, other_tensor, ie_device, precision, ir_version):
self.input_tensor = input_tensor
self.other_tensor = other_tensor
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
@pytest.mark.parametrize('input_tensor', ([
np.random.randint(low=0, high=10, size=5).astype(np.float32),

View File

@ -29,7 +29,7 @@ class TestBF16(PytorchLayerTest):
@pytest.mark.parametrize("to_trace", [True, False])
def test_bf16(self, ie_device, precision, ir_version, to_trace):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=to_trace, freeze_model=False)
ir_version, trace_model=to_trace, freeze_model=False, use_convert_model=True)
class TestFP16(PytorchLayerTest):
@ -53,4 +53,4 @@ class TestFP16(PytorchLayerTest):
@pytest.mark.parametrize("to_trace", [True, False])
def test_fp16(self, ie_device, precision, ir_version, to_trace):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=to_trace, freeze_model=False)
ir_version, trace_model=to_trace, freeze_model=False, use_convert_model=True)

View File

@ -347,7 +347,7 @@ class TestNewFull(PytorchLayerTest):
@pytest.mark.precommit
def test_new_full(self, shape, value, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape), ie_device, precision, ir_version,
kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype})
kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype}, use_convert_model=True)
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
@pytest.mark.parametrize("value,input_dtype", [(0, np.uint8), (1, np.int32), (-1, np.float32), (0.5, np.float64)])
@ -355,7 +355,7 @@ class TestNewFull(PytorchLayerTest):
@pytest.mark.nightly
def test_new_full_with_dtype(self, value, shape, dtype, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version,
kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype})
kwargs_to_prepare_input={'value': value, 'input_dtype': input_dtype}, use_convert_model=True)
class TestZerosAndOnes(PytorchLayerTest):
@ -562,7 +562,7 @@ class TestNewZeros(PytorchLayerTest):
@pytest.mark.precommit
def test_new_zeros(self, shape, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True)
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
@pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
@ -570,7 +570,7 @@ class TestNewZeros(PytorchLayerTest):
@pytest.mark.nightly
def test_new_zeros_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True)
class TestNewOnes(PytorchLayerTest):
@ -621,7 +621,7 @@ class TestNewOnes(PytorchLayerTest):
@pytest.mark.precommit
def test_new_ones(self, shape, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True)
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
@pytest.mark.parametrize("input_dtype", [bool, np.uint8, np.int8, np.int32, np.int64, np.float32, np.float64])
@ -629,4 +629,4 @@ class TestNewOnes(PytorchLayerTest):
@pytest.mark.nightly
def test_new_ones_with_dtype(self, shape, dtype, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(shape, dtype=dtype, used_dtype=True), ie_device, precision, ir_version,
kwargs_to_prepare_input={'input_dtype': input_dtype})
kwargs_to_prepare_input={'input_dtype': input_dtype}, use_convert_model=True)

View File

@ -102,4 +102,4 @@ class TestAddGetItem(PytorchLayerTest):
@pytest.mark.parametrize("idx", [-4, -3, -2, -1, 0, 1, 2, 3])
def test_add_cat(self, ie_device, precision, ir_version, idx):
self._test(aten_add_getitem(idx), None, ["aten::__getitem__", "aten::add", "prim::ListConstruct"],
ie_device, precision, ir_version, freeze_model=False)
ie_device, precision, ir_version, freeze_model=False, use_convert_model=True)

View File

@ -39,4 +39,4 @@ class TestIf(PytorchLayerTest):
@pytest.mark.skipif(os.getenv("GITHUB_ACTIONS") == 'true', reason="Ticket - 114818")
def test_if(self, y, ie_device, precision, ir_version):
self.y = y
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

View File

@ -150,4 +150,4 @@ class TestIndexMask(PytorchLayerTest):
[2, 2, 3, 4]))
def test_index_mask(self, input_shape, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"input_shape": input_shape}, trace_model=True)
"input_shape": input_shape}, trace_model=True, use_convert_model=True)

View File

@ -162,7 +162,7 @@ class TestNonZero_IndexPut(PytorchLayerTest):
self.values = input_data["values"]
self.indices_0 = indices[0]
self.indices_1 = indices[1]
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True)
self._test(*self.create_model(accumulate), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)
class TestMask_IndexPut(PytorchLayerTest):
def _prepare_input(self):
@ -181,4 +181,4 @@ class TestMask_IndexPut(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_nonzero_index_put_(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)

View File

@ -48,7 +48,7 @@ class TestLen(PytorchLayerTest):
def test_len_int_list(self, ie_device, precision, ir_version, input_tensor):
self.input_tensor = input_tensor
self._test(*self.create_model_int_list(),
ie_device, precision, ir_version)
ie_device, precision, ir_version, use_convert_model=True)
class TestLenEmpty(PytorchLayerTest):

View File

@ -123,6 +123,7 @@ class TestListUnpack(PytorchLayerTest):
ie_device,
precision,
ir_version,
use_convert_model=True,
)
class TestMeshgridListUnpack(PytorchLayerTest):

View File

@ -37,7 +37,7 @@ class TestMul(PytorchLayerTest):
self.input_type = np.float32
self.other_array = other_array
self.other_type = np.float32
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)
class TestMulTypes(PytorchLayerTest):

View File

@ -25,4 +25,5 @@ class TestLog(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_or(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, dynamic_shapes=False, trace_model=True)
self._test(*self.create_model(), ie_device, precision, ir_version,
dynamic_shapes=False, trace_model=True, use_convert_model=True)

View File

@ -41,7 +41,7 @@ class TestPow(PytorchLayerTest):
@pytest.mark.precommit
def test_pow(self, ie_device, precision, ir_version, test_input):
self.test_input = test_input
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)
class TestPowMixedTypes(PytorchLayerTest):

View File

@ -32,7 +32,7 @@ class TestRemainder(PytorchLayerTest):
@pytest.mark.precommit
def test_remainder(self, ie_device, precision, ir_version, input_rhs):
self.input_rhs = input_rhs
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)
class TestRemainderTypes(PytorchLayerTest):

View File

@ -77,4 +77,4 @@ class TestRepeatFromFlanT5(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_repeat_t5(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True)
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, use_convert_model=True)

View File

@ -40,9 +40,9 @@ class TestRsub(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_rsub(self, ie_device, precision, ir_version, input_data):
def test_rsub_f(self, ie_device, precision, ir_version, input_data):
self.input_data = input_data
self._test(*self.create_model(second_type="float"), ie_device, precision, ir_version)
self._test(*self.create_model(second_type="float"), ie_device, precision, ir_version, use_convert_model=True)
@pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 4).astype(np.float32),
np.array(5).astype(int),
@ -50,9 +50,9 @@ class TestRsub(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_rsub(self, ie_device, precision, ir_version, input_data):
def test_rsub_i(self, ie_device, precision, ir_version, input_data):
self.input_data = input_data
self._test(*self.create_model(second_type="int"), ie_device, precision, ir_version)
self._test(*self.create_model(second_type="int"), ie_device, precision, ir_version, use_convert_model=True)
class TestRsubTypes(PytorchLayerTest):

View File

@ -31,4 +31,4 @@ class TestStrides(PytorchLayerTest):
@pytest.mark.nightly
@pytest.mark.precommit
def test_strides(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(), ie_device, precision, ir_version, use_convert_model=True)

View File

@ -50,7 +50,7 @@ class TestSub(PytorchLayerTest):
@pytest.mark.precommit
def test_sub(self, ie_device, precision, ir_version, input_data, inplace):
self.input_data = input_data
self._test(*self.create_model(inplace), ie_device, precision, ir_version)
self._test(*self.create_model(inplace), ie_device, precision, ir_version, use_convert_model=True)
class TestSubTypes(PytorchLayerTest):

View File

@ -91,4 +91,5 @@ class TestTSmall(PytorchLayerTest):
precision,
ir_version,
kwargs_to_prepare_input={"num_dims": num_dims, "input_dtype": input_dtype},
use_convert_model=True,
)

View File

@ -60,7 +60,7 @@ class TestTupleConstruct(PytorchLayerTest):
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "tensor_tail", "list_and_tuple"])
@pytest.mark.nightly
def test_tuple_construct(self, case, ie_device, precision, ir_version):
self._test(*self.create_model(case), ie_device, precision, ir_version)
self._test(*self.create_model(case), ie_device, precision, ir_version, use_convert_model=True)
class TestTupleConstructTupleUnpack(PytorchLayerTest):
@ -86,7 +86,7 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest):
@pytest.mark.nightly
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device,
precision, ir_version, freeze_model=False)
precision, ir_version, freeze_model=False, use_convert_model=True)
class TestTupleUnpackParameterSingle(PytorchLayerTest):
@ -208,7 +208,7 @@ class TestTupleIndex(PytorchLayerTest):
@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=False, freeze_model=False)
ir_version, trace_model=False, freeze_model=False, use_convert_model=True)
class TestTcOutsideTuInsideIfBody(PytorchLayerTest):
@ -236,4 +236,4 @@ class TestTcOutsideTuInsideIfBody(PytorchLayerTest):
@pytest.mark.nightly
def test(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision,
ir_version, trace_model=False, freeze_model=False)
ir_version, trace_model=False, freeze_model=False, use_convert_model=True)