[PT FE] Make NodeContext constant inside conversion rules (#16165)

* Make NodeContext constant inside conversion rules

* Use shared_ptr

* Fix ptr

* Fix logical not
This commit is contained in:
Maxim Vafin 2023-03-20 22:08:24 +01:00 committed by GitHub
parent 4ffecce63f
commit 7f8786d9aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 173 additions and 180 deletions

View File

@ -60,7 +60,7 @@ protected:
bool supported_impl(const std::vector<ov::Any>& variants) const override; bool supported_impl(const std::vector<ov::Any>& variants) const override;
ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override; ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override;
std::map<std::string, PytorchCreatorFunction> m_op_translators; std::map<std::string, CreatorFunction> m_op_translators;
}; };
} // namespace pytorch } // namespace pytorch

View File

@ -19,20 +19,22 @@ typedef std::unordered_map<size_t, Output<Node>> TensorMap;
class NodeContext : public frontend::NodeContext { class NodeContext : public frontend::NodeContext {
public: public:
NodeContext(std::shared_ptr<TorchDecoder> decoder, NodeContext(std::shared_ptr<TorchDecoder> decoder,
TensorMap* tensor_map,
ParameterVector* external_parameters,
const TensorMap& ext_tensor_map, const TensorMap& ext_tensor_map,
std::shared_ptr<TensorMap> tensor_map,
std::shared_ptr<ParameterVector> external_parameters,
std::shared_ptr<std::set<size_t>> mutated_tensors,
TranslateSession* translate_session) TranslateSession* translate_session)
: frontend::NodeContext(decoder->get_op_type()), : frontend::NodeContext(decoder->get_op_type()),
m_decoder(decoder), m_decoder(decoder),
m_tensor_map(tensor_map),
m_ext_tensor_map(ext_tensor_map), m_ext_tensor_map(ext_tensor_map),
m_tensor_map(tensor_map),
m_external_parameters(external_parameters), m_external_parameters(external_parameters),
m_mutated_tensors(mutated_tensors),
m_translate_session(translate_session), m_translate_session(translate_session),
m_decoder_inputs(decoder->inputs()), m_decoder_inputs(decoder->inputs()),
m_decoder_outputs(decoder->outputs()) { m_decoder_outputs(decoder->outputs()) {
FRONT_END_GENERAL_CHECK(tensor_map != nullptr && external_parameters != nullptr && FRONT_END_GENERAL_CHECK(m_tensor_map != nullptr && m_external_parameters != nullptr &&
translate_session != nullptr); m_mutated_tensors != nullptr && m_translate_session != nullptr);
} }
// Do not search for input in tensor map; try to access it as a constant of specified type T and return its value // Do not search for input in tensor map; try to access it as a constant of specified type T and return its value
@ -106,11 +108,7 @@ public:
"There is no any named attributes in PyTorch node, query by attribute name is not implemented"); "There is no any named attributes in PyTorch node, query by attribute name is not implemented");
} }
void mutate_input(size_t index, Output<Node> ov_output); void mutate_input(size_t index, Output<Node> ov_output) const;
std::set<size_t> get_mutated_tensors() const {
return m_mutated_tensors;
}
std::shared_ptr<TorchDecoder> get_decoder() const { std::shared_ptr<TorchDecoder> get_decoder() const {
return m_decoder; return m_decoder;
@ -120,7 +118,7 @@ public:
return m_translate_session; return m_translate_session;
} }
void add_tensor_to_context(size_t index, Output<Node> ov_output); void add_tensor_to_context(size_t index, Output<Node> ov_output) const;
Output<Node> get_tensor_from_model(size_t index) const { Output<Node> get_tensor_from_model(size_t index) const {
if (m_tensor_map->find(index) != m_tensor_map->end()) { if (m_tensor_map->find(index) != m_tensor_map->end()) {
@ -130,22 +128,22 @@ public:
} }
} }
Output<Node> get_tensor_from_model_or_create_input(size_t index); Output<Node> get_tensor_from_model_or_create_input(size_t index) const;
Output<Node> get_input_from_visible_context(size_t index) const; Output<Node> get_input_from_visible_context(size_t index) const;
std::shared_ptr<ov::Model> convert_subgraph(size_t index); std::shared_ptr<ov::Model> convert_subgraph(size_t index) const;
private: private:
std::shared_ptr<TorchDecoder> m_decoder; std::shared_ptr<TorchDecoder> m_decoder;
std::set<size_t> m_mutated_tensors;
TensorMap* m_tensor_map;
const TensorMap& m_ext_tensor_map; const TensorMap& m_ext_tensor_map;
ParameterVector* m_external_parameters; std::shared_ptr<TensorMap> m_tensor_map;
TranslateSession* m_translate_session; std::shared_ptr<ParameterVector> m_external_parameters;
std::shared_ptr<std::set<size_t>> m_mutated_tensors;
TranslateSession* m_translate_session = nullptr;
const std::vector<size_t> m_decoder_inputs; const std::vector<size_t> m_decoder_inputs;
const std::vector<size_t> m_decoder_outputs; const std::vector<size_t> m_decoder_outputs;
}; };
using PytorchCreatorFunction = std::function<OutputVector(NodeContext&)>; using CreatorFunction = std::function<ov::OutputVector(const ov::frontend::pytorch::NodeContext&)>;
} // namespace pytorch } // namespace pytorch
} // namespace frontend } // namespace frontend

View File

