[PT FE] Support inplace operations on aliases of tensors (#17856)
* Support operations on aliases of tensors * Add tests * Fix issue with convnd * Fix code style * Fix issue with tensor index of mutated tensor * Fix if types alignment * Fix issues in keypoint detectron2 * Fix issue with masks in detectron2 * Fix acuracy issue in mobilevitv2 models * Remove unused includes * Return upsample case in lictconstruct replacer * Fix types, apply review feedback * Apply feedback * Revert change of not using shared_from_this for getitem * Fix issue in prim::device transformation * Fix layer tests * Apply review feedback * Fix issue with not existing alias to tensor
This commit is contained in:
parent
ec2db81468
commit
48dec1000e
@ -91,7 +91,7 @@ pt_to_ov_type_map = {
|
||||
|
||||
|
||||
class TorchScriptPythonDecoder (Decoder):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True):
|
||||
def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True, alias_db=None):
|
||||
Decoder.__init__(self)
|
||||
# We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted
|
||||
self.m_decoders = []
|
||||
@ -113,8 +113,10 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
" yourself, please refer to PyTorch documentation: "
|
||||
"https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.")
|
||||
self.graph_element = pt_module.inlined_graph
|
||||
self.alias_db = self.graph_element.alias_db()
|
||||
else:
|
||||
self.graph_element = graph_element
|
||||
self.alias_db = alias_db
|
||||
self.pt_module = pt_module
|
||||
self.raw_inputs = list(self.graph_element.inputs())
|
||||
self.raw_outputs = list(self.graph_element.outputs())
|
||||
@ -273,7 +275,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
def visit_subgraph(self, node_visitor) -> None:
|
||||
# make sure topological order is satisfied
|
||||
for node in self.graph_element.nodes():
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, node)
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db)
|
||||
self.m_decoders.append(decoder)
|
||||
node_visitor(decoder)
|
||||
|
||||
@ -289,7 +291,7 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return list(self.graph_element.blocks())
|
||||
|
||||
def get_subgraph_decoder(self, index: int):
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index])
|
||||
decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db)
|
||||
self.m_decoders.append(decoder)
|
||||
return decoder
|
||||
|
||||
@ -389,6 +391,16 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return pt_value is None
|
||||
return False
|
||||
|
||||
def may_produce_alias(self, in_index: int, out_index: int) -> bool:
|
||||
if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d"]:
|
||||
# AliasDB::may_contain_alias sometimes return True for tensors produced by convnd, we have to workaround that
|
||||
return False
|
||||
try:
|
||||
return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index))
|
||||
except:
|
||||
# Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph):
|
||||
# Function replaces prim::Constant containing List of Tensors with
|
||||
|
@ -109,6 +109,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
std::shared_ptr<TorchDecoder> get_subgraph_decoder(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(std::shared_ptr<TorchDecoder>, TorchDecoder, get_subgraph_decoder, index);
|
||||
}
|
||||
|
||||
bool may_produce_alias(size_t in_index, size_t out_index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, may_produce_alias, in_index, out_index);
|
||||
}
|
||||
};
|
||||
|
||||
void regclass_frontend_pytorch_decoder(py::module m);
|
||||
|
@ -104,8 +104,11 @@ public:
|
||||
// node_visitor is a function that will be fed by nodes in subgraph for all nodes in graph
|
||||
virtual void visit_subgraph(std::function<void(std::shared_ptr<TorchDecoder>)> node_visitor) const = 0;
|
||||
|
||||
/// Probably this toghether with immediate nodes visitor is a replacement for visit_subgraphs with an index
|
||||
/// Probably this together with immediate nodes visitor is a replacement for visit_subgraphs with an index
|
||||
virtual std::shared_ptr<TorchDecoder> get_subgraph_decoder(size_t index) const = 0;
|
||||
|
||||
/// \brief Returns if output may contain alias of input in AliasDB
|
||||
virtual bool may_produce_alias(size_t in_index, size_t out_index) const = 0;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
|
@ -74,6 +74,10 @@ public:
|
||||
return m_decoder->input_is_none(index);
|
||||
}
|
||||
|
||||
Any get_output_type(size_t index) const {
|
||||
return m_decoder->get_output_type(index);
|
||||
}
|
||||
|
||||
size_t get_output_size() const {
|
||||
return m_decoder_outputs.size();
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp"
|
||||
#include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp"
|
||||
#include "transformations/control_flow/unroll_if.hpp"
|
||||
#include "transformations/op_conversions/convert_convertlike.hpp"
|
||||
#include "transforms.hpp"
|
||||
#include "transforms/append_list_unpack_replacer.hpp"
|
||||
#include "transforms/aten_cat_replacer.hpp"
|
||||
@ -25,6 +26,7 @@
|
||||
#include "transforms/aten_stack_list_construct_replacer.hpp"
|
||||
#include "transforms/dict_resolver.hpp"
|
||||
#include "transforms/einsum_list_construct.hpp"
|
||||
#include "transforms/index_loop_getitem_replacer.hpp"
|
||||
#include "transforms/listconstruct_replacer.hpp"
|
||||
#include "transforms/min_max_prim_list_construct_replacer.hpp"
|
||||
#include "transforms/prim_list_construct_pad.hpp"
|
||||
@ -42,8 +44,10 @@ std::set<std::string> get_unconverted_types_from_model(const std::shared_ptr<Mod
|
||||
std::set<std::string> unconverted_ops_types;
|
||||
for (const auto& node : model->get_ordered_ops()) {
|
||||
if (const auto& fw_node = ov::as_type_ptr<PtFrameworkNode>(node)) {
|
||||
auto op_type = fw_node->get_decoder()->get_op_type();
|
||||
unconverted_ops_types.insert(op_type);
|
||||
auto attrs = fw_node->get_attrs();
|
||||
FRONT_END_GENERAL_CHECK(attrs.find("PtTypeName") != attrs.end(),
|
||||
"FrameworkNode attributes do not contain operation type.");
|
||||
unconverted_ops_types.insert(attrs.at("PtTypeName"));
|
||||
}
|
||||
if (const auto& fw_node = ov::as_type_ptr<ov::op::util::MultiSubGraphOp>(node)) {
|
||||
for (size_t i = 0; i < fw_node->get_internal_subgraphs_size(); i++) {
|
||||
@ -97,6 +101,11 @@ std::shared_ptr<Model> FrontEnd::decode(const InputModel::Ptr& model) const {
|
||||
void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
ov::pass::Manager manager;
|
||||
|
||||
// the following 2 transformations are needed for keypoint detectron2 models to work.
|
||||
// AtenIndexToSelect will be called twice
|
||||
manager.register_pass<ov::pass::ConvertConvertLike>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
|
||||
|
||||
manager.register_pass<ov::pass::ConstantFolding>();
|
||||
manager.register_pass<ov::pass::PushConstantToSubgraph>();
|
||||
manager.register_pass<ov::pass::UnrollIf>();
|
||||
@ -114,6 +123,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DictResolver>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::IndexLoopGetitemReplacer>();
|
||||
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParamsResults>();
|
||||
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
|
||||
|
||||
|
@ -6,7 +6,8 @@
|
||||
|
||||
#include "openvino/frontend/exception.hpp"
|
||||
#include "openvino/frontend/pytorch/decoder.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "translate_session.hpp"
|
||||
@ -15,6 +16,8 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector NodeContext::as_constant() const {
|
||||
auto dtype = m_decoder->get_output_type(0);
|
||||
if (dtype.is<type::Str>()) {
|
||||
@ -52,11 +55,36 @@ void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
|
||||
{m_decoder->get_input_debug_name(index), m_decoder->get_input_signature_name(index)});
|
||||
(*m_tensor_map)[input_id] = ov_output;
|
||||
m_mutated_tensors->insert(input_id);
|
||||
|
||||
// Resolve aliases
|
||||
auto back_input_id = input_id;
|
||||
auto back_node_input = ov_output;
|
||||
while (m_translate_session->m_may_be_alias.count(back_input_id)) {
|
||||
// Create node to backprop data. While loop is needed for the cases when alias to tensor point to another alias
|
||||
// to tensor. In that case we need to create a chain of backprop ops
|
||||
size_t in_tensor;
|
||||
std::shared_ptr<TorchDecoder> node;
|
||||
Output<Node> node_converted_output;
|
||||
std::tie(in_tensor, node, node_converted_output) = m_translate_session->m_may_be_alias.at(back_input_id);
|
||||
auto backprop_node = m_translate_session->get_backprop_op(node, node_converted_output, back_node_input);
|
||||
if (m_tensor_map->count(in_tensor)) {
|
||||
// Tensor is not found in the scope of this body, need to get it from internal context and mark mutated
|
||||
OPENVINO_DEBUG << "Couldn't find in the current body the initial aliased tensor: " << in_tensor
|
||||
<< " for operation: " << node->get_op_type() << " creating new body input.";
|
||||
get_tensor_from_model_or_create_input(in_tensor);
|
||||
}
|
||||
m_translate_session->encode_tensor_name(backprop_node, in_tensor);
|
||||
(*m_tensor_map)[in_tensor] = backprop_node;
|
||||
m_mutated_tensors->insert(in_tensor);
|
||||
OPENVINO_DEBUG << "Propagated back data from tensor: " << back_input_id << " to tensor: " << in_tensor << ".\n";
|
||||
back_input_id = in_tensor;
|
||||
back_node_input = backprop_node;
|
||||
}
|
||||
}
|
||||
|
||||
void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) const {
|
||||
if (m_tensor_map->count(index)) {
|
||||
OPENVINO_DEBUG << "[ WARNING ] Current context has tensor. Rewriting.\n";
|
||||
OPENVINO_DEBUG << "[ WARNING ] Current context has tensor " << index << ". Assuming mutated output.\n";
|
||||
}
|
||||
m_translate_session->encode_tensor_name(ov_output, index);
|
||||
(*m_tensor_map)[index] = ov_output;
|
||||
@ -67,7 +95,7 @@ Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) co
|
||||
return m_tensor_map->at(index);
|
||||
} else {
|
||||
// nested subgraphs case
|
||||
auto parameter = std::make_shared<opset10::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
m_translate_session->encode_tensor_name(parameter->output(0), index);
|
||||
(*m_tensor_map)[index] = parameter;
|
||||
m_external_parameters->push_back(parameter);
|
||||
@ -80,7 +108,7 @@ Output<Node> NodeContext::get_input_from_visible_context(size_t index) const {
|
||||
FRONT_END_GENERAL_CHECK(index < get_input_size(), "Index is lower then number of inputs.");
|
||||
auto input_tensor = get_input(static_cast<int>(index));
|
||||
auto input_node = input_tensor.get_node_shared_ptr();
|
||||
if (std::dynamic_pointer_cast<opset10::Parameter>(input_node)) {
|
||||
if (std::dynamic_pointer_cast<v0::Parameter>(input_node)) {
|
||||
// We need to look into external context for inputs that would be feed into this parameter
|
||||
size_t tensor_idx = m_translate_session->decode_tensor_name(input_node->output(0));
|
||||
if (m_ext_tensor_map.count(tensor_idx)) {
|
||||
@ -116,10 +144,10 @@ std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<opset10::Constant> get_constant_at_input(const NodeContext& ctx, size_t index) {
|
||||
std::shared_ptr<v0::Constant> get_constant_at_input(const NodeContext& ctx, size_t index) {
|
||||
FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none.");
|
||||
auto input_node = ctx.get_input_from_visible_context(index).get_node_shared_ptr();
|
||||
auto input = std::dynamic_pointer_cast<opset10::Constant>(input_node);
|
||||
auto input = std::dynamic_pointer_cast<v0::Constant>(input_node);
|
||||
FRONT_END_GENERAL_CHECK(input, "Input with index ", index, " cannot be interpreted as Constant: ", input_node);
|
||||
return input;
|
||||
}
|
||||
@ -185,7 +213,7 @@ std::string NodeContext::const_input<std::string>(size_t index) const {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
Any get_constant_data(const std::shared_ptr<opset10::Constant>& constant) {
|
||||
Any get_constant_data(const std::shared_ptr<v0::Constant>& constant) {
|
||||
const T* ptr = reinterpret_cast<const T*>(constant->get_data_ptr());
|
||||
const auto& shape = constant->get_shape();
|
||||
if (is_scalar(shape)) {
|
||||
@ -206,7 +234,7 @@ Any NodeContext::get_values_from_const_input(int index) const {
|
||||
}
|
||||
|
||||
auto input_node = get_input_from_visible_context(index).get_node_shared_ptr();
|
||||
if (auto constant = as_type_ptr<opset10::Constant>(input_node)) {
|
||||
if (auto constant = as_type_ptr<v0::Constant>(input_node)) {
|
||||
switch (constant->get_element_type()) {
|
||||
case element::f32:
|
||||
return get_constant_data<float>(constant);
|
||||
|
@ -3,9 +3,11 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -19,24 +21,36 @@ using namespace ov::op;
|
||||
OutputVector translate_as_tensor(const NodeContext& context) {
|
||||
// aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor
|
||||
num_inputs_check(context, 1, 4);
|
||||
// Input with index 2 is device, we skip this input
|
||||
// Input with index 3 is flag requires_grad, we skip this input
|
||||
auto dtype = element::f32;
|
||||
auto list_elems = get_list_as_outputs(context.get_input(0));
|
||||
bool is_converted = false;
|
||||
if (!context.input_is_none(1)) {
|
||||
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
|
||||
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
|
||||
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
|
||||
auto type_input = dtype_fw_node->input_value(0);
|
||||
return {context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(0), type_input))};
|
||||
std::for_each(list_elems.begin(), list_elems.end(), [&](Output<Node>& n) {
|
||||
n = context.mark_node(std::make_shared<v1::ConvertLike>(n, type_input));
|
||||
});
|
||||
is_converted = true;
|
||||
}
|
||||
if (auto dtype_const = std::dynamic_pointer_cast<v0::Constant>(dtype_ext_node)) {
|
||||
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
|
||||
dtype = convert_dtype(pt_type);
|
||||
}
|
||||
}
|
||||
auto cast = context.mark_node(std::make_shared<v0::Convert>(context.get_input(0), dtype));
|
||||
|
||||
// Input with index 2 is device, we skip this input
|
||||
// Input with index 3 is flag requires_grad, we skip this input
|
||||
return {cast};
|
||||
if (!is_converted) {
|
||||
std::for_each(list_elems.begin(), list_elems.end(), [&](Output<Node>& n) {
|
||||
n = context.mark_node(std::make_shared<v0::Convert>(n, dtype));
|
||||
});
|
||||
}
|
||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
std::for_each(list_elems.begin(), list_elems.end(), [&](Output<Node>& n) {
|
||||
n = context.mark_node(std::make_shared<v0::Unsqueeze>(n, zero));
|
||||
});
|
||||
return {context.mark_node(std::make_shared<v0::Concat>(OutputVector(list_elems.begin(), list_elems.end()), 0))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
35
src/frontends/pytorch/src/op/copy.cpp
Normal file
35
src/frontends/pytorch/src/op/copy.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_copy_(const NodeContext& context) {
|
||||
// aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto self = context.get_input(0);
|
||||
auto src = context.get_input(1);
|
||||
// Convert src to type of self
|
||||
auto src_converted = context.mark_node(std::make_shared<v1::ConvertLike>(src, self));
|
||||
// Broadcast src to shape of self
|
||||
auto self_shape = context.mark_node(std::make_shared<v3::ShapeOf>(self));
|
||||
Output<Node> res = context.mark_node(std::make_shared<v3::Broadcast>(src_converted, self_shape));
|
||||
context.mutate_input(0, res);
|
||||
return {res};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/floor.hpp"
|
||||
#include "utils.hpp"
|
||||
|
@ -1,24 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_floordiv(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Divide>(x, y, true))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -23,10 +23,19 @@ void align_result_types(const NodeContext& context,
|
||||
auto r2_tensor = r2->input_value(0);
|
||||
auto r1_type = r1_tensor.get_element_type();
|
||||
auto r2_type = r2_tensor.get_element_type();
|
||||
if (r1_type.is_dynamic() || r2_type.is_dynamic())
|
||||
if (r1_type == r2_type)
|
||||
return;
|
||||
element::Type merged_type;
|
||||
if (!element::Type::merge(merged_type, r1_type, r2_type)) {
|
||||
if (element::Type::merge(merged_type, r1_type, r2_type)) {
|
||||
if (r1_type != merged_type) {
|
||||
auto convert1 = std::make_shared<opset10::Convert>(r1_tensor, merged_type);
|
||||
r1->set_argument(0, convert1);
|
||||
}
|
||||
if (r2_type != merged_type) {
|
||||
auto convert2 = std::make_shared<opset10::Convert>(r2_tensor, merged_type);
|
||||
r2->set_argument(0, convert2);
|
||||
}
|
||||
} else {
|
||||
if (r1_type.bitwidth() >= r2_type.bitwidth()) {
|
||||
auto convert = std::make_shared<opset10::Convert>(r2_tensor, r1_type);
|
||||
r2->set_argument(0, convert);
|
||||
@ -131,7 +140,7 @@ OutputVector translate_if(const NodeContext& context) {
|
||||
then_body->add_parameters({new_parameter});
|
||||
then_body->add_results({new_result});
|
||||
then_body->validate_nodes_and_infer_types();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in else body");
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in else body: ", output_idx);
|
||||
inputs_map[output_idx][0] = new_parameter;
|
||||
extra_then_body_results[output_idx] = new_result;
|
||||
OPENVINO_DEBUG << "Modified then body: " << if_node << '\n';
|
||||
@ -143,7 +152,7 @@ OutputVector translate_if(const NodeContext& context) {
|
||||
else_body->add_parameters({new_parameter});
|
||||
else_body->add_results({new_result});
|
||||
else_body->validate_nodes_and_infer_types();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in then body");
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in then body: ", output_idx);
|
||||
inputs_map[output_idx][1] = new_parameter;
|
||||
extra_else_body_results[output_idx] = new_result;
|
||||
OPENVINO_DEBUG << "Modified else body: " << if_node << '\n';
|
||||
|
@ -23,7 +23,7 @@ OutputVector translate_layer_norm(const NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(normalized_shape.size() == 1,
|
||||
"Translation for aten::layer_norm supports only single normalized_shape value, "
|
||||
"which means normalizing over the last dimension.");
|
||||
// TODO: support any dimention
|
||||
// TODO: support any dimension
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto out_node =
|
||||
context.mark_node(std::make_shared<v6::MVN>(context.get_input(0), axes, true, eps, MVNEpsMode::INSIDE_SQRT));
|
||||
|
@ -25,6 +25,20 @@ OutputVector translate_loop(const NodeContext& context) {
|
||||
ov::op::v5::Loop::SpecialBodyPorts spec_ports{0, 0};
|
||||
loop->set_special_body_ports(spec_ports);
|
||||
|
||||
// process outputs first
|
||||
auto session = context.get_session();
|
||||
auto body_results = body->get_results();
|
||||
FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition.");
|
||||
std::map<size_t, Output<Node>> output_idxs;
|
||||
// 0 output is condition, do not need to connect it
|
||||
for (size_t i = 1; i < body_results.size(); i++) {
|
||||
auto result = body_results[i];
|
||||
auto out_idx = session->decode_tensor_name(result->input(0).get_source_output());
|
||||
FRONT_END_OP_CONVERSION_CHECK(output_idxs.count(out_idx) == 0,
|
||||
"More then one body output with same tensor name.");
|
||||
output_idxs[out_idx] = result;
|
||||
}
|
||||
|
||||
auto body_parameters = body->get_parameters();
|
||||
// #0 body parameter is counter;
|
||||
FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required");
|
||||
@ -34,26 +48,28 @@ OutputVector translate_loop(const NodeContext& context) {
|
||||
// #0 loop input is trip_count, #1 loop input is condition
|
||||
// Connect other inputs
|
||||
for (size_t i = 2; i < inputs.size(); i++) {
|
||||
loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]});
|
||||
if (i <= subgraph_decoder->num_of_outputs()) {
|
||||
loop->set_merged_input(body_parameters[i - 1], inputs[i], body_results[i - 1]);
|
||||
} else {
|
||||
loop->set_invariant_input(body_parameters[i - 1], inputs[i]);
|
||||
}
|
||||
}
|
||||
// Connect inputs from external context
|
||||
auto session = context.get_session();
|
||||
for (auto i = inputs.size() - 1; i < body_parameters.size(); i++) {
|
||||
auto param = body_parameters[i];
|
||||
auto input_idx = session->decode_tensor_name(param->output(0));
|
||||
auto external_output = context.get_tensor_from_model_or_create_input(input_idx);
|
||||
loop->set_invariant_inputs(external_output, {param});
|
||||
if (output_idxs.count(input_idx)) {
|
||||
loop->set_merged_input(param, external_output, output_idxs.at(input_idx));
|
||||
} else {
|
||||
loop->set_invariant_input(param, external_output);
|
||||
}
|
||||
}
|
||||
auto body_results = body->get_results();
|
||||
FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition.");
|
||||
std::set<size_t> output_idxs;
|
||||
// 0 output is condition, do not need to connect it
|
||||
|
||||
// connect outputs
|
||||
for (size_t i = 1; i < body_results.size(); i++) {
|
||||
auto result = body_results[i];
|
||||
auto out_idx = session->decode_tensor_name(result->input(0).get_source_output());
|
||||
FRONT_END_OP_CONVERSION_CHECK(output_idxs.count(out_idx) == 0,
|
||||
"More then one body output with same tensor name.");
|
||||
output_idxs.insert(out_idx);
|
||||
context.add_tensor_to_context(out_idx, loop->get_iter_value(result, -1));
|
||||
}
|
||||
loop->validate_and_infer_types();
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/interpolate.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
@ -50,7 +51,10 @@ OutputVector base_translate_upsample(const NodeContext& context,
|
||||
if (context.input_is_none(1)) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(scale_id), "Scale or Output size should be provided");
|
||||
auto spatial_scales = context.get_input(scale_id);
|
||||
|
||||
if (context.get_input_type(1).is<type::List>()) {
|
||||
spatial_scales = concat_list_construct(spatial_scales);
|
||||
}
|
||||
spatial_scales = context.mark_node(std::make_shared<v0::Convert>(spatial_scales, element::f32));
|
||||
size_mode = v11::Interpolate::ShapeCalcMode::SCALES;
|
||||
scales_sizes = context.mark_node(std::make_shared<v1::Multiply>(spatial_scales, scales));
|
||||
} else {
|
||||
@ -58,6 +62,7 @@ OutputVector base_translate_upsample(const NodeContext& context,
|
||||
if (context.get_input_type(1).is<type::List>()) {
|
||||
out_sizes = concat_list_construct(out_sizes);
|
||||
}
|
||||
out_sizes = context.mark_node(std::make_shared<v0::Convert>(out_sizes, element::i32));
|
||||
scales_sizes = context.mark_node(std::make_shared<v1::Multiply>(out_sizes, output_sizes));
|
||||
}
|
||||
auto attrs = v11::Interpolate::InterpolateAttrs(interpolate_mode, size_mode, pad, pad);
|
||||
|
@ -37,6 +37,7 @@ OP_CONVERTER(translate_conv_transposend);
|
||||
OP_CONVERTER(translate_convnd);
|
||||
OP_CONVERTER(translate_convolution);
|
||||
OP_CONVERTER(translate_convolution_mode);
|
||||
OP_CONVERTER(translate_copy_);
|
||||
OP_CONVERTER(translate_cumsum);
|
||||
OP_CONVERTER(translate_deform_conv);
|
||||
OP_CONVERTER(translate_derive_index);
|
||||
@ -53,7 +54,6 @@ OP_CONVERTER(translate_fill_);
|
||||
OP_CONVERTER(translate_flatten);
|
||||
OP_CONVERTER(translate_flip);
|
||||
OP_CONVERTER(translate_floor_divide);
|
||||
OP_CONVERTER(translate_floordiv);
|
||||
OP_CONVERTER(translate_frobenius_norm);
|
||||
OP_CONVERTER(translate_full);
|
||||
OP_CONVERTER(translate_full_like);
|
||||
@ -217,6 +217,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::conv3d", op::translate_convnd},
|
||||
{"aten::convolution", op::translate_convolution},
|
||||
{"aten::copy", op::skip_node},
|
||||
{"aten::copy_", op::translate_copy_},
|
||||
{"aten::cos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cos>},
|
||||
{"aten::cos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cos>>},
|
||||
{"aten::cosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Cosh>},
|
||||
@ -234,6 +235,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::empty", op::translate_empty},
|
||||
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
||||
{"aten::exp", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
|
||||
{"aten::exp_", op::inplace_op<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>>},
|
||||
{"aten::expand", op::translate_expand},
|
||||
{"aten::expand_as", op::translate_expand_as},
|
||||
{"aten::eye", op::translate_eye},
|
||||
@ -243,7 +245,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
|
||||
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
|
||||
{"aten::floor_divide", op::translate_floor_divide},
|
||||
{"aten::floordiv", op::translate_floordiv},
|
||||
{"aten::floordiv", op::translate_floor_divide},
|
||||
{"aten::frobenius_norm", op::translate_frobenius_norm},
|
||||
{"aten::full", op::translate_full},
|
||||
{"aten::full_like", op::translate_full_like},
|
||||
@ -373,6 +375,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::var_mean", op::translate_var_mean},
|
||||
{"aten::view", op::translate_reshape},
|
||||
{"aten::where", op::translate_where},
|
||||
{"aten::zero_", op::inplace_op<op::translate_zeros_like>},
|
||||
{"aten::zeros", op::translate_zeros},
|
||||
{"aten::zeros_like", op::translate_zeros_like},
|
||||
{"prim::Constant", op::translate_constant},
|
||||
|
@ -14,13 +14,21 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode {
|
||||
public:
|
||||
OPENVINO_OP("PtFrameworkNode", "util", ::ov::op::util::FrameworkNode);
|
||||
|
||||
PtFrameworkNode(const std::shared_ptr<TorchDecoder>& decoder, const OutputVector& inputs, size_t output_size)
|
||||
PtFrameworkNode(const std::shared_ptr<TorchDecoder>& decoder,
|
||||
const OutputVector& inputs,
|
||||
size_t output_size,
|
||||
bool is_backprop = false)
|
||||
: ov::op::util::FrameworkNode(inputs, output_size, decoder->get_subgraph_size()),
|
||||
m_decoder(decoder) {
|
||||
ov::op::util::FrameworkNodeAttrs attrs;
|
||||
attrs.set_type_name("PTFrameworkNode");
|
||||
attrs["PtTypeName"] = m_decoder->get_op_type();
|
||||
attrs["PtSchema"] = m_decoder->get_schema();
|
||||
if (is_backprop) {
|
||||
attrs["PtTypeName"] = m_decoder->get_op_type() + "_backprop";
|
||||
attrs["PtSchema"] = "None";
|
||||
} else {
|
||||
attrs["PtTypeName"] = m_decoder->get_op_type();
|
||||
attrs["PtSchema"] = m_decoder->get_schema();
|
||||
}
|
||||
set_attrs(attrs);
|
||||
|
||||
// Set output shapes and types if recognized
|
||||
|
@ -0,0 +1,141 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "index_loop_getitem_replacer.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/greater.hpp"
|
||||
#include "openvino/op/loop.hpp"
|
||||
#include "openvino/op/mod.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/reduce_sum.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
using namespace ov::pass;
|
||||
using namespace ov::op;
|
||||
|
||||
IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() {
|
||||
auto loop_pattern = pattern::wrap_type<v5::Loop>([](Output<Node> n) {
|
||||
auto loop_op = ov::as_type_ptr<v5::Loop>(n.get_node_shared_ptr());
|
||||
bool check_len_input = false;
|
||||
if (auto len_reduce = ov::as_type_ptr<v1::ReduceSum>(loop_op->input_value(0).get_node_shared_ptr())) {
|
||||
if (auto len_slice = ov::as_type_ptr<v8::Slice>(len_reduce->input_value(0).get_node_shared_ptr())) {
|
||||
if (auto len_shape_of = ov::as_type_ptr<v3::ShapeOf>(len_slice->input_value(0).get_node_shared_ptr())) {
|
||||
check_len_input = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return check_len_input;
|
||||
});
|
||||
ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto loop_op = ov::as_type_ptr<v5::Loop>(m.get_match_root());
|
||||
std::shared_ptr<Node> chunk_op;
|
||||
size_t chunk_idx = 0;
|
||||
auto loop_inputs = loop_op->input_values();
|
||||
for (size_t i = 1; i < loop_inputs.size(); i++) {
|
||||
if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), "aten::chunk")) {
|
||||
chunk_op = loop_inputs.at(i).get_node_shared_ptr();
|
||||
chunk_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!chunk_op)
|
||||
return false;
|
||||
|
||||
auto body = loop_op->get_function();
|
||||
std::shared_ptr<Node> chunk_param;
|
||||
for (auto input_desc : loop_op->get_input_descriptions()) {
|
||||
if (input_desc->m_input_index == chunk_idx) {
|
||||
chunk_param = body->get_parameters().at(input_desc->m_body_parameter_index);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!chunk_param)
|
||||
return false;
|
||||
|
||||
auto param_targets = chunk_param->get_output_target_inputs(0);
|
||||
if (param_targets.size() != 1)
|
||||
return false;
|
||||
|
||||
auto getitem = param_targets.begin()->get_node()->shared_from_this();
|
||||
if (!ov::as_type_ptr<v8::Gather>(getitem))
|
||||
return false;
|
||||
|
||||
auto dim = chunk_op->input_value(2);
|
||||
if (!ov::as_type_ptr<v0::Constant>(dim.get_node_shared_ptr()))
|
||||
return false;
|
||||
|
||||
// connect chunk input directly to loop
|
||||
auto chunk_input = chunk_op->input_value(0);
|
||||
chunk_op->output(0).replace(chunk_input);
|
||||
// len(chunks) is number of iterations
|
||||
auto chunks_outside = chunk_op->input_value(1);
|
||||
loop_op->input_value(0).replace(chunks_outside);
|
||||
|
||||
auto chunk_counter = getitem->input_value(1);
|
||||
|
||||
pass::NodeRegistry rg;
|
||||
auto tensor_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto one_1d = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||
|
||||
auto input_shape = rg.make<v3::ShapeOf>(chunk_input, element::i32);
|
||||
auto input_dimension = rg.make<v8::Gather>(input_shape, dim, tensor_0);
|
||||
auto init_chunk_size = rg.make<v1::Divide>(input_dimension, chunks_outside, true);
|
||||
|
||||
// Add 1 if input is not evenly divisible by chunks
|
||||
auto last_chunk_size = rg.make<v1::Mod>(input_dimension, chunks_outside);
|
||||
auto is_last_nonzero = rg.make<v1::Greater>(last_chunk_size, tensor_0);
|
||||
auto is_last_nonzero_int = rg.make<v0::Convert>(is_last_nonzero, element::i32);
|
||||
auto chunk_size = rg.make<v1::Add>(init_chunk_size, is_last_nonzero_int);
|
||||
auto dim_1d = rg.make<v1::Reshape>(dim, one_1d, false);
|
||||
|
||||
// Add new inputs in Loop: chunk_size and dim_1d
|
||||
auto inp_descs = loop_op->get_input_descriptions();
|
||||
auto chunks_size_body = rg.make<v0::Parameter>(element::i32, Shape{1});
|
||||
auto dim_body = rg.make<v0::Parameter>(dim.get_element_type(), Shape{1});
|
||||
body->add_parameters({chunks_size_body, dim_body});
|
||||
loop_op->set_argument(loop_op->get_input_size(), chunk_size);
|
||||
loop_op->set_argument(loop_op->get_input_size(), dim_1d);
|
||||
inp_descs.push_back(std::make_shared<ov::op::util::MultiSubGraphOp::InvariantInputDescription>(
|
||||
loop_op->get_input_size() - 2,
|
||||
body->get_parameters().size() - 2));
|
||||
inp_descs.push_back(std::make_shared<ov::op::util::MultiSubGraphOp::InvariantInputDescription>(
|
||||
loop_op->get_input_size() - 1,
|
||||
body->get_parameters().size() - 1));
|
||||
loop_op->set_input_descriptions(0, inp_descs);
|
||||
|
||||
auto start = rg.make<v1::Multiply>(chunk_counter, chunks_size_body);
|
||||
auto stop = rg.make<v1::Add>(start, chunks_size_body);
|
||||
auto curr_chunk = rg.make<v8::Slice>(chunk_param, start, stop, one_1d, dim_body);
|
||||
replace_node(getitem, curr_chunk);
|
||||
copy_runtime_info({chunk_op, getitem}, rg.get());
|
||||
curr_chunk->set_friendly_name(getitem->get_friendly_name());
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(loop_pattern, "ov::frontend::pytorch::pass::IndexLoopGetitemReplacer");
|
||||
this->register_matcher(m, callback);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @brief IndexLoopGetitemReplacer transformation replaces following graph:
|
||||
* aten::chunk->prim::Loop(aten::__getitem__) to Slice inside the Loop
|
||||
*/
|
||||
class IndexLoopGetitemReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::IndexLoopGetitemReplacer");
|
||||
IndexLoopGetitemReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -9,7 +9,10 @@
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/equal.hpp"
|
||||
#include "openvino/op/interpolate.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/roll.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
@ -52,6 +55,11 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
auto transpose_op = pattern::wrap_type<v1::Transpose>({pattern::any_input(), list});
|
||||
// aten::split_with_sizes case
|
||||
auto vsplit_op = pattern::wrap_type<v1::VariadicSplit>({pattern::any_input(), pattern::any_input(), list});
|
||||
// aten::upsample... case inside the body when body was removed
|
||||
auto interpolate_convert_op = pattern::wrap_type<v0::Convert>({list});
|
||||
auto interpolate_mul_op = pattern::wrap_type<v1::Multiply>({interpolate_convert_op, pattern::any_input()});
|
||||
auto interpolate_op =
|
||||
pattern::wrap_type<v11::Interpolate>({pattern::any_input(), interpolate_mul_op, pattern::any_input()});
|
||||
auto lc_pattern = std::make_shared<pattern::op::Or>(OutputVector{reshape_op,
|
||||
roll_op,
|
||||
broadcast_op,
|
||||
@ -61,7 +69,8 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
select_op,
|
||||
tile_op,
|
||||
transpose_op,
|
||||
vsplit_op});
|
||||
vsplit_op,
|
||||
interpolate_op});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto& pattern_map = m.get_pattern_value_map();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/equal.hpp"
|
||||
#include "openvino/op/not_equal.hpp"
|
||||
@ -25,9 +26,16 @@ using namespace ov::op;
|
||||
StringEqualityReplacer::StringEqualityReplacer() {
|
||||
auto framework_node_lhs = pattern::wrap_type<PtFrameworkNode>();
|
||||
auto framework_node_rhs = pattern::wrap_type<PtFrameworkNode>();
|
||||
auto convert_like = pattern::wrap_type<v1::ConvertLike>({framework_node_rhs, framework_node_lhs});
|
||||
auto equal_op = pattern::wrap_type<v1::Equal>({framework_node_lhs, convert_like});
|
||||
auto not_equal_op = pattern::wrap_type<v1::NotEqual>({framework_node_lhs, convert_like});
|
||||
auto convert_lhs = pattern::wrap_type<v0::Convert>({framework_node_lhs});
|
||||
auto convert_like_lhs = pattern::wrap_type<v1::ConvertLike>({framework_node_lhs, framework_node_rhs});
|
||||
auto convert_rhs = pattern::wrap_type<v0::Convert>({framework_node_rhs});
|
||||
auto convert_like_rhs = pattern::wrap_type<v1::ConvertLike>({framework_node_rhs, framework_node_lhs});
|
||||
auto lhs_pattern =
|
||||
std::make_shared<pattern::op::Or>(OutputVector{framework_node_lhs, convert_lhs, convert_like_lhs});
|
||||
auto rhs_pattern =
|
||||
std::make_shared<pattern::op::Or>(OutputVector{framework_node_rhs, convert_rhs, convert_like_rhs});
|
||||
auto equal_op = pattern::wrap_type<v1::Equal>({lhs_pattern, rhs_pattern});
|
||||
auto not_equal_op = pattern::wrap_type<v1::NotEqual>({lhs_pattern, rhs_pattern});
|
||||
|
||||
auto string_equality_pattern = std::make_shared<pattern::op::Or>(OutputVector{equal_op, not_equal_op});
|
||||
|
||||
|
@ -6,11 +6,18 @@
|
||||
|
||||
#include "input_model.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reduce_prod.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/scatter_nd_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -143,13 +150,41 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
context.get_op_type(),
|
||||
" outputs greater then number of converted outputs.");
|
||||
|
||||
// TODO: Make sure that mapping of fw_outputs to converted_outputs does always work
|
||||
// FIXME: Now it is not true for at least prim::Constant
|
||||
for (size_t i = 0; i < fw_outputs.size(); ++i) {
|
||||
size_t fw_tensor_id = node->output(i);
|
||||
if (node->inputs().size() > 0 && node->may_produce_alias(0, i)) {
|
||||
// TODO: do we need to check other inputs, not only 0?
|
||||
auto in_tensor_id = node->inputs().at(0);
|
||||
if (m_may_be_alias.count(fw_tensor_id)) {
|
||||
size_t recorded_in_tensor_id;
|
||||
std::shared_ptr<TorchDecoder> recorded_node;
|
||||
std::tie(recorded_in_tensor_id, recorded_node, std::ignore) = m_may_be_alias.at(fw_tensor_id);
|
||||
FRONT_END_GENERAL_CHECK(recorded_in_tensor_id == in_tensor_id,
|
||||
"Operation ",
|
||||
context.get_op_type(),
|
||||
" creates alias to tensor which was already created before by ",
|
||||
recorded_node->get_op_type(),
|
||||
", but from different tensor: ",
|
||||
in_tensor_id,
|
||||
" vs ",
|
||||
recorded_in_tensor_id);
|
||||
}
|
||||
m_may_be_alias[fw_tensor_id] = {node->inputs().at(0), node, converted_outputs[i]};
|
||||
OPENVINO_DEBUG << "Registered alias: " << fw_tensor_id << " of tensor: " << node->inputs().at(0)
|
||||
<< " of operation: " << context.get_op_type();
|
||||
}
|
||||
FRONT_END_GENERAL_CHECK(tensor_map->find(fw_tensor_id) == tensor_map->end(),
|
||||
"Duplicated producer for PT value with unique ID: ",
|
||||
fw_tensor_id);
|
||||
auto out_type = context.get_output_type(i);
|
||||
if (out_type.is<element::Type>()) {
|
||||
if (!converted_outputs[i].get_element_type().compatible(out_type.as<element::Type>())) {
|
||||
OPENVINO_DEBUG << "[WARNING] Produced output type for operation " << context.get_op_type()
|
||||
<< " for tensor id: " << fw_tensor_id << " is incompatible: produced "
|
||||
<< converted_outputs[i].get_element_type() << " vs "
|
||||
<< out_type.as<element::Type>();
|
||||
}
|
||||
}
|
||||
(*tensor_map)[fw_tensor_id] = converted_outputs[i];
|
||||
encode_tensor_name(converted_outputs[i], fw_tensor_id, {node->get_output_debug_name(i)});
|
||||
}
|
||||
@ -197,6 +232,8 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
// additional outputs in that case.
|
||||
if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty())
|
||||
results.push_back(std::make_shared<v0::Result>(tensor_map->at(tensor_id)));
|
||||
} else {
|
||||
OPENVINO_DEBUG << "Mutated tensor with id " << tensor_id << " doesn't exist in inputs, skipping.";
|
||||
}
|
||||
}
|
||||
resulting_model = std::make_shared<Model>(results, *parameters);
|
||||
@ -215,10 +252,9 @@ OutputVector TranslateSession::convert_node(const NodeContext& context) {
|
||||
|
||||
} catch (std::exception& e) {
|
||||
OPENVINO_DEBUG << "Exception happened during conversion of op: " << context.get_op_type()
|
||||
<< " with schema: " << context.get_schema() << ": " << e.what() << '\n';
|
||||
<< " with schema: " << context.get_schema() << ": " << e.what();
|
||||
} catch (...) {
|
||||
OPENVINO_DEBUG << "Some exception happened during conversion of node of type: " << context.get_op_type()
|
||||
<< '\n';
|
||||
OPENVINO_DEBUG << "Some exception happened during conversion of node of type: " << context.get_op_type();
|
||||
}
|
||||
// Create PtFrameworkNode for everything that wasn't able to be converted normally
|
||||
return make_framework_node(context);
|
||||
@ -228,7 +264,9 @@ void TranslateSession::encode_tensor_name(Output<Node> output,
|
||||
size_t tensor_idx,
|
||||
std::vector<std::string> additional_names) {
|
||||
if (!output.get_names().empty()) {
|
||||
OPENVINO_DEBUG << "Tensor names already exist: " << output.get_any_name() << ". Rewriting with " << tensor_idx;
|
||||
OPENVINO_DEBUG << "Tensor names already exist: " << output.get_any_name() << ". Will not be rewritten with "
|
||||
<< tensor_idx << ". This is likely a mutated tensor.";
|
||||
return;
|
||||
}
|
||||
auto name = std::to_string(tensor_idx);
|
||||
std::unordered_set<std::string> names;
|
||||
@ -243,7 +281,6 @@ void TranslateSession::encode_tensor_name(Output<Node> output,
|
||||
pair.second.set_names({new_name});
|
||||
pair.second = output;
|
||||
output.set_names(names);
|
||||
|
||||
} else {
|
||||
m_counter_map[tensor_idx] = {0, output};
|
||||
output.set_names(names);
|
||||
@ -257,6 +294,110 @@ size_t TranslateSession::decode_tensor_name(const Output<Node>& output) {
|
||||
return static_cast<size_t>(std::stoll(name));
|
||||
}
|
||||
|
||||
namespace {
|
||||
Output<Node> slice_backprop(const Output<Node>& slice_output, const Output<Node>& value) {
|
||||
auto slice_node = slice_output.get_node_shared_ptr();
|
||||
FRONT_END_OP_CONVERSION_CHECK(ov::as_type_ptr<v8::Slice>(slice_node),
|
||||
"Conversion rule for aten::slice doesn't contain Slice node.");
|
||||
|
||||
auto zero = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto one = v0::Constant::create(element::i64, Shape{}, {1});
|
||||
auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1});
|
||||
auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1});
|
||||
|
||||
// Get 1d indices [0..numel)
|
||||
auto to_insert_data = slice_node->input_value(0);
|
||||
auto input_shape = std::make_shared<v3::ShapeOf>(to_insert_data, element::i64);
|
||||
auto numel = std::make_shared<v1::ReduceProd>(input_shape, zero, false);
|
||||
auto full_data_indices_1d = std::make_shared<v4::Range>(zero, numel, one, element::i64);
|
||||
|
||||
// Slice indices by same start, stop, slice, axes as initial Slice
|
||||
auto full_data_indices = std::make_shared<v1::Reshape>(full_data_indices_1d, input_shape, false);
|
||||
Output<Node> data_indices;
|
||||
if (slice_node->get_input_size() == 5) {
|
||||
data_indices = std::make_shared<v8::Slice>(full_data_indices,
|
||||
slice_node->input_value(1),
|
||||
slice_node->input_value(2),
|
||||
slice_node->input_value(3),
|
||||
slice_node->input_value(4));
|
||||
} else if (slice_node->get_input_size() == 4) {
|
||||
data_indices = std::make_shared<v8::Slice>(full_data_indices,
|
||||
slice_node->input_value(1),
|
||||
slice_node->input_value(2),
|
||||
slice_node->input_value(3));
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Incorrect number of Slice inputs");
|
||||
}
|
||||
|
||||
// Scatter in flattened tensor with indices and flattened data to be inserted
|
||||
auto to_insert_data_1d = std::make_shared<v1::Reshape>(to_insert_data, neg_one_1d, false);
|
||||
auto data_indices_1d = std::make_shared<v1::Reshape>(data_indices, scattering_shape, false);
|
||||
auto to_be_inserted_data_1d = std::make_shared<v1::Reshape>(value, neg_one_1d, false);
|
||||
auto updated_data_1d =
|
||||
std::make_shared<v3::ScatterNDUpdate>(to_insert_data_1d, data_indices_1d, to_be_inserted_data_1d);
|
||||
|
||||
// Reshape to initial shape
|
||||
return std::make_shared<v1::Reshape>(updated_data_1d, input_shape, false);
|
||||
}
|
||||
|
||||
Output<Node> select_backprop(const Output<Node>& select_output, const Output<Node>& value) {
|
||||
auto gather_node = select_output.get_node_shared_ptr();
|
||||
FRONT_END_OP_CONVERSION_CHECK(ov::as_type_ptr<v8::Gather>(gather_node),
|
||||
"Conversion rule for aten::select doesn't contain Gather node.");
|
||||
|
||||
auto zero = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto one = v0::Constant::create(element::i64, Shape{}, {1});
|
||||
auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1});
|
||||
auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1});
|
||||
|
||||
// Get 1d indices [0..numel)
|
||||
auto to_insert_data = gather_node->input_value(0);
|
||||
auto input_shape = std::make_shared<v3::ShapeOf>(to_insert_data, element::i64);
|
||||
auto numel = std::make_shared<v1::ReduceProd>(input_shape, zero, false);
|
||||
auto full_data_indices_1d = std::make_shared<v4::Range>(zero, numel, one, element::i64);
|
||||
|
||||
// Slice indices by same start, stop, slice, axes as initial Slice
|
||||
auto full_data_indices = std::make_shared<v1::Reshape>(full_data_indices_1d, input_shape, false);
|
||||
Output<Node> data_indices =
|
||||
std::make_shared<v8::Gather>(full_data_indices, gather_node->input_value(1), gather_node->input_value(2));
|
||||
|
||||
// Scatter in flattened tensor with indices and flattened data to be inserted
|
||||
auto to_insert_data_1d = std::make_shared<v1::Reshape>(to_insert_data, neg_one_1d, false);
|
||||
auto data_indices_1d = std::make_shared<v1::Reshape>(data_indices, scattering_shape, false);
|
||||
auto to_be_inserted_data_1d = std::make_shared<v1::Reshape>(value, neg_one_1d, false);
|
||||
auto updated_data_1d =
|
||||
std::make_shared<v3::ScatterNDUpdate>(to_insert_data_1d, data_indices_1d, to_be_inserted_data_1d);
|
||||
|
||||
// Reshape to initial shape
|
||||
return std::make_shared<v1::Reshape>(updated_data_1d, input_shape, false);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using BackpropCreatorFunction = std::function<ov::Output<ov::Node>(const Output<Node>&, const Output<Node>&)>;
|
||||
|
||||
Output<Node> TranslateSession::get_backprop_op(const std::shared_ptr<TorchDecoder>& node,
|
||||
const Output<Node>& direct_op_output,
|
||||
const Output<Node>& value) {
|
||||
std::map<std::string, BackpropCreatorFunction> backprop_map = {
|
||||
{"aten::slice", slice_backprop},
|
||||
{"aten::select", select_backprop},
|
||||
};
|
||||
|
||||
Output<Node> backprop_node;
|
||||
try {
|
||||
auto it = backprop_map.find(node->get_op_type());
|
||||
if (it != backprop_map.end()) {
|
||||
return it->second(direct_op_output, value);
|
||||
}
|
||||
|
||||
} catch (std::exception& e) {
|
||||
OPENVINO_DEBUG << "Exception happened during conversion of backprop op: " << node->get_op_type()
|
||||
<< " with schema: " << node->get_schema() << ": " << e.what();
|
||||
}
|
||||
// Create PtFrameworkNode representing unconverted backprop operation
|
||||
return std::make_shared<PtFrameworkNode>(node, OutputVector{value}, 1, true);
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
|
@ -35,13 +35,25 @@ public:
|
||||
const TensorMap& external_tensor_map = {},
|
||||
const std::unordered_map<size_t, PlaceDesc>& external_descriptors = {});
|
||||
|
||||
/// \brief Returns backprop operations for direct operation
|
||||
Output<Node> get_backprop_op(const std::shared_ptr<TorchDecoder>& node,
|
||||
const Output<Node>& direct_op_output,
|
||||
const Output<Node>& value);
|
||||
|
||||
/// \brief Writes pytorch tensor index into openvino tensor
|
||||
void encode_tensor_name(Output<Node> tensor_desc,
|
||||
size_t tensor_idx,
|
||||
std::vector<std::string> additional_names = {});
|
||||
|
||||
/// \brief Gets pytorch tensor index from openvino tensor
|
||||
size_t decode_tensor_name(const Output<Node>& tensor_desc);
|
||||
|
||||
size_t m_friendly_name_counter = 0;
|
||||
|
||||
// Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias
|
||||
// and to the output produced during conversion of this node
|
||||
std::map<size_t, std::tuple<size_t, std::shared_ptr<TorchDecoder>, Output<Node>>> m_may_be_alias;
|
||||
|
||||
private:
|
||||
OutputVector convert_node(const NodeContext& context);
|
||||
|
||||
|
@ -339,6 +339,19 @@ std::unordered_map<size_t, element::Type> bit_to_int{
|
||||
void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Output<Node>& rhs, bool align_scalars) {
|
||||
const auto& lhs_type = lhs.get_element_type();
|
||||
const auto& rhs_type = rhs.get_element_type();
|
||||
auto out_type = context.get_output_type(0);
|
||||
if (out_type.is<element::Type>()) {
|
||||
auto otype = out_type.as<element::Type>();
|
||||
if (otype.is_real()) {
|
||||
if (otype != lhs_type) {
|
||||
lhs = context.mark_node(std::make_shared<ov::op::v0::Convert>(lhs, otype));
|
||||
}
|
||||
if (otype != lhs_type) {
|
||||
rhs = context.mark_node(std::make_shared<ov::op::v0::Convert>(rhs, otype));
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (lhs_type.is_dynamic() || rhs_type.is_dynamic()) {
|
||||
// if any of types is not known, align to lhs type.
|
||||
// TODO: can be fixed with special operation?
|
||||
@ -410,6 +423,18 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
|
||||
}
|
||||
}
|
||||
|
||||
void align_output_types(const NodeContext& context, OutputVector& outputs) {
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
auto dtype_any = context.get_output_type(i);
|
||||
if (dtype_any.is<element::Type>()) {
|
||||
auto dtype = dtype_any.as<element::Type>();
|
||||
if (dtype.is_static() && dtype != outputs[i].get_element_type()) {
|
||||
outputs[i] = std::make_shared<opset10::Convert>(outputs[i], dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start) {
|
||||
std::deque<Output<Node>> res;
|
||||
auto current_output = start;
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
|
||||
namespace ov {
|
||||
|
||||
@ -64,6 +65,8 @@ void align_eltwise_input_types(const NodeContext& context,
|
||||
Output<Node>& rhs,
|
||||
bool align_scalars = false);
|
||||
|
||||
void align_output_types(const NodeContext& context, OutputVector& outputs);
|
||||
|
||||
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
|
||||
|
||||
namespace op {
|
||||
@ -79,7 +82,15 @@ OutputVector inplace_op(const NodeContext& context) {
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_1_inputs(const NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None.");
|
||||
return {context.mark_node(std::make_shared<T>(context.get_input(0)))};
|
||||
auto res = context.mark_node(std::make_shared<T>(context.get_input(0)));
|
||||
auto out_type = context.get_output_type(0);
|
||||
if (out_type.is<element::Type>()) {
|
||||
auto dtype = out_type.as<element::Type>();
|
||||
if (dtype.is_static() && dtype != res->output(0).get_element_type()) {
|
||||
res = context.mark_node(std::make_shared<ov::op::v0::Convert>(res, dtype));
|
||||
}
|
||||
}
|
||||
return {res};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -106,7 +117,9 @@ OutputVector translate_1to1_match_2_inputs_align_types(const NodeContext& contex
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
align_eltwise_input_types(context, lhs, rhs, true);
|
||||
return {context.mark_node(std::make_shared<T>(lhs, rhs))};
|
||||
OutputVector res = {context.mark_node(std::make_shared<T>(lhs, rhs))};
|
||||
align_output_types(context, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
inline OutputVector return_false_scalar(const NodeContext& context) {
|
||||
|
38
tests/layer_tests/pytorch_tests/test_aliases.py
Normal file
38
tests/layer_tests/pytorch_tests/test_aliases.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class aten_alias(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x[:, 1, :, :] = 4.
|
||||
return x
|
||||
|
||||
|
||||
class aten_loop_alias(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
for i in range(2):
|
||||
x[:, i, :, :] = 4.
|
||||
return x
|
||||
|
||||
|
||||
class TestAliases(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_alias(self, ie_device, precision, ir_version):
|
||||
self._test(aten_alias(), None, [
|
||||
"aten::slice", "aten::select", "aten::copy_"], ie_device, precision, ir_version)
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_loop_alias(self, ie_device, precision, ir_version):
|
||||
self._test(aten_loop_alias(), None, [
|
||||
"aten::slice", "aten::select", "aten::copy_", "prim::Loop"], ie_device, precision, ir_version)
|
@ -54,6 +54,7 @@ class TestCat(PytorchLayerTest):
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.")
|
||||
def test_loop_append_cat(self, ie_device, precision, ir_version):
|
||||
self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"],
|
||||
ie_device, precision, ir_version, freeze_model=False)
|
||||
|
@ -120,3 +120,41 @@ class TestChunk(PytorchLayerTest):
|
||||
for idx in [0, 1, output_chunks - 1]:
|
||||
self._test(aten_chunk_getitem(chunks, dim, idx), None, "aten::chunk",
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class aten_chunk_loop_getitem(torch.nn.Module):
|
||||
def __init__(self, num_chunks) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
self.num_chunks = num_chunks
|
||||
|
||||
def forward(self, input_tensor):
|
||||
chunks = torch.chunk(torch.arange(
|
||||
input_tensor.shape[0]), self.num_chunks)
|
||||
|
||||
for inds in chunks:
|
||||
input_tensor[inds] *= 10
|
||||
return input_tensor
|
||||
|
||||
|
||||
class TestChunkLoopGetitem(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.rand(*self.input_shape),)
|
||||
|
||||
@pytest.mark.parametrize("input_shape", [
|
||||
(4, 4),
|
||||
(5, 9, 7),
|
||||
(10, 13, 11),
|
||||
(8, 7, 6, 5, 4),
|
||||
])
|
||||
@pytest.mark.parametrize("chunks", [
|
||||
2,
|
||||
3,
|
||||
4
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_chunk_loop_getitem(self, input_shape, chunks, ie_device, precision, ir_version):
|
||||
self.input_shape = input_shape
|
||||
|
||||
self._test(aten_chunk_loop_getitem(chunks), None, ["aten::chunk", "prim::Loop", "aten::__getitem__"],
|
||||
ie_device, precision, ir_version)
|
||||
|
33
tests/layer_tests/pytorch_tests/test_copy.py
Normal file
33
tests/layer_tests/pytorch_tests/test_copy.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestCopy(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
|
||||
|
||||
def create_model(self, value):
|
||||
import torch
|
||||
|
||||
class aten_copy(torch.nn.Module):
|
||||
def __init__(self, value):
|
||||
super(aten_copy, self).__init__()
|
||||
self.value = torch.tensor(value)
|
||||
|
||||
def forward(self, x):
|
||||
return x.copy_(self.value)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_copy(value), ref_net, "aten::copy_"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("value", [1, [2.5], range(224)])
|
||||
def test_copy_(self, value, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(value), ie_device, precision, ir_version)
|
@ -40,7 +40,7 @@ class TestDeriveIndexRangeLength(PytorchLayerTest):
|
||||
step = int(x[2])
|
||||
accumulator = 0
|
||||
for idx in range(start, stop, step):
|
||||
accumulator = idx
|
||||
accumulator += idx
|
||||
return accumulator
|
||||
|
||||
ref_net = None
|
||||
|
@ -130,6 +130,31 @@ class TestFill(PytorchLayerTest):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={'value': value, 'shape': shape, "input_dtype": input_dtype, "value_dtype": value_dtype})
|
||||
|
||||
class TestZero(PytorchLayerTest):
|
||||
def _prepare_input(self, shape, input_dtype):
|
||||
return (np.random.randn(*shape).astype(input_dtype),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class aten_zero(torch.nn.Module):
|
||||
|
||||
def forward(self, input_t: torch.Tensor):
|
||||
return input_t.zero_()
|
||||
ref_net = None
|
||||
|
||||
model = aten_zero()
|
||||
|
||||
return model, ref_net, "aten::zero_"
|
||||
|
||||
@pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]])
|
||||
@pytest.mark.parametrize("input_dtype", ["int8", "int32", "int64", "float32", "float64"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_zero(self, shape, input_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={'shape': shape, "input_dtype": input_dtype})
|
||||
|
||||
class TestFullLike(PytorchLayerTest):
|
||||
def _prepare_input(self, value, shape):
|
||||
return (np.random.randn(*shape).astype(np.float32), np.array(value, dtype=np.float32),)
|
||||
|
@ -64,6 +64,7 @@ class TestUnaryOp(PytorchLayerTest):
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
|
||||
@pytest.mark.parametrize("op,op_type", [
|
||||
# some pytorch inplace ops do not support int
|
||||
(torch.exp_, "aten::exp_"),
|
||||
(torch.sigmoid_, "aten::sigmoid_"),
|
||||
# trigonometry
|
||||
(torch.cos_, "aten::cos_"),
|
||||
|
Loading…
Reference in New Issue
Block a user