diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index 4b67980cad1..bee7d216dd6 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -91,7 +91,7 @@ pt_to_ov_type_map = { class TorchScriptPythonDecoder (Decoder): - def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True): + def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True, alias_db=None): Decoder.__init__(self) # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted self.m_decoders = [] @@ -113,8 +113,10 @@ class TorchScriptPythonDecoder (Decoder): " yourself, please refer to PyTorch documentation: " "https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html.") self.graph_element = pt_module.inlined_graph + self.alias_db = self.graph_element.alias_db() else: self.graph_element = graph_element + self.alias_db = alias_db self.pt_module = pt_module self.raw_inputs = list(self.graph_element.inputs()) self.raw_outputs = list(self.graph_element.outputs()) @@ -273,7 +275,7 @@ class TorchScriptPythonDecoder (Decoder): def visit_subgraph(self, node_visitor) -> None: # make sure topological order is satisfied for node in self.graph_element.nodes(): - decoder = TorchScriptPythonDecoder(self.pt_module, node) + decoder = TorchScriptPythonDecoder(self.pt_module, node, alias_db=self.alias_db) self.m_decoders.append(decoder) node_visitor(decoder) @@ -289,7 +291,7 @@ class TorchScriptPythonDecoder (Decoder): return list(self.graph_element.blocks()) def get_subgraph_decoder(self, index: int): - decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index]) + decoder = TorchScriptPythonDecoder(self.pt_module, self.get_subgraphs()[index], alias_db=self.alias_db) self.m_decoders.append(decoder) return decoder @@ -389,6 +391,16 @@ class TorchScriptPythonDecoder (Decoder): return pt_value is None return False + def may_produce_alias(self, in_index: int, out_index: int) -> bool: + if self.get_op_type() in ["aten::conv1d", "aten::conv2d", "aten::conv3d"]: + # AliasDB::may_contain_alias sometimes return True for tensors produced by convnd, we have to workaround that + return False + try: + return self.alias_db.may_contain_alias(self._raw_input(in_index), self._raw_output(out_index)) + except: + # Sometimes pytorch fails to get result with IndexError exception while these indexes exist in node + return False + @staticmethod def _transform_tensor_list_constants_to_listconstruct(graph: torch.Graph): # Function replaces prim::Constant containing List of Tensors with diff --git a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp index 4151a030257..1df1f7f0024 100644 --- a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp +++ b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp @@ -109,6 +109,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder { std::shared_ptr get_subgraph_decoder(size_t index) const override { PYBIND11_OVERRIDE_PURE(std::shared_ptr, TorchDecoder, get_subgraph_decoder, index); } + + bool may_produce_alias(size_t in_index, size_t out_index) const override { + PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, may_produce_alias, in_index, out_index); + } }; void regclass_frontend_pytorch_decoder(py::module m); diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index 963a5e051c9..3c008e563fe 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -104,8 +104,11 @@ public: // node_visitor is a function that will be fed by nodes in subgraph for all nodes in graph virtual void visit_subgraph(std::function)> node_visitor) const = 0; - /// Probably this toghether with immediate nodes visitor is a replacement for visit_subgraphs with an index + /// Probably this together with immediate nodes visitor is a replacement for visit_subgraphs with an index virtual std::shared_ptr get_subgraph_decoder(size_t index) const = 0; + + /// \brief Returns if output may contain alias of input in AliasDB + virtual bool may_produce_alias(size_t in_index, size_t out_index) const = 0; }; } // namespace pytorch diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp index 8502101bdfb..2291bc4e651 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp @@ -74,6 +74,10 @@ public: return m_decoder->input_is_none(index); } + Any get_output_type(size_t index) const { + return m_decoder->get_output_type(index); + } + size_t get_output_size() const { return m_decoder_outputs.size(); } diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 054accca91b..472eb571ed4 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -16,6 +16,7 @@ #include "transformations/common_optimizations/remove_multi_subgraph_op_dangling_params.hpp" #include "transformations/common_optimizations/reverse_shape_and_type_infer.hpp" #include "transformations/control_flow/unroll_if.hpp" +#include "transformations/op_conversions/convert_convertlike.hpp" #include "transforms.hpp" #include "transforms/append_list_unpack_replacer.hpp" #include "transforms/aten_cat_replacer.hpp" @@ -25,6 +26,7 @@ #include "transforms/aten_stack_list_construct_replacer.hpp" #include "transforms/dict_resolver.hpp" #include "transforms/einsum_list_construct.hpp" +#include "transforms/index_loop_getitem_replacer.hpp" #include "transforms/listconstruct_replacer.hpp" #include "transforms/min_max_prim_list_construct_replacer.hpp" #include "transforms/prim_list_construct_pad.hpp" @@ -42,8 +44,10 @@ std::set get_unconverted_types_from_model(const std::shared_ptr unconverted_ops_types; for (const auto& node : model->get_ordered_ops()) { if (const auto& fw_node = ov::as_type_ptr(node)) { - auto op_type = fw_node->get_decoder()->get_op_type(); - unconverted_ops_types.insert(op_type); + auto attrs = fw_node->get_attrs(); + FRONT_END_GENERAL_CHECK(attrs.find("PtTypeName") != attrs.end(), + "FrameworkNode attributes do not contain operation type."); + unconverted_ops_types.insert(attrs.at("PtTypeName")); } if (const auto& fw_node = ov::as_type_ptr(node)) { for (size_t i = 0; i < fw_node->get_internal_subgraphs_size(); i++) { @@ -97,6 +101,11 @@ std::shared_ptr FrontEnd::decode(const InputModel::Ptr& model) const { void FrontEnd::normalize(const std::shared_ptr& model) const { ov::pass::Manager manager; + // the following 2 transformations are needed for keypoint detectron2 models to work. + // AtenIndexToSelect will be called twice + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); @@ -114,6 +123,7 @@ void FrontEnd::normalize(const std::shared_ptr& model) const { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index c3ae7f9008e..682deecfacb 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -6,7 +6,8 @@ #include "openvino/frontend/exception.hpp" #include "openvino/frontend/pytorch/decoder.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/parameter.hpp" #include "openvino/util/log.hpp" #include "pt_framework_node.hpp" #include "translate_session.hpp" @@ -15,6 +16,8 @@ namespace ov { namespace frontend { namespace pytorch { +using namespace ov::op; + OutputVector NodeContext::as_constant() const { auto dtype = m_decoder->get_output_type(0); if (dtype.is()) { @@ -52,11 +55,36 @@ void NodeContext::mutate_input(size_t index, Output ov_output) const { {m_decoder->get_input_debug_name(index), m_decoder->get_input_signature_name(index)}); (*m_tensor_map)[input_id] = ov_output; m_mutated_tensors->insert(input_id); + + // Resolve aliases + auto back_input_id = input_id; + auto back_node_input = ov_output; + while (m_translate_session->m_may_be_alias.count(back_input_id)) { + // Create node to backprop data. While loop is needed for the cases when alias to tensor point to another alias + // to tensor. In that case we need to create a chain of backprop ops + size_t in_tensor; + std::shared_ptr node; + Output node_converted_output; + std::tie(in_tensor, node, node_converted_output) = m_translate_session->m_may_be_alias.at(back_input_id); + auto backprop_node = m_translate_session->get_backprop_op(node, node_converted_output, back_node_input); + if (m_tensor_map->count(in_tensor)) { + // Tensor is not found in the scope of this body, need to get it from internal context and mark mutated + OPENVINO_DEBUG << "Couldn't find in the current body the initial aliased tensor: " << in_tensor + << " for operation: " << node->get_op_type() << " creating new body input."; + get_tensor_from_model_or_create_input(in_tensor); + } + m_translate_session->encode_tensor_name(backprop_node, in_tensor); + (*m_tensor_map)[in_tensor] = backprop_node; + m_mutated_tensors->insert(in_tensor); + OPENVINO_DEBUG << "Propagated back data from tensor: " << back_input_id << " to tensor: " << in_tensor << ".\n"; + back_input_id = in_tensor; + back_node_input = backprop_node; + } } void NodeContext::add_tensor_to_context(size_t index, Output ov_output) const { if (m_tensor_map->count(index)) { - OPENVINO_DEBUG << "[ WARNING ] Current context has tensor. Rewriting.\n"; + OPENVINO_DEBUG << "[ WARNING ] Current context has tensor " << index << ". Assuming mutated output.\n"; } m_translate_session->encode_tensor_name(ov_output, index); (*m_tensor_map)[index] = ov_output; @@ -67,7 +95,7 @@ Output NodeContext::get_tensor_from_model_or_create_input(size_t index) co return m_tensor_map->at(index); } else { // nested subgraphs case - auto parameter = std::make_shared(element::dynamic, PartialShape::dynamic()); + auto parameter = std::make_shared(element::dynamic, PartialShape::dynamic()); m_translate_session->encode_tensor_name(parameter->output(0), index); (*m_tensor_map)[index] = parameter; m_external_parameters->push_back(parameter); @@ -80,7 +108,7 @@ Output NodeContext::get_input_from_visible_context(size_t index) const { FRONT_END_GENERAL_CHECK(index < get_input_size(), "Index is lower then number of inputs."); auto input_tensor = get_input(static_cast(index)); auto input_node = input_tensor.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(input_node)) { + if (std::dynamic_pointer_cast(input_node)) { // We need to look into external context for inputs that would be feed into this parameter size_t tensor_idx = m_translate_session->decode_tensor_name(input_node->output(0)); if (m_ext_tensor_map.count(tensor_idx)) { @@ -116,10 +144,10 @@ std::shared_ptr NodeContext::convert_subgraph(size_t index) const { } namespace { -std::shared_ptr get_constant_at_input(const NodeContext& ctx, size_t index) { +std::shared_ptr get_constant_at_input(const NodeContext& ctx, size_t index) { FRONT_END_GENERAL_CHECK(!ctx.input_is_none(index), "Input with index: ", index, " is none."); auto input_node = ctx.get_input_from_visible_context(index).get_node_shared_ptr(); - auto input = std::dynamic_pointer_cast(input_node); + auto input = std::dynamic_pointer_cast(input_node); FRONT_END_GENERAL_CHECK(input, "Input with index ", index, " cannot be interpreted as Constant: ", input_node); return input; } @@ -185,7 +213,7 @@ std::string NodeContext::const_input(size_t index) const { namespace { template -Any get_constant_data(const std::shared_ptr& constant) { +Any get_constant_data(const std::shared_ptr& constant) { const T* ptr = reinterpret_cast(constant->get_data_ptr()); const auto& shape = constant->get_shape(); if (is_scalar(shape)) { @@ -206,7 +234,7 @@ Any NodeContext::get_values_from_const_input(int index) const { } auto input_node = get_input_from_visible_context(index).get_node_shared_ptr(); - if (auto constant = as_type_ptr(input_node)) { + if (auto constant = as_type_ptr(input_node)) { switch (constant->get_element_type()) { case element::f32: return get_constant_data(constant); diff --git a/src/frontends/pytorch/src/op/as_tensor.cpp b/src/frontends/pytorch/src/op/as_tensor.cpp index ae2c15d0a1e..0400d201bbe 100644 --- a/src/frontends/pytorch/src/op/as_tensor.cpp +++ b/src/frontends/pytorch/src/op/as_tensor.cpp @@ -3,9 +3,11 @@ // #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" #include "openvino/op/convert.hpp" #include "openvino/op/convert_like.hpp" +#include "openvino/op/unsqueeze.hpp" #include "pt_framework_node.hpp" #include "utils.hpp" @@ -19,24 +21,36 @@ using namespace ov::op; OutputVector translate_as_tensor(const NodeContext& context) { // aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor num_inputs_check(context, 1, 4); + // Input with index 2 is device, we skip this input + // Input with index 3 is flag requires_grad, we skip this input auto dtype = element::f32; + auto list_elems = get_list_as_outputs(context.get_input(0)); + bool is_converted = false; if (!context.input_is_none(1)) { auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr(); auto dtype_fw_node = std::dynamic_pointer_cast(dtype_ext_node); if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") { auto type_input = dtype_fw_node->input_value(0); - return {context.mark_node(std::make_shared(context.get_input(0), type_input))}; + std::for_each(list_elems.begin(), list_elems.end(), [&](Output& n) { + n = context.mark_node(std::make_shared(n, type_input)); + }); + is_converted = true; } if (auto dtype_const = std::dynamic_pointer_cast(dtype_ext_node)) { auto pt_type = dtype_const->cast_vector()[0]; dtype = convert_dtype(pt_type); } } - auto cast = context.mark_node(std::make_shared(context.get_input(0), dtype)); - - // Input with index 2 is device, we skip this input - // Input with index 3 is flag requires_grad, we skip this input - return {cast}; + if (!is_converted) { + std::for_each(list_elems.begin(), list_elems.end(), [&](Output& n) { + n = context.mark_node(std::make_shared(n, dtype)); + }); + } + auto zero = v0::Constant::create(element::i32, Shape{}, {0}); + std::for_each(list_elems.begin(), list_elems.end(), [&](Output& n) { + n = context.mark_node(std::make_shared(n, zero)); + }); + return {context.mark_node(std::make_shared(OutputVector(list_elems.begin(), list_elems.end()), 0))}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op/copy.cpp b/src/frontends/pytorch/src/op/copy.cpp new file mode 100644 index 00000000000..271d06ffd92 --- /dev/null +++ b/src/frontends/pytorch/src/op/copy.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/convert_like.hpp" +#include "openvino/op/shape_of.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; + +OutputVector translate_copy_(const NodeContext& context) { + // aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + num_inputs_check(context, 2, 3); + auto self = context.get_input(0); + auto src = context.get_input(1); + // Convert src to type of self + auto src_converted = context.mark_node(std::make_shared(src, self)); + // Broadcast src to shape of self + auto self_shape = context.mark_node(std::make_shared(self)); + Output res = context.mark_node(std::make_shared(src_converted, self_shape)); + context.mutate_input(0, res); + return {res}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op/floor_divide.cpp b/src/frontends/pytorch/src/op/floor_divide.cpp index 4fb1b230d44..8e9eb8a44f6 100644 --- a/src/frontends/pytorch/src/op/floor_divide.cpp +++ b/src/frontends/pytorch/src/op/floor_divide.cpp @@ -3,6 +3,7 @@ // #include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/divide.hpp" #include "openvino/op/floor.hpp" #include "utils.hpp" diff --git a/src/frontends/pytorch/src/op/floordiv.cpp b/src/frontends/pytorch/src/op/floordiv.cpp deleted file mode 100644 index 91c03e74d7f..00000000000 --- a/src/frontends/pytorch/src/op/floordiv.cpp +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (C) 2018-2023 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/op/divide.hpp" -#include "utils.hpp" - -namespace ov { -namespace frontend { -namespace pytorch { -namespace op { - -OutputVector translate_floordiv(const NodeContext& context) { - num_inputs_check(context, 2, 2); - auto x = context.get_input(0); - auto y = context.get_input(1); - return {context.mark_node(std::make_shared(x, y, true))}; -}; - -} // namespace op -} // namespace pytorch -} // namespace frontend -} // namespace ov diff --git a/src/frontends/pytorch/src/op/if.cpp b/src/frontends/pytorch/src/op/if.cpp index 77015fb1dee..15d1c5e24c1 100644 --- a/src/frontends/pytorch/src/op/if.cpp +++ b/src/frontends/pytorch/src/op/if.cpp @@ -23,10 +23,19 @@ void align_result_types(const NodeContext& context, auto r2_tensor = r2->input_value(0); auto r1_type = r1_tensor.get_element_type(); auto r2_type = r2_tensor.get_element_type(); - if (r1_type.is_dynamic() || r2_type.is_dynamic()) + if (r1_type == r2_type) return; element::Type merged_type; - if (!element::Type::merge(merged_type, r1_type, r2_type)) { + if (element::Type::merge(merged_type, r1_type, r2_type)) { + if (r1_type != merged_type) { + auto convert1 = std::make_shared(r1_tensor, merged_type); + r1->set_argument(0, convert1); + } + if (r2_type != merged_type) { + auto convert2 = std::make_shared(r2_tensor, merged_type); + r2->set_argument(0, convert2); + } + } else { if (r1_type.bitwidth() >= r2_type.bitwidth()) { auto convert = std::make_shared(r2_tensor, r1_type); r2->set_argument(0, convert); @@ -131,7 +140,7 @@ OutputVector translate_if(const NodeContext& context) { then_body->add_parameters({new_parameter}); then_body->add_results({new_result}); then_body->validate_nodes_and_infer_types(); - FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in else body"); + FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in else body: ", output_idx); inputs_map[output_idx][0] = new_parameter; extra_then_body_results[output_idx] = new_result; OPENVINO_DEBUG << "Modified then body: " << if_node << '\n'; @@ -143,7 +152,7 @@ OutputVector translate_if(const NodeContext& context) { else_body->add_parameters({new_parameter}); else_body->add_results({new_result}); else_body->validate_nodes_and_infer_types(); - FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in then body"); + FRONT_END_OP_CONVERSION_CHECK(inputs_map.count(output_idx), "Input must exist in then body: ", output_idx); inputs_map[output_idx][1] = new_parameter; extra_else_body_results[output_idx] = new_result; OPENVINO_DEBUG << "Modified else body: " << if_node << '\n'; diff --git a/src/frontends/pytorch/src/op/layer_norm.cpp b/src/frontends/pytorch/src/op/layer_norm.cpp index 204d7164531..69a8e947d78 100644 --- a/src/frontends/pytorch/src/op/layer_norm.cpp +++ b/src/frontends/pytorch/src/op/layer_norm.cpp @@ -23,7 +23,7 @@ OutputVector translate_layer_norm(const NodeContext& context) { FRONT_END_OP_CONVERSION_CHECK(normalized_shape.size() == 1, "Translation for aten::layer_norm supports only single normalized_shape value, " "which means normalizing over the last dimension."); - // TODO: support any dimention + // TODO: support any dimension auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); auto out_node = context.mark_node(std::make_shared(context.get_input(0), axes, true, eps, MVNEpsMode::INSIDE_SQRT)); diff --git a/src/frontends/pytorch/src/op/loop.cpp b/src/frontends/pytorch/src/op/loop.cpp index 36369ea63bd..99c59b46f34 100644 --- a/src/frontends/pytorch/src/op/loop.cpp +++ b/src/frontends/pytorch/src/op/loop.cpp @@ -25,6 +25,20 @@ OutputVector translate_loop(const NodeContext& context) { ov::op::v5::Loop::SpecialBodyPorts spec_ports{0, 0}; loop->set_special_body_ports(spec_ports); + // process outputs first + auto session = context.get_session(); + auto body_results = body->get_results(); + FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition."); + std::map> output_idxs; + // 0 output is condition, do not need to connect it + for (size_t i = 1; i < body_results.size(); i++) { + auto result = body_results[i]; + auto out_idx = session->decode_tensor_name(result->input(0).get_source_output()); + FRONT_END_OP_CONVERSION_CHECK(output_idxs.count(out_idx) == 0, + "More then one body output with same tensor name."); + output_idxs[out_idx] = result; + } + auto body_parameters = body->get_parameters(); // #0 body parameter is counter; FRONT_END_OP_CONVERSION_CHECK(body_parameters.size() > 0, "At least one input to Loop body is required"); @@ -34,26 +48,28 @@ OutputVector translate_loop(const NodeContext& context) { // #0 loop input is trip_count, #1 loop input is condition // Connect other inputs for (size_t i = 2; i < inputs.size(); i++) { - loop->set_invariant_inputs(inputs[i], {body_parameters[i - 1]}); + if (i <= subgraph_decoder->num_of_outputs()) { + loop->set_merged_input(body_parameters[i - 1], inputs[i], body_results[i - 1]); + } else { + loop->set_invariant_input(body_parameters[i - 1], inputs[i]); + } } // Connect inputs from external context - auto session = context.get_session(); for (auto i = inputs.size() - 1; i < body_parameters.size(); i++) { auto param = body_parameters[i]; auto input_idx = session->decode_tensor_name(param->output(0)); auto external_output = context.get_tensor_from_model_or_create_input(input_idx); - loop->set_invariant_inputs(external_output, {param}); + if (output_idxs.count(input_idx)) { + loop->set_merged_input(param, external_output, output_idxs.at(input_idx)); + } else { + loop->set_invariant_input(param, external_output); + } } - auto body_results = body->get_results(); - FRONT_END_OP_CONVERSION_CHECK(body_results.size() > 0, "At least one output from loop is required - condition."); - std::set output_idxs; - // 0 output is condition, do not need to connect it + + // connect outputs for (size_t i = 1; i < body_results.size(); i++) { auto result = body_results[i]; auto out_idx = session->decode_tensor_name(result->input(0).get_source_output()); - FRONT_END_OP_CONVERSION_CHECK(output_idxs.count(out_idx) == 0, - "More then one body output with same tensor name."); - output_idxs.insert(out_idx); context.add_tensor_to_context(out_idx, loop->get_iter_value(result, -1)); } loop->validate_and_infer_types(); diff --git a/src/frontends/pytorch/src/op/upsample.cpp b/src/frontends/pytorch/src/op/upsample.cpp index 4ab2b843a9b..06caacbab9d 100644 --- a/src/frontends/pytorch/src/op/upsample.cpp +++ b/src/frontends/pytorch/src/op/upsample.cpp @@ -4,6 +4,7 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/interpolate.hpp" #include "openvino/op/multiply.hpp" #include "utils.hpp" @@ -50,7 +51,10 @@ OutputVector base_translate_upsample(const NodeContext& context, if (context.input_is_none(1)) { FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(scale_id), "Scale or Output size should be provided"); auto spatial_scales = context.get_input(scale_id); - + if (context.get_input_type(1).is()) { + spatial_scales = concat_list_construct(spatial_scales); + } + spatial_scales = context.mark_node(std::make_shared(spatial_scales, element::f32)); size_mode = v11::Interpolate::ShapeCalcMode::SCALES; scales_sizes = context.mark_node(std::make_shared(spatial_scales, scales)); } else { @@ -58,6 +62,7 @@ OutputVector base_translate_upsample(const NodeContext& context, if (context.get_input_type(1).is()) { out_sizes = concat_list_construct(out_sizes); } + out_sizes = context.mark_node(std::make_shared(out_sizes, element::i32)); scales_sizes = context.mark_node(std::make_shared(out_sizes, output_sizes)); } auto attrs = v11::Interpolate::InterpolateAttrs(interpolate_mode, size_mode, pad, pad); diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 2b331800809..1b7e76f9474 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -37,6 +37,7 @@ OP_CONVERTER(translate_conv_transposend); OP_CONVERTER(translate_convnd); OP_CONVERTER(translate_convolution); OP_CONVERTER(translate_convolution_mode); +OP_CONVERTER(translate_copy_); OP_CONVERTER(translate_cumsum); OP_CONVERTER(translate_deform_conv); OP_CONVERTER(translate_derive_index); @@ -53,7 +54,6 @@ OP_CONVERTER(translate_fill_); OP_CONVERTER(translate_flatten); OP_CONVERTER(translate_flip); OP_CONVERTER(translate_floor_divide); -OP_CONVERTER(translate_floordiv); OP_CONVERTER(translate_frobenius_norm); OP_CONVERTER(translate_full); OP_CONVERTER(translate_full_like); @@ -217,6 +217,7 @@ const std::map get_supported_ops() { {"aten::conv3d", op::translate_convnd}, {"aten::convolution", op::translate_convolution}, {"aten::copy", op::skip_node}, + {"aten::copy_", op::translate_copy_}, {"aten::cos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::cos_", op::inplace_op>}, {"aten::cosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, @@ -234,6 +235,7 @@ const std::map get_supported_ops() { {"aten::empty", op::translate_empty}, {"aten::eq", op::translate_1to1_match_2_inputs_align_types}, {"aten::exp", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, + {"aten::exp_", op::inplace_op>}, {"aten::expand", op::translate_expand}, {"aten::expand_as", op::translate_expand_as}, {"aten::eye", op::translate_eye}, @@ -243,7 +245,7 @@ const std::map get_supported_ops() { {"aten::floor", op::translate_1to1_match_1_inputs}, {"aten::floor_", op::inplace_op>}, {"aten::floor_divide", op::translate_floor_divide}, - {"aten::floordiv", op::translate_floordiv}, + {"aten::floordiv", op::translate_floor_divide}, {"aten::frobenius_norm", op::translate_frobenius_norm}, {"aten::full", op::translate_full}, {"aten::full_like", op::translate_full_like}, @@ -373,6 +375,7 @@ const std::map get_supported_ops() { {"aten::var_mean", op::translate_var_mean}, {"aten::view", op::translate_reshape}, {"aten::where", op::translate_where}, + {"aten::zero_", op::inplace_op}, {"aten::zeros", op::translate_zeros}, {"aten::zeros_like", op::translate_zeros_like}, {"prim::Constant", op::translate_constant}, diff --git a/src/frontends/pytorch/src/pt_framework_node.hpp b/src/frontends/pytorch/src/pt_framework_node.hpp index 05db41c190a..29018d3ccba 100644 --- a/src/frontends/pytorch/src/pt_framework_node.hpp +++ b/src/frontends/pytorch/src/pt_framework_node.hpp @@ -14,13 +14,21 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode { public: OPENVINO_OP("PtFrameworkNode", "util", ::ov::op::util::FrameworkNode); - PtFrameworkNode(const std::shared_ptr& decoder, const OutputVector& inputs, size_t output_size) + PtFrameworkNode(const std::shared_ptr& decoder, + const OutputVector& inputs, + size_t output_size, + bool is_backprop = false) : ov::op::util::FrameworkNode(inputs, output_size, decoder->get_subgraph_size()), m_decoder(decoder) { ov::op::util::FrameworkNodeAttrs attrs; attrs.set_type_name("PTFrameworkNode"); - attrs["PtTypeName"] = m_decoder->get_op_type(); - attrs["PtSchema"] = m_decoder->get_schema(); + if (is_backprop) { + attrs["PtTypeName"] = m_decoder->get_op_type() + "_backprop"; + attrs["PtSchema"] = "None"; + } else { + attrs["PtTypeName"] = m_decoder->get_op_type(); + attrs["PtSchema"] = m_decoder->get_schema(); + } set_attrs(attrs); // Set output shapes and types if recognized diff --git a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp new file mode 100644 index 00000000000..c9ca09c376b --- /dev/null +++ b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.cpp @@ -0,0 +1,141 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "index_loop_getitem_replacer.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/loop.hpp" +#include "openvino/op/mod.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/op/reduce_sum.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +using namespace ov::pass; +using namespace ov::op; + +IndexLoopGetitemReplacer::IndexLoopGetitemReplacer() { + auto loop_pattern = pattern::wrap_type([](Output n) { + auto loop_op = ov::as_type_ptr(n.get_node_shared_ptr()); + bool check_len_input = false; + if (auto len_reduce = ov::as_type_ptr(loop_op->input_value(0).get_node_shared_ptr())) { + if (auto len_slice = ov::as_type_ptr(len_reduce->input_value(0).get_node_shared_ptr())) { + if (auto len_shape_of = ov::as_type_ptr(len_slice->input_value(0).get_node_shared_ptr())) { + check_len_input = true; + } + } + } + return check_len_input; + }); + ov::matcher_pass_callback callback = [](pattern::Matcher& m) { + auto loop_op = ov::as_type_ptr(m.get_match_root()); + std::shared_ptr chunk_op; + size_t chunk_idx = 0; + auto loop_inputs = loop_op->input_values(); + for (size_t i = 1; i < loop_inputs.size(); i++) { + if (cast_fw_node(loop_inputs.at(i).get_node_shared_ptr(), "aten::chunk")) { + chunk_op = loop_inputs.at(i).get_node_shared_ptr(); + chunk_idx = i; + break; + } + } + if (!chunk_op) + return false; + + auto body = loop_op->get_function(); + std::shared_ptr chunk_param; + for (auto input_desc : loop_op->get_input_descriptions()) { + if (input_desc->m_input_index == chunk_idx) { + chunk_param = body->get_parameters().at(input_desc->m_body_parameter_index); + break; + } + } + if (!chunk_param) + return false; + + auto param_targets = chunk_param->get_output_target_inputs(0); + if (param_targets.size() != 1) + return false; + + auto getitem = param_targets.begin()->get_node()->shared_from_this(); + if (!ov::as_type_ptr(getitem)) + return false; + + auto dim = chunk_op->input_value(2); + if (!ov::as_type_ptr(dim.get_node_shared_ptr())) + return false; + + // connect chunk input directly to loop + auto chunk_input = chunk_op->input_value(0); + chunk_op->output(0).replace(chunk_input); + // len(chunks) is number of iterations + auto chunks_outside = chunk_op->input_value(1); + loop_op->input_value(0).replace(chunks_outside); + + auto chunk_counter = getitem->input_value(1); + + pass::NodeRegistry rg; + auto tensor_0 = v0::Constant::create(element::i32, Shape{1}, {0}); + auto one_1d = v0::Constant::create(element::i32, Shape{1}, {1}); + + auto input_shape = rg.make(chunk_input, element::i32); + auto input_dimension = rg.make(input_shape, dim, tensor_0); + auto init_chunk_size = rg.make(input_dimension, chunks_outside, true); + + // Add 1 if input is not evenly divisible by chunks + auto last_chunk_size = rg.make(input_dimension, chunks_outside); + auto is_last_nonzero = rg.make(last_chunk_size, tensor_0); + auto is_last_nonzero_int = rg.make(is_last_nonzero, element::i32); + auto chunk_size = rg.make(init_chunk_size, is_last_nonzero_int); + auto dim_1d = rg.make(dim, one_1d, false); + + // Add new inputs in Loop: chunk_size and dim_1d + auto inp_descs = loop_op->get_input_descriptions(); + auto chunks_size_body = rg.make(element::i32, Shape{1}); + auto dim_body = rg.make(dim.get_element_type(), Shape{1}); + body->add_parameters({chunks_size_body, dim_body}); + loop_op->set_argument(loop_op->get_input_size(), chunk_size); + loop_op->set_argument(loop_op->get_input_size(), dim_1d); + inp_descs.push_back(std::make_shared( + loop_op->get_input_size() - 2, + body->get_parameters().size() - 2)); + inp_descs.push_back(std::make_shared( + loop_op->get_input_size() - 1, + body->get_parameters().size() - 1)); + loop_op->set_input_descriptions(0, inp_descs); + + auto start = rg.make(chunk_counter, chunks_size_body); + auto stop = rg.make(start, chunks_size_body); + auto curr_chunk = rg.make(chunk_param, start, stop, one_1d, dim_body); + replace_node(getitem, curr_chunk); + copy_runtime_info({chunk_op, getitem}, rg.get()); + curr_chunk->set_friendly_name(getitem->get_friendly_name()); + return true; + }; + + auto m = std::make_shared(loop_pattern, "ov::frontend::pytorch::pass::IndexLoopGetitemReplacer"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.hpp b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.hpp new file mode 100644 index 00000000000..9a8337b4ac2 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/index_loop_getitem_replacer.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +/** + * @brief IndexLoopGetitemReplacer transformation replaces following graph: + * aten::chunk->prim::Loop(aten::__getitem__) to Slice inside the Loop + */ +class IndexLoopGetitemReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::IndexLoopGetitemReplacer"); + IndexLoopGetitemReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp index 6b1792f7a63..719a3a2a1fb 100644 --- a/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/listconstruct_replacer.cpp @@ -9,7 +9,10 @@ #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/equal.hpp" +#include "openvino/op/interpolate.hpp" +#include "openvino/op/multiply.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/roll.hpp" #include "openvino/op/select.hpp" @@ -52,6 +55,11 @@ ListConstructReplacer::ListConstructReplacer() { auto transpose_op = pattern::wrap_type({pattern::any_input(), list}); // aten::split_with_sizes case auto vsplit_op = pattern::wrap_type({pattern::any_input(), pattern::any_input(), list}); + // aten::upsample... case inside the body when body was removed + auto interpolate_convert_op = pattern::wrap_type({list}); + auto interpolate_mul_op = pattern::wrap_type({interpolate_convert_op, pattern::any_input()}); + auto interpolate_op = + pattern::wrap_type({pattern::any_input(), interpolate_mul_op, pattern::any_input()}); auto lc_pattern = std::make_shared(OutputVector{reshape_op, roll_op, broadcast_op, @@ -61,7 +69,8 @@ ListConstructReplacer::ListConstructReplacer() { select_op, tile_op, transpose_op, - vsplit_op}); + vsplit_op, + interpolate_op}); ov::matcher_pass_callback callback = [=](pattern::Matcher& m) { auto& pattern_map = m.get_pattern_value_map(); diff --git a/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp b/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp index cb01409b0bd..11315353a3b 100644 --- a/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp +++ b/src/frontends/pytorch/src/transforms/string_equality_replacer.cpp @@ -6,6 +6,7 @@ #include "openvino/core/rt_info.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/convert_like.hpp" #include "openvino/op/equal.hpp" #include "openvino/op/not_equal.hpp" @@ -25,9 +26,16 @@ using namespace ov::op; StringEqualityReplacer::StringEqualityReplacer() { auto framework_node_lhs = pattern::wrap_type(); auto framework_node_rhs = pattern::wrap_type(); - auto convert_like = pattern::wrap_type({framework_node_rhs, framework_node_lhs}); - auto equal_op = pattern::wrap_type({framework_node_lhs, convert_like}); - auto not_equal_op = pattern::wrap_type({framework_node_lhs, convert_like}); + auto convert_lhs = pattern::wrap_type({framework_node_lhs}); + auto convert_like_lhs = pattern::wrap_type({framework_node_lhs, framework_node_rhs}); + auto convert_rhs = pattern::wrap_type({framework_node_rhs}); + auto convert_like_rhs = pattern::wrap_type({framework_node_rhs, framework_node_lhs}); + auto lhs_pattern = + std::make_shared(OutputVector{framework_node_lhs, convert_lhs, convert_like_lhs}); + auto rhs_pattern = + std::make_shared(OutputVector{framework_node_rhs, convert_rhs, convert_like_rhs}); + auto equal_op = pattern::wrap_type({lhs_pattern, rhs_pattern}); + auto not_equal_op = pattern::wrap_type({lhs_pattern, rhs_pattern}); auto string_equality_pattern = std::make_shared(OutputVector{equal_op, not_equal_op}); diff --git a/src/frontends/pytorch/src/translate_session.cpp b/src/frontends/pytorch/src/translate_session.cpp index 2e4abe29fc8..9431bb51b0e 100644 --- a/src/frontends/pytorch/src/translate_session.cpp +++ b/src/frontends/pytorch/src/translate_session.cpp @@ -6,11 +6,18 @@ #include "input_model.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" #include "openvino/op/parameter.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reduce_prod.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/result.hpp" +#include "openvino/op/scatter_nd_update.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" #include "openvino/op/transpose.hpp" #include "openvino/util/log.hpp" +#include "pt_framework_node.hpp" #include "utils.hpp" namespace ov { @@ -143,13 +150,41 @@ std::shared_ptr TranslateSession::convert_pytorch_model( context.get_op_type(), " outputs greater then number of converted outputs."); - // TODO: Make sure that mapping of fw_outputs to converted_outputs does always work - // FIXME: Now it is not true for at least prim::Constant for (size_t i = 0; i < fw_outputs.size(); ++i) { size_t fw_tensor_id = node->output(i); + if (node->inputs().size() > 0 && node->may_produce_alias(0, i)) { + // TODO: do we need to check other inputs, not only 0? + auto in_tensor_id = node->inputs().at(0); + if (m_may_be_alias.count(fw_tensor_id)) { + size_t recorded_in_tensor_id; + std::shared_ptr recorded_node; + std::tie(recorded_in_tensor_id, recorded_node, std::ignore) = m_may_be_alias.at(fw_tensor_id); + FRONT_END_GENERAL_CHECK(recorded_in_tensor_id == in_tensor_id, + "Operation ", + context.get_op_type(), + " creates alias to tensor which was already created before by ", + recorded_node->get_op_type(), + ", but from different tensor: ", + in_tensor_id, + " vs ", + recorded_in_tensor_id); + } + m_may_be_alias[fw_tensor_id] = {node->inputs().at(0), node, converted_outputs[i]}; + OPENVINO_DEBUG << "Registered alias: " << fw_tensor_id << " of tensor: " << node->inputs().at(0) + << " of operation: " << context.get_op_type(); + } FRONT_END_GENERAL_CHECK(tensor_map->find(fw_tensor_id) == tensor_map->end(), "Duplicated producer for PT value with unique ID: ", fw_tensor_id); + auto out_type = context.get_output_type(i); + if (out_type.is()) { + if (!converted_outputs[i].get_element_type().compatible(out_type.as())) { + OPENVINO_DEBUG << "[WARNING] Produced output type for operation " << context.get_op_type() + << " for tensor id: " << fw_tensor_id << " is incompatible: produced " + << converted_outputs[i].get_element_type() << " vs " + << out_type.as(); + } + } (*tensor_map)[fw_tensor_id] = converted_outputs[i]; encode_tensor_name(converted_outputs[i], fw_tensor_id, {node->get_output_debug_name(i)}); } @@ -197,6 +232,8 @@ std::shared_ptr TranslateSession::convert_pytorch_model( // additional outputs in that case. if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty()) results.push_back(std::make_shared(tensor_map->at(tensor_id))); + } else { + OPENVINO_DEBUG << "Mutated tensor with id " << tensor_id << " doesn't exist in inputs, skipping."; } } resulting_model = std::make_shared(results, *parameters); @@ -215,10 +252,9 @@ OutputVector TranslateSession::convert_node(const NodeContext& context) { } catch (std::exception& e) { OPENVINO_DEBUG << "Exception happened during conversion of op: " << context.get_op_type() - << " with schema: " << context.get_schema() << ": " << e.what() << '\n'; + << " with schema: " << context.get_schema() << ": " << e.what(); } catch (...) { - OPENVINO_DEBUG << "Some exception happened during conversion of node of type: " << context.get_op_type() - << '\n'; + OPENVINO_DEBUG << "Some exception happened during conversion of node of type: " << context.get_op_type(); } // Create PtFrameworkNode for everything that wasn't able to be converted normally return make_framework_node(context); @@ -228,7 +264,9 @@ void TranslateSession::encode_tensor_name(Output output, size_t tensor_idx, std::vector additional_names) { if (!output.get_names().empty()) { - OPENVINO_DEBUG << "Tensor names already exist: " << output.get_any_name() << ". Rewriting with " << tensor_idx; + OPENVINO_DEBUG << "Tensor names already exist: " << output.get_any_name() << ". Will not be rewritten with " + << tensor_idx << ". This is likely a mutated tensor."; + return; } auto name = std::to_string(tensor_idx); std::unordered_set names; @@ -243,7 +281,6 @@ void TranslateSession::encode_tensor_name(Output output, pair.second.set_names({new_name}); pair.second = output; output.set_names(names); - } else { m_counter_map[tensor_idx] = {0, output}; output.set_names(names); @@ -257,6 +294,110 @@ size_t TranslateSession::decode_tensor_name(const Output& output) { return static_cast(std::stoll(name)); } +namespace { +Output slice_backprop(const Output& slice_output, const Output& value) { + auto slice_node = slice_output.get_node_shared_ptr(); + FRONT_END_OP_CONVERSION_CHECK(ov::as_type_ptr(slice_node), + "Conversion rule for aten::slice doesn't contain Slice node."); + + auto zero = v0::Constant::create(element::i64, Shape{}, {0}); + auto one = v0::Constant::create(element::i64, Shape{}, {1}); + auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1}); + auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1}); + + // Get 1d indices [0..numel) + auto to_insert_data = slice_node->input_value(0); + auto input_shape = std::make_shared(to_insert_data, element::i64); + auto numel = std::make_shared(input_shape, zero, false); + auto full_data_indices_1d = std::make_shared(zero, numel, one, element::i64); + + // Slice indices by same start, stop, slice, axes as initial Slice + auto full_data_indices = std::make_shared(full_data_indices_1d, input_shape, false); + Output data_indices; + if (slice_node->get_input_size() == 5) { + data_indices = std::make_shared(full_data_indices, + slice_node->input_value(1), + slice_node->input_value(2), + slice_node->input_value(3), + slice_node->input_value(4)); + } else if (slice_node->get_input_size() == 4) { + data_indices = std::make_shared(full_data_indices, + slice_node->input_value(1), + slice_node->input_value(2), + slice_node->input_value(3)); + } else { + FRONT_END_OP_CONVERSION_CHECK(false, "Incorrect number of Slice inputs"); + } + + // Scatter in flattened tensor with indices and flattened data to be inserted + auto to_insert_data_1d = std::make_shared(to_insert_data, neg_one_1d, false); + auto data_indices_1d = std::make_shared(data_indices, scattering_shape, false); + auto to_be_inserted_data_1d = std::make_shared(value, neg_one_1d, false); + auto updated_data_1d = + std::make_shared(to_insert_data_1d, data_indices_1d, to_be_inserted_data_1d); + + // Reshape to initial shape + return std::make_shared(updated_data_1d, input_shape, false); +} + +Output select_backprop(const Output& select_output, const Output& value) { + auto gather_node = select_output.get_node_shared_ptr(); + FRONT_END_OP_CONVERSION_CHECK(ov::as_type_ptr(gather_node), + "Conversion rule for aten::select doesn't contain Gather node."); + + auto zero = v0::Constant::create(element::i64, Shape{}, {0}); + auto one = v0::Constant::create(element::i64, Shape{}, {1}); + auto neg_one_1d = v0::Constant::create(element::i64, Shape{1}, {-1}); + auto scattering_shape = v0::Constant::create(element::i64, Shape{2}, {-1, 1}); + + // Get 1d indices [0..numel) + auto to_insert_data = gather_node->input_value(0); + auto input_shape = std::make_shared(to_insert_data, element::i64); + auto numel = std::make_shared(input_shape, zero, false); + auto full_data_indices_1d = std::make_shared(zero, numel, one, element::i64); + + // Slice indices by same start, stop, slice, axes as initial Slice + auto full_data_indices = std::make_shared(full_data_indices_1d, input_shape, false); + Output data_indices = + std::make_shared(full_data_indices, gather_node->input_value(1), gather_node->input_value(2)); + + // Scatter in flattened tensor with indices and flattened data to be inserted + auto to_insert_data_1d = std::make_shared(to_insert_data, neg_one_1d, false); + auto data_indices_1d = std::make_shared(data_indices, scattering_shape, false); + auto to_be_inserted_data_1d = std::make_shared(value, neg_one_1d, false); + auto updated_data_1d = + std::make_shared(to_insert_data_1d, data_indices_1d, to_be_inserted_data_1d); + + // Reshape to initial shape + return std::make_shared(updated_data_1d, input_shape, false); +} +} // namespace + +using BackpropCreatorFunction = std::function(const Output&, const Output&)>; + +Output TranslateSession::get_backprop_op(const std::shared_ptr& node, + const Output& direct_op_output, + const Output& value) { + std::map backprop_map = { + {"aten::slice", slice_backprop}, + {"aten::select", select_backprop}, + }; + + Output backprop_node; + try { + auto it = backprop_map.find(node->get_op_type()); + if (it != backprop_map.end()) { + return it->second(direct_op_output, value); + } + + } catch (std::exception& e) { + OPENVINO_DEBUG << "Exception happened during conversion of backprop op: " << node->get_op_type() + << " with schema: " << node->get_schema() << ": " << e.what(); + } + // Create PtFrameworkNode representing unconverted backprop operation + return std::make_shared(node, OutputVector{value}, 1, true); +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/translate_session.hpp b/src/frontends/pytorch/src/translate_session.hpp index fb1c06a48e8..940edf1d867 100644 --- a/src/frontends/pytorch/src/translate_session.hpp +++ b/src/frontends/pytorch/src/translate_session.hpp @@ -35,13 +35,25 @@ public: const TensorMap& external_tensor_map = {}, const std::unordered_map& external_descriptors = {}); + /// \brief Returns backprop operations for direct operation + Output get_backprop_op(const std::shared_ptr& node, + const Output& direct_op_output, + const Output& value); + + /// \brief Writes pytorch tensor index into openvino tensor void encode_tensor_name(Output tensor_desc, size_t tensor_idx, std::vector additional_names = {}); + + /// \brief Gets pytorch tensor index from openvino tensor size_t decode_tensor_name(const Output& tensor_desc); size_t m_friendly_name_counter = 0; + // Maps tensor index to initial tensor index which it is alias to, and to decoder of the node produced this alias + // and to the output produced during conversion of this node + std::map, Output>> m_may_be_alias; + private: OutputVector convert_node(const NodeContext& context); diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index dcfcce0d3c2..cf5b7da6592 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -339,6 +339,19 @@ std::unordered_map bit_to_int{ void align_eltwise_input_types(const NodeContext& context, Output& lhs, Output& rhs, bool align_scalars) { const auto& lhs_type = lhs.get_element_type(); const auto& rhs_type = rhs.get_element_type(); + auto out_type = context.get_output_type(0); + if (out_type.is()) { + auto otype = out_type.as(); + if (otype.is_real()) { + if (otype != lhs_type) { + lhs = context.mark_node(std::make_shared(lhs, otype)); + } + if (otype != lhs_type) { + rhs = context.mark_node(std::make_shared(rhs, otype)); + } + return; + } + } if (lhs_type.is_dynamic() || rhs_type.is_dynamic()) { // if any of types is not known, align to lhs type. // TODO: can be fixed with special operation? @@ -410,6 +423,18 @@ void align_eltwise_input_types(const NodeContext& context, Output& lhs, Ou } } +void align_output_types(const NodeContext& context, OutputVector& outputs) { + for (size_t i = 0; i < outputs.size(); i++) { + auto dtype_any = context.get_output_type(i); + if (dtype_any.is()) { + auto dtype = dtype_any.as(); + if (dtype.is_static() && dtype != outputs[i].get_element_type()) { + outputs[i] = std::make_shared(outputs[i], dtype); + } + } + } +} + std::deque> get_list_as_outputs(const Output& start) { std::deque> res; auto current_output = start; diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index aea3fd505c5..3762079b259 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -6,6 +6,7 @@ #include "openvino/frontend/pytorch/node_context.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" namespace ov { @@ -64,6 +65,8 @@ void align_eltwise_input_types(const NodeContext& context, Output& rhs, bool align_scalars = false); +void align_output_types(const NodeContext& context, OutputVector& outputs); + std::deque> get_list_as_outputs(const Output& start); namespace op { @@ -79,7 +82,15 @@ OutputVector inplace_op(const NodeContext& context) { template OutputVector translate_1to1_match_1_inputs(const NodeContext& context) { FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None."); - return {context.mark_node(std::make_shared(context.get_input(0)))}; + auto res = context.mark_node(std::make_shared(context.get_input(0))); + auto out_type = context.get_output_type(0); + if (out_type.is()) { + auto dtype = out_type.as(); + if (dtype.is_static() && dtype != res->output(0).get_element_type()) { + res = context.mark_node(std::make_shared(res, dtype)); + } + } + return {res}; } template @@ -106,7 +117,9 @@ OutputVector translate_1to1_match_2_inputs_align_types(const NodeContext& contex auto lhs = context.get_input(0); auto rhs = context.get_input(1); align_eltwise_input_types(context, lhs, rhs, true); - return {context.mark_node(std::make_shared(lhs, rhs))}; + OutputVector res = {context.mark_node(std::make_shared(lhs, rhs))}; + align_output_types(context, res); + return res; } inline OutputVector return_false_scalar(const NodeContext& context) { diff --git a/tests/layer_tests/pytorch_tests/test_aliases.py b/tests/layer_tests/pytorch_tests/test_aliases.py new file mode 100644 index 00000000000..1919aab5fbc --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_aliases.py @@ -0,0 +1,38 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class aten_alias(torch.nn.Module): + def forward(self, x): + x[:, 1, :, :] = 4. + return x + + +class aten_loop_alias(torch.nn.Module): + def forward(self, x): + for i in range(2): + x[:, i, :, :] = 4. + return x + + +class TestAliases(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_alias(self, ie_device, precision, ir_version): + self._test(aten_alias(), None, [ + "aten::slice", "aten::select", "aten::copy_"], ie_device, precision, ir_version) + + @pytest.mark.nightly + @pytest.mark.precommit + def test_loop_alias(self, ie_device, precision, ir_version): + self._test(aten_loop_alias(), None, [ + "aten::slice", "aten::select", "aten::copy_", "prim::Loop"], ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_cat.py b/tests/layer_tests/pytorch_tests/test_cat.py index cc314bf1b4d..b1d3fcef5ea 100644 --- a/tests/layer_tests/pytorch_tests/test_cat.py +++ b/tests/layer_tests/pytorch_tests/test_cat.py @@ -54,6 +54,7 @@ class TestCat(PytorchLayerTest): @pytest.mark.nightly @pytest.mark.precommit + @pytest.mark.xfail(reason="Transformation RemoveMultiSubGraphOpDanglingParamsResults doesn't support removing unused merged inputs, ticket 112833.") def test_loop_append_cat(self, ie_device, precision, ir_version): self._test(aten_loop_append_cat(), None, ["aten::cat", "aten::append", "prim::ListConstruct", "prim::Loop"], ie_device, precision, ir_version, freeze_model=False) diff --git a/tests/layer_tests/pytorch_tests/test_chunk.py b/tests/layer_tests/pytorch_tests/test_chunk.py index 50b09c30789..7da36d568d9 100644 --- a/tests/layer_tests/pytorch_tests/test_chunk.py +++ b/tests/layer_tests/pytorch_tests/test_chunk.py @@ -120,3 +120,41 @@ class TestChunk(PytorchLayerTest): for idx in [0, 1, output_chunks - 1]: self._test(aten_chunk_getitem(chunks, dim, idx), None, "aten::chunk", ie_device, precision, ir_version) + + +class aten_chunk_loop_getitem(torch.nn.Module): + def __init__(self, num_chunks) -> None: + torch.nn.Module.__init__(self) + self.num_chunks = num_chunks + + def forward(self, input_tensor): + chunks = torch.chunk(torch.arange( + input_tensor.shape[0]), self.num_chunks) + + for inds in chunks: + input_tensor[inds] *= 10 + return input_tensor + + +class TestChunkLoopGetitem(PytorchLayerTest): + def _prepare_input(self): + return (np.random.rand(*self.input_shape),) + + @pytest.mark.parametrize("input_shape", [ + (4, 4), + (5, 9, 7), + (10, 13, 11), + (8, 7, 6, 5, 4), + ]) + @pytest.mark.parametrize("chunks", [ + 2, + 3, + 4 + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_chunk_loop_getitem(self, input_shape, chunks, ie_device, precision, ir_version): + self.input_shape = input_shape + + self._test(aten_chunk_loop_getitem(chunks), None, ["aten::chunk", "prim::Loop", "aten::__getitem__"], + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_copy.py b/tests/layer_tests/pytorch_tests/test_copy.py new file mode 100644 index 00000000000..b78af602712 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_copy.py @@ -0,0 +1,33 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestCopy(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self, value): + import torch + + class aten_copy(torch.nn.Module): + def __init__(self, value): + super(aten_copy, self).__init__() + self.value = torch.tensor(value) + + def forward(self, x): + return x.copy_(self.value) + + ref_net = None + + return aten_copy(value), ref_net, "aten::copy_" + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("value", [1, [2.5], range(224)]) + def test_copy_(self, value, ie_device, precision, ir_version): + self._test(*self.create_model(value), ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_derive_index_range_length.py b/tests/layer_tests/pytorch_tests/test_derive_index_range_length.py index fbabd072e30..8ae5f11d267 100644 --- a/tests/layer_tests/pytorch_tests/test_derive_index_range_length.py +++ b/tests/layer_tests/pytorch_tests/test_derive_index_range_length.py @@ -40,7 +40,7 @@ class TestDeriveIndexRangeLength(PytorchLayerTest): step = int(x[2]) accumulator = 0 for idx in range(start, stop, step): - accumulator = idx + accumulator += idx return accumulator ref_net = None diff --git a/tests/layer_tests/pytorch_tests/test_full.py b/tests/layer_tests/pytorch_tests/test_full.py index e35d60c91bc..4ce42db7fa9 100644 --- a/tests/layer_tests/pytorch_tests/test_full.py +++ b/tests/layer_tests/pytorch_tests/test_full.py @@ -130,6 +130,31 @@ class TestFill(PytorchLayerTest): self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={'value': value, 'shape': shape, "input_dtype": input_dtype, "value_dtype": value_dtype}) +class TestZero(PytorchLayerTest): + def _prepare_input(self, shape, input_dtype): + return (np.random.randn(*shape).astype(input_dtype),) + + def create_model(self): + import torch + + class aten_zero(torch.nn.Module): + + def forward(self, input_t: torch.Tensor): + return input_t.zero_() + ref_net = None + + model = aten_zero() + + return model, ref_net, "aten::zero_" + + @pytest.mark.parametrize("shape", [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5, 6]]) + @pytest.mark.parametrize("input_dtype", ["int8", "int32", "int64", "float32", "float64"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_zero(self, shape, input_dtype, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, + kwargs_to_prepare_input={'shape': shape, "input_dtype": input_dtype}) + class TestFullLike(PytorchLayerTest): def _prepare_input(self, value, shape): return (np.random.randn(*shape).astype(np.float32), np.array(value, dtype=np.float32),) diff --git a/tests/layer_tests/pytorch_tests/test_unary_ops.py b/tests/layer_tests/pytorch_tests/test_unary_ops.py index c1bf42d96b8..2f1e75753b1 100644 --- a/tests/layer_tests/pytorch_tests/test_unary_ops.py +++ b/tests/layer_tests/pytorch_tests/test_unary_ops.py @@ -64,6 +64,7 @@ class TestUnaryOp(PytorchLayerTest): @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("op,op_type", [ # some pytorch inplace ops do not support int + (torch.exp_, "aten::exp_"), (torch.sigmoid_, "aten::sigmoid_"), # trigonometry (torch.cos_, "aten::cos_"),