@ -42,16 +42,16 @@ std::shared_ptr<Node> NodeContext::mark_node(std::shared_ptr<Node> ov_node) cons
return m_decoder->mark_node(ov_node); return m_decoder->mark_node(ov_node);
} }
void NodeContext::mutate_input(size_t index, Output<Node> ov_output) { void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index); FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
auto input_id = m_decoder_inputs.at(index); auto input_id = m_decoder_inputs.at(index);
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist."); FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist.");
m_translate_session->encode_tensor_name(ov_output, input_id, m_decoder->get_input_debug_name(index)); m_translate_session->encode_tensor_name(ov_output, input_id, m_decoder->get_input_debug_name(index));
(*m_tensor_map)[input_id] = ov_output; (*m_tensor_map)[input_id] = ov_output;
m_mutated_tensors.insert(input_id); m_mutated_tensors->insert(input_id);
} }
void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) { void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) const {
if (m_tensor_map->count(index)) { if (m_tensor_map->count(index)) {
OPENVINO_DEBUG << "[ WARNING ] Current context has tensor. Rewriting.\n"; OPENVINO_DEBUG << "[ WARNING ] Current context has tensor. Rewriting.\n";
} }
@ -59,7 +59,7 @@ void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) {
(*m_tensor_map)[index] = ov_output; (*m_tensor_map)[index] = ov_output;
} }
Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) { Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) const {
if (m_tensor_map->find(index) != m_tensor_map->end()) { if (m_tensor_map->find(index) != m_tensor_map->end()) {
return m_tensor_map->at(index); return m_tensor_map->at(index);
} else { } else {
@ -87,7 +87,7 @@ Output<Node> NodeContext::get_input_from_visible_context(size_t index) const {
return input_tensor; return input_tensor;
} }
std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) { std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
auto subgraph_decoder = m_decoder->get_subgraph_decoder(index); auto subgraph_decoder = m_decoder->get_subgraph_decoder(index);
// Extend external context with internal tensors except Parameter nodes, because internal Parameters are created to // Extend external context with internal tensors except Parameter nodes, because internal Parameters are created to

View File

@ -19,7 +19,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_adaptive_avg_pool3d(NodeContext& context) { OutputVector translate_adaptive_avg_pool3d(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1})); auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_adaptive_max_pool2d(NodeContext& context) { OutputVector translate_adaptive_max_pool2d(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -15,7 +15,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_add(NodeContext& context) { OutputVector translate_add(const NodeContext& context) {
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
auto lhs = context.get_input(0); auto lhs = context.get_input(0);
auto rhs = context.get_input(1); auto rhs = context.get_input(1);

View File

@ -17,7 +17,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_addcmul(NodeContext& context) { OutputVector translate_addcmul(const NodeContext& context) {
num_inputs_check(context, 4, 4); num_inputs_check(context, 4, 4);
const auto eltwise_mult = std::make_shared<v1::Multiply>(context.get_input(1), context.get_input(2)); const auto eltwise_mult = std::make_shared<v1::Multiply>(context.get_input(1), context.get_input(2));
const auto value = context.get_input(3); const auto value = context.get_input(3);

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_addmm(NodeContext& context) { OutputVector translate_addmm(const NodeContext& context) {
num_inputs_check(context, 5, 5); num_inputs_check(context, 5, 5);
auto input = context.get_input(0); auto input = context.get_input(0);
auto m1 = context.get_input(1); auto m1 = context.get_input(1);

View File

@ -17,7 +17,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_arange(NodeContext& context) { OutputVector translate_arange(const NodeContext& context) {
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
int dtype_port = -1; int dtype_port = -1;

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_as_tensor(NodeContext& context) { OutputVector translate_as_tensor(const NodeContext& context) {
// aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor // aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor
num_inputs_check(context, 1, 4); num_inputs_check(context, 1, 4);
auto dtype = element::f32; auto dtype = element::f32;

View File

@ -18,7 +18,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_avg_poolnd(NodeContext& context) { OutputVector translate_avg_poolnd(const NodeContext& context) {
num_inputs_check(context, 6, 7); num_inputs_check(context, 6, 7);
auto input = context.get_input(0); auto input = context.get_input(0);
auto kernel = context.const_input<Shape>(1); auto kernel = context.const_input<Shape>(1);

View File

@ -32,7 +32,7 @@ Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
} }
} // namespace } // namespace
OutputVector translate_batch_norm(NodeContext& context) { OutputVector translate_batch_norm(const NodeContext& context) {
// Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, // Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var,
// bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor // bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
num_inputs_check(context, 8, 9); num_inputs_check(context, 8, 9);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_bitwise_not(NodeContext& context) { OutputVector translate_bitwise_not(const NodeContext& context) {
num_inputs_check(context, 1, 2); num_inputs_check(context, 1, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean), FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_bool(NodeContext& context) { OutputVector translate_bool(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::boolean))}; return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::boolean))};
}; };

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_cat(NodeContext& context) { OutputVector translate_cat(const NodeContext& context) {
// This translator is only needed to get axis as constant from external scope // This translator is only needed to get axis as constant from external scope
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
const auto&& list_elems = get_list_as_outputs(context.get_input(0)); const auto&& list_elems = get_list_as_outputs(context.get_input(0));

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_clamp(NodeContext& context) { OutputVector translate_clamp(const NodeContext& context) {
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
if (!context.input_is_none(1)) { if (!context.input_is_none(1)) {

View File

@ -9,7 +9,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_constant(NodeContext& context) { OutputVector translate_constant(const NodeContext& context) {
return context.as_constant(); return context.as_constant();
}; };

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_conv_transposend(NodeContext& context) { OutputVector translate_conv_transposend(const NodeContext& context) {
num_inputs_check(context, 8, 8); num_inputs_check(context, 8, 8);
auto strides = context.const_input<Strides>(3); auto strides = context.const_input<Strides>(3);
// PyTorch support only symmetric padding, padding sizes are the same for begins and ends for each dimension // PyTorch support only symmetric padding, padding sizes are the same for begins and ends for each dimension

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_convnd(NodeContext& context) { OutputVector translate_convnd(const NodeContext& context) {
num_inputs_check(context, 7, 7); num_inputs_check(context, 7, 7);
auto strides = context.const_input<Strides>(3); auto strides = context.const_input<Strides>(3);
// In torch pads at beginning are same as at end // In torch pads at beginning are same as at end

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_convolution(NodeContext& context) { OutputVector translate_convolution(const NodeContext& context) {
// Schema: aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] // Schema: aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[]
// dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool // dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool
// cudnn_enabled, bool allow_tf32) -> Tensor // cudnn_enabled, bool allow_tf32) -> Tensor

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_convolution_mode(NodeContext& context) { OutputVector translate_convolution_mode(const NodeContext& context) {
// Schema: aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[] // Schema: aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[]
// dilation, int groups) -> Tensor // dilation, int groups) -> Tensor
num_inputs_check(context, 7, 7); num_inputs_check(context, 7, 7);

View File

@ -13,7 +13,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_cumsum(NodeContext& context) { OutputVector translate_cumsum(const NodeContext& context) {
// aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None, Tensor out=None) // aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None, Tensor out=None)
num_inputs_check(context, 2, 4); num_inputs_check(context, 2, 4);
auto x = context.get_input(0); auto x = context.get_input(0);

View File

@ -12,7 +12,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_dim(NodeContext& context) { OutputVector translate_dim(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
Output<Node> rank; Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true); std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);

View File

@ -17,7 +17,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_div(NodeContext& context) { OutputVector translate_div(const NodeContext& context) {
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_elu(NodeContext& context) { OutputVector translate_elu(const NodeContext& context) {
// aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor // aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
num_inputs_check(context, 2, 4); num_inputs_check(context, 2, 4);
auto x = context.get_input(0); auto x = context.get_input(0);

View File

@ -13,7 +13,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_embedding(NodeContext& context) { OutputVector translate_embedding(const NodeContext& context) {
// aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool
// sparse=False) // sparse=False)
num_inputs_check(context, 5, 5); num_inputs_check(context, 5, 5);

View File

@ -30,7 +30,7 @@ OutputVector base_expand(const NodeContext& context, const Output<Node>& x, cons
}; };
} // namespace } // namespace
OutputVector translate_expand(NodeContext& context) { OutputVector translate_expand(const NodeContext& context) {
// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
@ -41,7 +41,7 @@ OutputVector translate_expand(NodeContext& context) {
return base_expand(context, x, sizes); return base_expand(context, x, sizes);
}; };
OutputVector translate_expand_as(NodeContext& context) { OutputVector translate_expand_as(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_eye(NodeContext& context) { OutputVector translate_eye(const NodeContext& context) {
size_t num_inputs = context.get_input_size(); size_t num_inputs = context.get_input_size();
auto x = context.get_input(0); auto x = context.get_input(0);
// num rows and cols should be integer, but at the moment conversion their data type can be unknown yet // num rows and cols should be integer, but at the moment conversion their data type can be unknown yet

View File

@ -18,7 +18,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_flatten(NodeContext& context) { OutputVector translate_flatten(const NodeContext& context) {
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
int64_t start_dim = 0; int64_t start_dim = 0;

View File

@ -14,7 +14,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_floor_divide(NodeContext& context) { OutputVector translate_floor_divide(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_floordiv(NodeContext& context) { OutputVector translate_floordiv(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -42,7 +42,7 @@ Output<Node> base_translate_full_with_convert(const NodeContext& context,
} }
} // namespace } // namespace
OutputVector translate_full(NodeContext& context) { OutputVector translate_full(const NodeContext& context) {
num_inputs_check(context, 2, 6); num_inputs_check(context, 2, 6);
auto sizes = context.get_input(0); auto sizes = context.get_input(0);
auto value = context.get_input(1); auto value = context.get_input(1);
@ -59,7 +59,7 @@ OutputVector translate_full(NodeContext& context) {
return {base_translate_full_with_convert(context, sizes, value, dtype_id)}; return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
}; };
OutputVector translate_full_like(NodeContext& context) { OutputVector translate_full_like(const NodeContext& context) {
num_inputs_check(context, 2, 7); num_inputs_check(context, 2, 7);
auto input = context.get_input(0); auto input = context.get_input(0);
auto value = context.get_input(1); auto value = context.get_input(1);
@ -71,7 +71,7 @@ OutputVector translate_full_like(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, out)}; return {base_translate_full_with_convertlike(context, sizes, value, out)};
}; };
OutputVector translate_fill_(NodeContext& context) { OutputVector translate_fill_(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto input = context.get_input(0); auto input = context.get_input(0);
auto value = context.get_input(1); auto value = context.get_input(1);
@ -79,7 +79,7 @@ OutputVector translate_fill_(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, input)}; return {base_translate_full_with_convertlike(context, sizes, value, input)};
}; };
OutputVector translate_new_full(NodeContext& context) { OutputVector translate_new_full(const NodeContext& context) {
num_inputs_check(context, 3, 7); num_inputs_check(context, 3, 7);
auto input = context.get_input(0); auto input = context.get_input(0);
auto sizes = context.get_input(1); auto sizes = context.get_input(1);
@ -90,7 +90,7 @@ OutputVector translate_new_full(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, input)}; return {base_translate_full_with_convertlike(context, sizes, value, input)};
}; };
OutputVector translate_zeros(NodeContext& context) { OutputVector translate_zeros(const NodeContext& context) {
num_inputs_check(context, 2, 5); num_inputs_check(context, 2, 5);
auto sizes = context.get_input(0); auto sizes = context.get_input(0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
@ -107,7 +107,7 @@ OutputVector translate_zeros(NodeContext& context) {
return {base_translate_full_with_convert(context, sizes, value, dtype_id)}; return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
}; };
OutputVector translate_zeros_like(NodeContext& context) { OutputVector translate_zeros_like(const NodeContext& context) {
num_inputs_check(context, 1, 6); num_inputs_check(context, 1, 6);
auto input = context.get_input(0); auto input = context.get_input(0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0})); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
@ -119,7 +119,7 @@ OutputVector translate_zeros_like(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, out)}; return {base_translate_full_with_convertlike(context, sizes, value, out)};
}; };
OutputVector translate_new_zeros(NodeContext& context) { OutputVector translate_new_zeros(const NodeContext& context) {
num_inputs_check(context, 2, 6); num_inputs_check(context, 2, 6);
auto input = context.get_input(0); auto input = context.get_input(0);
auto sizes = context.get_input(1); auto sizes = context.get_input(1);
@ -130,7 +130,7 @@ OutputVector translate_new_zeros(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, input)}; return {base_translate_full_with_convertlike(context, sizes, value, input)};
}; };
OutputVector translate_ones(NodeContext& context) { OutputVector translate_ones(const NodeContext& context) {
num_inputs_check(context, 1, 5); num_inputs_check(context, 1, 5);
auto sizes = context.get_input(0); auto sizes = context.get_input(0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
@ -147,7 +147,7 @@ OutputVector translate_ones(NodeContext& context) {
return {base_translate_full_with_convert(context, sizes, value, dtype_id)}; return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
}; };
OutputVector translate_ones_like(NodeContext& context) { OutputVector translate_ones_like(const NodeContext& context) {
num_inputs_check(context, 1, 6); num_inputs_check(context, 1, 6);
auto input = context.get_input(0); auto input = context.get_input(0);
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1})); auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
@ -159,7 +159,7 @@ OutputVector translate_ones_like(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, out)}; return {base_translate_full_with_convertlike(context, sizes, value, out)};
}; };
OutputVector translate_new_ones(NodeContext& context) { OutputVector translate_new_ones(const NodeContext& context) {
num_inputs_check(context, 2, 6); num_inputs_check(context, 2, 6);
auto input = context.get_input(0); auto input = context.get_input(0);
auto sizes = context.get_input(1); auto sizes = context.get_input(1);
@ -170,7 +170,7 @@ OutputVector translate_new_ones(NodeContext& context) {
return {base_translate_full_with_convertlike(context, sizes, value, input)}; return {base_translate_full_with_convertlike(context, sizes, value, input)};
}; };
OutputVector translate_empty(NodeContext& context) { OutputVector translate_empty(const NodeContext& context) {
// aten::empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? // aten::empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
// pin_memory=None, MemoryFormat? memory_format=None) -> Tensor layout, device and work with memory ignored on our // pin_memory=None, MemoryFormat? memory_format=None) -> Tensor layout, device and work with memory ignored on our
// side, so just skip these parameters // side, so just skip these parameters

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_gelu(NodeContext& context) { OutputVector translate_gelu(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto approximate = context.const_input<std::string>(1); auto approximate = context.const_input<std::string>(1);

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_get_attr(NodeContext& context) { OutputVector translate_get_attr(const NodeContext& context) {
auto res = context.get_decoder()->try_decode_get_attr(); auto res = context.get_decoder()->try_decode_get_attr();
FRONT_END_OP_CONVERSION_CHECK(res.size() > 0, "GetAttr must have at least one output."); FRONT_END_OP_CONVERSION_CHECK(res.size() > 0, "GetAttr must have at least one output.");
return res; return res;

View File

@ -13,7 +13,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_getitem(NodeContext& context) { OutputVector translate_getitem(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto input = context.get_input(0); auto input = context.get_input(0);
if (std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) { if (std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_glu(NodeContext& context) { OutputVector translate_glu(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto dim = context.input_is_none(1) ? context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})) auto dim = context.input_is_none(1) ? context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}))

View File

@ -13,7 +13,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_grid_sampler(NodeContext& context) { OutputVector translate_grid_sampler(const NodeContext& context) {
num_inputs_check(context, 4, 5); num_inputs_check(context, 4, 5);
auto x = context.get_input(0); auto x = context.get_input(0);
auto grid = context.get_input(1); auto grid = context.get_input(1);

View File

@ -20,7 +20,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_group_norm(NodeContext& context) { OutputVector translate_group_norm(const NodeContext& context) {
// aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float // aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float
// eps=1.0000000000000001e-05, bool cudnn_enabled=True) -> Tensor // eps=1.0000000000000001e-05, bool cudnn_enabled=True) -> Tensor
num_inputs_check(context, 2, 6); num_inputs_check(context, 2, 6);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_hardtanh(NodeContext& context) { OutputVector translate_hardtanh(const NodeContext& context) {
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
float min = -1; float min = -1;
float max = 1; float max = 1;

View File

@ -13,7 +13,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_if(NodeContext& context) { OutputVector translate_if(const NodeContext& context) {
auto if_node = std::make_shared<opset10::If>(context.get_input(0)); auto if_node = std::make_shared<opset10::If>(context.get_input(0));
context.mark_node(if_node); context.mark_node(if_node);
auto decoder = context.get_decoder(); auto decoder = context.get_decoder();

View File

@ -56,7 +56,7 @@ std::shared_ptr<Node> get_im2col_indices_along_dim(const NodeContext& context,
} }
} // namespace } // namespace
OutputVector translate_im2col(NodeContext& context) { OutputVector translate_im2col(const NodeContext& context) {
num_inputs_check(context, 5, 5); num_inputs_check(context, 5, 5);
auto input = context.get_input(0); auto input = context.get_input(0);
auto kernel_size = context.const_input<std::vector<int64_t>>(1); auto kernel_size = context.const_input<std::vector<int64_t>>(1);

View File

@ -10,9 +10,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
using namespace ov::op; OutputVector translate_index_put_(const NodeContext& context) {
OutputVector translate_index_put_(NodeContext& context) {
// Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation. // Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation.
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs()); auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
return {context.mark_node(node)}; return {context.mark_node(node)};

View File

@ -88,7 +88,7 @@ OutputVector translate_instance_norm_train(const NodeContext& context,
} // namespace } // namespace
OutputVector translate_instance_norm(NodeContext& context) { OutputVector translate_instance_norm(const NodeContext& context) {
num_inputs_check(context, 8, 9); num_inputs_check(context, 8, 9);
auto input = context.get_input(0); auto input = context.get_input(0);
auto eps = context.const_input<float>(7); auto eps = context.const_input<float>(7);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_int(NodeContext& context) { OutputVector translate_int(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::i32))}; return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::i32))};
}; };

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_layer_norm(NodeContext& context) { OutputVector translate_layer_norm(const NodeContext& context) {
num_inputs_check(context, 5, 6); num_inputs_check(context, 5, 6);
auto eps = context.const_input<float>(4); auto eps = context.const_input<float>(4);
auto normalized_shape = context.const_input<Shape>(1); auto normalized_shape = context.const_input<Shape>(1);

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_len(NodeContext& context) { OutputVector translate_len(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_linear(NodeContext& context) { OutputVector translate_linear(const NodeContext& context) {
// schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor // schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
auto x = context.get_input(0); auto x = context.get_input(0);

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_list_construct(NodeContext& context) { OutputVector translate_list_construct(const NodeContext& context) {
// Process the case when prim::ListConstruct has all inputs constant // Process the case when prim::ListConstruct has all inputs constant
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
ov::OutputVector consts; ov::OutputVector consts;

View File

@ -17,7 +17,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_log(NodeContext& context) { OutputVector translate_log(const NodeContext& context) {
// torch.log returns a tensor with the natural logarithm of the elements of input. // torch.log returns a tensor with the natural logarithm of the elements of input.
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto x = context.get_input(0); auto x = context.get_input(0);
@ -26,7 +26,7 @@ OutputVector translate_log(NodeContext& context) {
return {log}; return {log};
}; };
OutputVector translate_log2(NodeContext& context) { OutputVector translate_log2(const NodeContext& context) {
// torch.log2 returns a tensor with the logarithm to the base 2 of the elements of input. // torch.log2 returns a tensor with the logarithm to the base 2 of the elements of input.
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto x = context.get_input(0); auto x = context.get_input(0);

View File

@ -13,7 +13,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_loop(NodeContext& context) { OutputVector translate_loop(const NodeContext& context) {
const auto& inputs = context.inputs(); const auto& inputs = context.inputs();
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Loop must have at least 2 inputs."); FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Loop must have at least 2 inputs.");
auto loop = std::make_shared<ov::op::v5::Loop>(inputs[0], inputs[1]); auto loop = std::make_shared<ov::op::v5::Loop>(inputs[0], inputs[1]);

View File

@ -18,7 +18,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_masked_fill(NodeContext& context) { OutputVector translate_masked_fill(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
auto data = context.get_input(0); auto data = context.get_input(0);
auto mask = context.get_input(1); auto mask = context.get_input(1);

View File

@ -13,7 +13,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_max_poolnd(NodeContext& context) { OutputVector translate_max_poolnd(const NodeContext& context) {
num_inputs_check(context, 6, 6); num_inputs_check(context, 6, 6);
auto kernel = context.const_input<Shape>(1); auto kernel = context.const_input<Shape>(1);
auto strides = context.const_input<Strides>(2); auto strides = context.const_input<Strides>(2);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_mean(NodeContext& context) { OutputVector translate_mean(const NodeContext& context) {
num_inputs_check(context, 3, 4); num_inputs_check(context, 3, 4);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -10,7 +10,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_meshgrid(NodeContext& context) { OutputVector translate_meshgrid(const NodeContext& context) {
std::string indexing = "ij"; std::string indexing = "ij";
if (!context.input_is_none(1)) { if (!context.input_is_none(1)) {
indexing = context.const_input<std::string>(1); indexing = context.const_input<std::string>(1);

View File

@ -20,7 +20,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_max(NodeContext& context) { OutputVector translate_max(const NodeContext& context) {
// torch.max (same for torch.min) actually has two interfaces smashed together: // torch.max (same for torch.min) actually has two interfaces smashed together:
// torch.max(x, dim, keepdim) and torch.max(x, y) // torch.max(x, dim, keepdim) and torch.max(x, y)
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
@ -49,7 +49,7 @@ OutputVector translate_max(NodeContext& context) {
return {values, indicies}; return {values, indicies};
}; };
OutputVector translate_min(NodeContext& context) { OutputVector translate_min(const NodeContext& context) {
// torch.min (same for torch.max) actually has two interfaces smashed together: // torch.min (same for torch.max) actually has two interfaces smashed together:
// torch.min(x, dim, keepdim) and torch.min(x, y) // torch.min(x, dim, keepdim) and torch.min(x, y)
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_narrow(NodeContext& context) { OutputVector translate_narrow(const NodeContext& context) {
num_inputs_check(context, 4, 4); num_inputs_check(context, 4, 4);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_neg(NodeContext& context) { OutputVector translate_neg(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto x = context.get_input(0); auto x = context.get_input(0);
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));

View File

@ -18,7 +18,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_nms(NodeContext& context) { OutputVector translate_nms(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_nonzero(NodeContext& context) { OutputVector translate_nonzero(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto cond = context.get_input(0); auto cond = context.get_input(0);
auto non_zero = context.mark_node(std::make_shared<v3::NonZero>(cond)); auto non_zero = context.mark_node(std::make_shared<v3::NonZero>(cond));

View File

@ -20,7 +20,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_norm(NodeContext& context) { OutputVector translate_norm(const NodeContext& context) {
num_inputs_check(context, 4, 4); num_inputs_check(context, 4, 4);
auto input_tensor = context.get_input(0); auto input_tensor = context.get_input(0);
auto p = context.const_input<float>(1); auto p = context.const_input<float>(1);

View File

@ -10,7 +10,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_numel(NodeContext& context) { OutputVector translate_numel(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
return {numel(context, context.get_input(0))}; return {numel(context, context.get_input(0))};
}; };

View File

@ -22,7 +22,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_pad(NodeContext& context) { OutputVector translate_pad(const NodeContext& context) {
num_inputs_check(context, 2, 4); num_inputs_check(context, 2, 4);
auto data = context.get_input(0); auto data = context.get_input(0);
auto paddings = context.const_input<std::vector<int64_t>>(1); auto paddings = context.const_input<std::vector<int64_t>>(1);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_pow(NodeContext& context) { OutputVector translate_pow(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto lhs = context.get_input(0); auto lhs = context.get_input(0);
auto rhs = context.get_input(1); auto rhs = context.get_input(1);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_pythonop(NodeContext& context) { OutputVector translate_pythonop(const NodeContext& context) {
auto decoder = context.get_decoder(); auto decoder = context.get_decoder();
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1, FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1,
"PythonOp must have 1 subgraph to be able to translate it to OV."); "PythonOp must have 1 subgraph to be able to translate it to OV.");

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_reciprocal(NodeContext& context) { OutputVector translate_reciprocal(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto x = context.get_input(0); auto x = context.get_input(0);
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1})); auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_relu6(NodeContext& context) { OutputVector translate_relu6(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto x = context.get_input(0); auto x = context.get_input(0);
return {context.mark_node(std::make_shared<ov::op::v0::Clamp>(x, 0., 6.))}; return {context.mark_node(std::make_shared<ov::op::v0::Clamp>(x, 0., 6.))};

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_remainder(NodeContext& context) { OutputVector translate_remainder(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_repeat(NodeContext& context) { OutputVector translate_repeat(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
auto repeats = context.get_input(1); auto repeats = context.get_input(1);

View File

@ -34,7 +34,7 @@ OutputVector generate_indices_from_repeats_tensor(const NodeContext& context, co
}; };
} // namespace } // namespace
OutputVector translate_repeat_interleave(NodeContext& context) { OutputVector translate_repeat_interleave(const NodeContext& context) {
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
// constants // constants
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_reshape(NodeContext& context) { OutputVector translate_reshape(const NodeContext& context) {
// Translation is used by both aten::view and aten::reshape. // Translation is used by both aten::view and aten::reshape.
// Schema: aten::view(Tensor input, int[] shape) -> Tensor // Schema: aten::view(Tensor input, int[] shape) -> Tensor
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor // Schema: aten::reshape(Tensor input, int[] shape) -> Tensor

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_reshape_as(NodeContext& context) { OutputVector translate_reshape_as(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
auto input_tensor = context.get_input(0); auto input_tensor = context.get_input(0);
auto shape_tesnor = context.get_input(1); auto shape_tesnor = context.get_input(1);

View File

@ -19,7 +19,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_roi_align(NodeContext& context) { OutputVector translate_roi_align(const NodeContext& context) {
num_inputs_check(context, 7, 7); num_inputs_check(context, 7, 7);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));

View File

@ -17,7 +17,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_roll(NodeContext& context) { OutputVector translate_roll(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
const auto data = context.get_input(0); const auto data = context.get_input(0);
const auto shifts = context.get_input(1); const auto shifts = context.get_input(1);

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_rsqrt(NodeContext& context) { OutputVector translate_rsqrt(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto data = context.get_input(0); auto data = context.get_input(0);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32)); auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_rsub(NodeContext& context) { OutputVector translate_rsub(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
auto self = context.get_input(0); auto self = context.get_input(0);
auto other = context.get_input(1); auto other = context.get_input(1);

View File

@ -20,7 +20,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_select(NodeContext& context) { OutputVector translate_select(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_selu(NodeContext& context) { OutputVector translate_selu(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto x = context.get_input(0); auto x = context.get_input(0);
auto alpha = context.mark_node(v0::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717})); auto alpha = context.mark_node(v0::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_set_item(NodeContext& context) { OutputVector translate_set_item(const NodeContext& context) {
// schema: aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!) // schema: aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!)
// _set_item inserts element in list // _set_item inserts element in list
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_size(NodeContext& context) { OutputVector translate_size(const NodeContext& context) {
num_inputs_check(context, 1, 2); num_inputs_check(context, 1, 2);
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(context.get_input(0), element::i32)); auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(context.get_input(0), element::i32));
if (context.input_is_none(1)) { if (context.input_is_none(1)) {

View File

@ -18,7 +18,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_slice(NodeContext& context) { OutputVector translate_slice(const NodeContext& context) {
// aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[]) // aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])
// aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a)) // aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
ov::Output<ov::Node> dim; ov::Output<ov::Node> dim;

View File

@ -13,7 +13,7 @@ namespace pytorch {
namespace op { namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_softmax(NodeContext& context) { OutputVector translate_softmax(const NodeContext& context) {
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
auto axis = context.const_input<int64_t>(1); auto axis = context.const_input<int64_t>(1);

View File

@ -9,7 +9,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_sort(NodeContext& context) { OutputVector translate_sort(const NodeContext& context) {
num_inputs_check(context, 3, 4); num_inputs_check(context, 3, 4);
const auto input_tensor = context.get_input(0); const auto input_tensor = context.get_input(0);
bool stable, descending; bool stable, descending;
@ -40,7 +40,7 @@ OutputVector translate_sort(NodeContext& context) {
return topk->outputs(); return topk->outputs();
}; };
OutputVector translate_argsort(NodeContext& context) { OutputVector translate_argsort(const NodeContext& context) {
auto sort = translate_sort(context); auto sort = translate_sort(context);
return {sort[1]}; return {sort[1]};
}; };

View File

@ -14,7 +14,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_square(NodeContext& context) { OutputVector translate_square(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
auto input_0 = context.get_input(0); auto input_0 = context.get_input(0);
auto const_2 = context.mark_node(v0::Constant::create(input_0.get_element_type(), Shape{1}, {2})); auto const_2 = context.mark_node(v0::Constant::create(input_0.get_element_type(), Shape{1}, {2}));

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_squeeze(NodeContext& context) { OutputVector translate_squeeze(const NodeContext& context) {
num_inputs_check(context, 1, 2); num_inputs_check(context, 1, 2);
auto x = context.get_input(0); auto x = context.get_input(0);
if (context.input_is_none(1)) { if (context.input_is_none(1)) {

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_sub(NodeContext& context) { OutputVector translate_sub(const NodeContext& context) {
num_inputs_check(context, 2, 3); num_inputs_check(context, 2, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
auto y = context.get_input(1); auto y = context.get_input(1);

View File

@ -11,7 +11,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_sum(NodeContext& context) { OutputVector translate_sum(const NodeContext& context) {
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
bool keep_dims = false; bool keep_dims = false;
ov::Output<ov::Node> axes; ov::Output<ov::Node> axes;

View File

@ -16,7 +16,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_to(NodeContext& context) { OutputVector translate_to(const NodeContext& context) {
int dtype_idx; int dtype_idx;
int memory_format_idx; int memory_format_idx;
if (context.get_input_size() == 5) { if (context.get_input_size() == 5) {

View File

@ -15,7 +15,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_topk(NodeContext& context) { OutputVector translate_topk(const NodeContext& context) {
num_inputs_check(context, 5, 5); num_inputs_check(context, 5, 5);
const auto input_tensor = context.get_input(0); const auto input_tensor = context.get_input(0);
const auto largest = context.const_input<bool>(3); const auto largest = context.const_input<bool>(3);

View File

@ -20,7 +20,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_transpose(NodeContext& context) { OutputVector translate_transpose(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
auto dim0 = context.const_input<int64_t>(1); auto dim0 = context.const_input<int64_t>(1);
auto dim1 = context.const_input<int64_t>(2); auto dim1 = context.const_input<int64_t>(2);

View File

@ -60,11 +60,11 @@ OutputVector translate_base_triu_tril(const NodeContext& context, bool upper) {
} }
}; // namespace }; // namespace
OutputVector translate_triu(NodeContext& context) { OutputVector translate_triu(const NodeContext& context) {
return translate_base_triu_tril(context, true); return translate_base_triu_tril(context, true);
}; };
OutputVector translate_tril(NodeContext& context) { OutputVector translate_tril(const NodeContext& context) {
return translate_base_triu_tril(context, false); return translate_base_triu_tril(context, false);
}; };

View File

@ -13,7 +13,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
OutputVector translate_unfold(NodeContext& context) { OutputVector translate_unfold(const NodeContext& context) {
num_inputs_check(context, 4, 4); num_inputs_check(context, 4, 4);
// constants // constants
auto const_0 = context.mark_node(Constant::create(element::i32, Shape{}, {0})); auto const_0 = context.mark_node(Constant::create(element::i32, Shape{}, {0}));

View File

@ -69,32 +69,32 @@ OutputVector base_translate_upsample(const NodeContext& context,
}; };
} // namespace } // namespace
OutputVector translate_upsample_linear1d(NodeContext& context) { OutputVector translate_upsample_linear1d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 1); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 1);
}; };
OutputVector translate_upsample_bilinear2d(NodeContext& context) { OutputVector translate_upsample_bilinear2d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 2); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 2);
}; };
OutputVector translate_upsample_trilinear3d(NodeContext& context) { OutputVector translate_upsample_trilinear3d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 3); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 3);
}; };
OutputVector translate_upsample_nearest1d(NodeContext& context) { OutputVector translate_upsample_nearest1d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 1); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 1);
}; };
OutputVector translate_upsample_nearest2d(NodeContext& context) { OutputVector translate_upsample_nearest2d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 2); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 2);
}; };
OutputVector translate_upsample_nearest3d(NodeContext& context) { OutputVector translate_upsample_nearest3d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 3); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 3);
}; };
// bicubic is only supported for 2d in pytorch // bicubic is only supported for 2d in pytorch
OutputVector translate_upsample_bicubic2d(NodeContext& context) { OutputVector translate_upsample_bicubic2d(const NodeContext& context) {
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::CUBIC, 2); return base_translate_upsample(context, v4::Interpolate::InterpolateMode::CUBIC, 2);
}; };

View File

@ -20,7 +20,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_var_mean(NodeContext& context) { OutputVector translate_var_mean(const NodeContext& context) {
num_inputs_check(context, 1, 4); num_inputs_check(context, 1, 4);
auto data = context.get_input(0); auto data = context.get_input(0);
bool unbiased = true; bool unbiased = true;
@ -75,7 +75,7 @@ OutputVector translate_var_mean(NodeContext& context) {
return {var, mean}; return {var, mean};
}; };
OutputVector translate_var(NodeContext& context) { OutputVector translate_var(const NodeContext& context) {
auto res = translate_var_mean(context); auto res = translate_var_mean(context);
return {res[0]}; return {res[0]};
} }

View File

@ -14,7 +14,7 @@ namespace op {
using namespace ov::op; using namespace ov::op;
OutputVector translate_where(NodeContext& context) { OutputVector translate_where(const NodeContext& context) {
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
auto cond = context.get_input(0); auto cond = context.get_input(0);
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(1), "aten::where(cond) unsupported"); FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(1), "aten::where(cond) unsupported");

View File

@ -12,7 +12,7 @@ namespace frontend {
namespace pytorch { namespace pytorch {
namespace op { namespace op {
#define OP_CONVERTER(op) OutputVector op(NodeContext& node) #define OP_CONVERTER(op) OutputVector op(const NodeContext& node)
OP_CONVERTER(translate_adaptive_avg_pool3d); OP_CONVERTER(translate_adaptive_avg_pool3d);
OP_CONVERTER(translate_adaptive_max_pool2d); OP_CONVERTER(translate_adaptive_max_pool2d);
@ -130,7 +130,7 @@ OP_CONVERTER(translate_zeros_like);
} // namespace op } // namespace op
const std::map<std::string, PytorchCreatorFunction> get_supported_ops() { const std::map<std::string, CreatorFunction> get_supported_ops() {
return { return {
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases {"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
{"aten::__getitem__", op::translate_getitem}, {"aten::__getitem__", op::translate_getitem},

View File

@ -10,7 +10,7 @@ namespace ov {
namespace frontend { namespace frontend {
namespace pytorch { namespace pytorch {
const std::map<std::string, PytorchCreatorFunction> get_supported_ops(); const std::map<std::string, CreatorFunction> get_supported_ops();
} // namespace pytorch } // namespace pytorch
} // namespace frontend } // namespace frontend

View File

@ -20,7 +20,7 @@ namespace pytorch {
using namespace ov::op; using namespace ov::op;
TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model, TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
const std::map<std::string, PytorchCreatorFunction>& translator_map) const std::map<std::string, CreatorFunction>& translator_map)
: m_input_model(input_model), : m_input_model(input_model),
m_translator_map(translator_map), m_translator_map(translator_map),
m_ov_model(nullptr) {} m_ov_model(nullptr) {}
@ -45,9 +45,9 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
const std::unordered_map<size_t, PlaceDesc>& external_descriptors) { const std::unordered_map<size_t, PlaceDesc>& external_descriptors) {
std::shared_ptr<Model> resulting_model; // define here to make a conversion in a nested scope std::shared_ptr<Model> resulting_model; // define here to make a conversion in a nested scope
{ {
ParameterVector parameters; auto parameters = std::make_shared<ParameterVector>();
TensorMap tensor_map; // tensor map of the current context auto tensor_map = std::make_shared<TensorMap>(); // tensor map of the current context
std::set<size_t> mutated_tensors; auto mutated_tensors = std::make_shared<std::set<size_t>>();
// Go over all pytorch_model inputs and register them in the tensor map: // Go over all pytorch_model inputs and register them in the tensor map:
auto inputs = pytorch_model->inputs(); auto inputs = pytorch_model->inputs();
@ -74,7 +74,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
if (!input_node) { if (!input_node) {
auto parameter = std::make_shared<v0::Parameter>(type, pshape); auto parameter = std::make_shared<v0::Parameter>(type, pshape);
encode_tensor_name(parameter->output(0), inputs.at(i), pytorch_model->get_input_debug_name(i)); encode_tensor_name(parameter->output(0), inputs.at(i), pytorch_model->get_input_debug_name(i));
parameters.push_back(parameter); parameters->push_back(parameter);
input_node = parameter; input_node = parameter;
auto order = pytorch_model->get_input_transpose_order(i); auto order = pytorch_model->get_input_transpose_order(i);
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) { if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
@ -91,7 +91,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
input_node = transpose; input_node = transpose;
} }
} }
tensor_map[inputs.at(i)] = input_node; (*tensor_map)[inputs.at(i)] = input_node;
} }
auto node_visitor = [&](std::shared_ptr<TorchDecoder> node) { auto node_visitor = [&](std::shared_ptr<TorchDecoder> node) {
@ -102,7 +102,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
auto raw_inputs = node->inputs(); auto raw_inputs = node->inputs();
for (size_t i = 0; i < raw_inputs.size(); ++i) { for (size_t i = 0; i < raw_inputs.size(); ++i) {
auto input = raw_inputs.at(i); auto input = raw_inputs.at(i);
if (tensor_map.find(input) == tensor_map.end()) { if (tensor_map->find(input) == tensor_map->end()) {
// Input refers value in the outer scope, need to create a new Parameter in the current scope // Input refers value in the outer scope, need to create a new Parameter in the current scope
// Linkage to external scope will be performed on the level of the parent operation (if or loop) // Linkage to external scope will be performed on the level of the parent operation (if or loop)
// TODO: Eliminate duplication with the main code for Parameters creation // TODO: Eliminate duplication with the main code for Parameters creation
@ -111,18 +111,15 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
// TODO: Use special API to set custom type specification // TODO: Use special API to set custom type specification
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, ps); auto parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
// TODO: Missing get_input_transpose_order handling for not trivial layouts // TODO: Missing get_input_transpose_order handling for not trivial layouts
tensor_map[input] = parameter; (*tensor_map)[input] = parameter;
// set name of parameter to the index of node in the model // set name of parameter to the index of node in the model
encode_tensor_name(parameter->output(0), input); encode_tensor_name(parameter->output(0), input);
parameters.push_back(parameter); parameters->push_back(parameter);
} }
} }
auto context = NodeContext(node, &tensor_map, &parameters, external_tensor_map, this); auto context = NodeContext(node, external_tensor_map, tensor_map, parameters, mutated_tensors, this);
auto converted_outputs = convert_node(context); auto converted_outputs = convert_node(context);
auto mutated_t = context.get_mutated_tensors();
mutated_tensors.insert(mutated_t.begin(), mutated_t.end());
auto fw_outputs = node->outputs(); auto fw_outputs = node->outputs();
// Ops with subgraphs or with mutated inputs may have more outputs after conversion compared to pytorch ones // Ops with subgraphs or with mutated inputs may have more outputs after conversion compared to pytorch ones
FRONT_END_OP_CONVERSION_CHECK(fw_outputs.size() <= converted_outputs.size(), FRONT_END_OP_CONVERSION_CHECK(fw_outputs.size() <= converted_outputs.size(),
@ -134,10 +131,10 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
// FIXME: Now it is not true for at least prim::Constant // FIXME: Now it is not true for at least prim::Constant
for (size_t i = 0; i < fw_outputs.size(); ++i) { for (size_t i = 0; i < fw_outputs.size(); ++i) {
size_t fw_tensor_id = node->output(i); size_t fw_tensor_id = node->output(i);
FRONT_END_GENERAL_CHECK(tensor_map.find(fw_tensor_id) == tensor_map.end(), FRONT_END_GENERAL_CHECK(tensor_map->find(fw_tensor_id) == tensor_map->end(),
"Duplicated producer for PT value with unique ID: ", "Duplicated producer for PT value with unique ID: ",
fw_tensor_id); fw_tensor_id);
tensor_map[fw_tensor_id] = converted_outputs[i]; (*tensor_map)[fw_tensor_id] = converted_outputs[i];
encode_tensor_name(converted_outputs[i], fw_tensor_id, node->get_output_debug_name(i)); encode_tensor_name(converted_outputs[i], fw_tensor_id, node->get_output_debug_name(i));
} }
}; };
@ -148,14 +145,14 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
ResultVector results; ResultVector results;
for (size_t i = 0; i < pytorch_model->num_of_outputs(); ++i) { for (size_t i = 0; i < pytorch_model->num_of_outputs(); ++i) {
size_t id = pytorch_model->output(i); size_t id = pytorch_model->output(i);
if (tensor_map.find(id) == tensor_map.end()) { if (tensor_map->find(id) == tensor_map->end()) {
// Not found in this scope, adding Parameter to connect to external scope // Not found in this scope, adding Parameter to connect to external scope
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, PartialShape::dynamic()); auto parameter = std::make_shared<v0::Parameter>(element::dynamic, PartialShape::dynamic());
encode_tensor_name(parameter->output(0), id); encode_tensor_name(parameter->output(0), id);
parameters.push_back(parameter); parameters->push_back(parameter);
tensor_map[id] = parameter; (*tensor_map)[id] = parameter;
} }
auto ov_output = tensor_map[id]; auto ov_output = tensor_map->at(id);
auto order = pytorch_model->get_output_transpose_order(i); auto order = pytorch_model->get_output_transpose_order(i);
FRONT_END_GENERAL_CHECK(order.size() == 0 || std::is_sorted(order.begin(), order.end()), FRONT_END_GENERAL_CHECK(order.size() == 0 || std::is_sorted(order.begin(), order.end()),
"Output strides have wrong order."); "Output strides have wrong order.");
@ -168,32 +165,32 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
// Since parameters can be added we need to list all current parameters // Since parameters can be added we need to list all current parameters
std::set<size_t> param_names; std::set<size_t> param_names;
for (const auto& param : parameters) { for (const auto& param : *parameters) {
auto input_idx = decode_tensor_name(param->output(0)); auto input_idx = decode_tensor_name(param->output(0));
param_names.insert(input_idx); param_names.insert(input_idx);
} }
for (const auto& tensor_id : mutated_tensors) { for (const auto& tensor_id : *mutated_tensors) {
if (param_names.count(tensor_id)) { if (param_names.count(tensor_id)) {
FRONT_END_GENERAL_CHECK(tensor_map.count(tensor_id), FRONT_END_GENERAL_CHECK(tensor_map->count(tensor_id),
"Tensor with id: ", "Tensor with id: ",
tensor_id, tensor_id,
" doesn't exist in tensor map."); " doesn't exist in tensor map.");
// model input was mutated we need to make a result for it // model input was mutated we need to make a result for it
auto mutated_tensor = tensor_map.at(tensor_id); auto mutated_tensor = tensor_map->at(tensor_id);
// empty external_tensor_map means this is main body of the model and we don't want to create // empty external_tensor_map means this is main body of the model and we don't want to create
// additional outputs in that case. // additional outputs in that case.
if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty()) if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty())
results.push_back(std::make_shared<v0::Result>(tensor_map.at(tensor_id))); results.push_back(std::make_shared<v0::Result>(tensor_map->at(tensor_id)));
} }
} }
resulting_model = std::make_shared<Model>(results, parameters); resulting_model = std::make_shared<Model>(results, *parameters);
// Did a conversion in a nested scope to automatically remove any holders of nodes except those in the graph // Did a conversion in a nested scope to automatically remove any holders of nodes except those in the graph
} }
return resulting_model; return resulting_model;
} }
OutputVector TranslateSession::convert_node(NodeContext& context) { OutputVector TranslateSession::convert_node(const NodeContext& context) {
try { try {
auto it = m_translator_map.find(context.get_op_type()); auto it = m_translator_map.find(context.get_op_type());
if (it != m_translator_map.end()) { if (it != m_translator_map.end()) {

View File

@ -17,7 +17,7 @@ namespace pytorch {
class TranslateSession { class TranslateSession {
public: public:
TranslateSession(const frontend::InputModel::Ptr& input_model, TranslateSession(const frontend::InputModel::Ptr& input_model,
const std::map<std::string, PytorchCreatorFunction>& translator_map); const std::map<std::string, CreatorFunction>& translator_map);
std::shared_ptr<Model> get_converted_model(); std::shared_ptr<Model> get_converted_model();
std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model); std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);
@ -38,10 +38,10 @@ public:
size_t m_friendly_name_counter = 0; size_t m_friendly_name_counter = 0;
private: private:
OutputVector convert_node(NodeContext& context); OutputVector convert_node(const NodeContext& context);
const frontend::InputModel::Ptr m_input_model; const frontend::InputModel::Ptr m_input_model;
const std::map<std::string, PytorchCreatorFunction>& m_translator_map; const std::map<std::string, CreatorFunction>& m_translator_map;
std::shared_ptr<Model> m_ov_model; std::shared_ptr<Model> m_ov_model;
std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map; std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map;

View File

@ -177,7 +177,7 @@ std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input) {
return input; return input;
} }
OutputVector make_framework_node(NodeContext& context) { OutputVector make_framework_node(const NodeContext& context) {
auto schema = context.get_schema(); auto schema = context.get_schema();
// TODO: properly process schema to get the actual position of mutable input // TODO: properly process schema to get the actual position of mutable input
// Hack. Can indicate mutable inputs, but can it be reliable? // Hack. Can indicate mutable inputs, but can it be reliable?

View File

@ -48,7 +48,7 @@ op::PadType convert_pad(const std::string& pt_pad);
std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input); std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input);
OutputVector make_framework_node(NodeContext& context); OutputVector make_framework_node(const NodeContext& context);
std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type); std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
@ -63,8 +63,8 @@ void align_eltwise_input_types(const NodeContext& context,
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start); std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
namespace op { namespace op {
template <OutputVector (*T)(NodeContext&), size_t idx = 0> template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
OutputVector inplace_op(NodeContext& context) { OutputVector inplace_op(const NodeContext& context) {
auto translation_res = T(context); auto translation_res = T(context);
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() == 1, FRONT_END_OP_CONVERSION_CHECK(translation_res.size() == 1,
"inplace_op function must be used on single output translators"); "inplace_op function must be used on single output translators");
@ -73,21 +73,21 @@ OutputVector inplace_op(NodeContext& context) {
} }
template <typename T> template <typename T>
OutputVector translate_1to1_match_1_inputs(NodeContext& context) { OutputVector translate_1to1_match_1_inputs(const NodeContext& context) {
num_inputs_check(context, 1, 1); num_inputs_check(context, 1, 1);
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None."); FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None.");
return {context.mark_node(std::make_shared<T>(context.get_input(0)))}; return {context.mark_node(std::make_shared<T>(context.get_input(0)))};
} }
template <typename T> template <typename T>
OutputVector translate_1to1_match_2_inputs(NodeContext& context) { OutputVector translate_1to1_match_2_inputs(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None."); FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
return {context.mark_node(std::make_shared<T>(context.get_input(0), context.get_input(1)))}; return {context.mark_node(std::make_shared<T>(context.get_input(0), context.get_input(1)))};
} }
template <typename T> template <typename T>
OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) { OutputVector translate_1to1_match_2_inputs_align_types(const NodeContext& context) {
num_inputs_check(context, 2, 2); num_inputs_check(context, 2, 2);
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None."); FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
auto lhs = context.get_input(0); auto lhs = context.get_input(0);
@ -96,11 +96,11 @@ OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) {
return {context.mark_node(std::make_shared<T>(lhs, rhs))}; return {context.mark_node(std::make_shared<T>(lhs, rhs))};
} }
inline OutputVector return_false_scalar(NodeContext& context) { inline OutputVector return_false_scalar(const NodeContext& context) {
return {context.mark_node(ov::op::v0::Constant::create(element::boolean, Shape{}, {false}))}; return {context.mark_node(ov::op::v0::Constant::create(element::boolean, Shape{}, {false}))};
} }
inline OutputVector skip_node(NodeContext& context) { inline OutputVector skip_node(const NodeContext& context) {
return {context.get_input(0).get_node_shared_ptr()}; return {context.get_input(0).get_node_shared_ptr()};
} }