[PT FE] Check number of inputs, use rank function (#15656)
* Check number of inputs, use rank function * Change min inputs in flatten * Fix code style * Fix aten::tensor and aten::_convolution number of inputs * Refactor NodeCOntext a little * Fix codestyle
This commit is contained in:
parent
d0a97af629
commit
6c72ea4bea
@ -106,9 +106,11 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
else:
|
||||
self.graph_element = graph_element
|
||||
self.pt_module = pt_module
|
||||
self.raw_inputs = list(self.graph_element.inputs())
|
||||
self.raw_outputs = list(self.graph_element.outputs())
|
||||
|
||||
def inputs(self) -> list:
|
||||
return [x.unique() for x in self.graph_element.inputs()]
|
||||
return [x.unique() for x in self.raw_inputs]
|
||||
|
||||
def get_input(self, index: int):
|
||||
return self.inputs()[index]
|
||||
@ -207,22 +209,16 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
return self.graph_element.schema()
|
||||
|
||||
def outputs(self) -> list:
|
||||
return [x.unique() for x in self.graph_element.outputs()]
|
||||
|
||||
def _raw_outputs(self) -> list:
|
||||
return list(self.graph_element.outputs())
|
||||
return [x.unique() for x in self.raw_outputs]
|
||||
|
||||
def _raw_output(self, index: int):
|
||||
return self._raw_outputs()[index]
|
||||
|
||||
def _raw_inputs(self) -> list:
|
||||
return list(self.graph_element.inputs())
|
||||
return self.raw_outputs[index]
|
||||
|
||||
def _raw_input(self, index: int):
|
||||
return self._raw_inputs()[index]
|
||||
return self.raw_inputs[index]
|
||||
|
||||
def num_of_outputs(self):
|
||||
return len(self.outputs())
|
||||
return len(self.raw_outputs)
|
||||
|
||||
def output(self, index: int):
|
||||
return self.outputs()[index]
|
||||
|
@ -18,10 +18,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, const_input, index);
|
||||
}
|
||||
|
||||
size_t input(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(size_t, TorchDecoder, get_input, index);
|
||||
}
|
||||
|
||||
const std::vector<size_t>& inputs() const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, inputs);
|
||||
}
|
||||
|
@ -29,9 +29,6 @@ public:
|
||||
|
||||
// TODO: set of input and output methods are not aligned; also they are not aligned with the rest of FEs
|
||||
|
||||
// Input tensor id
|
||||
virtual size_t input(size_t index) const = 0;
|
||||
|
||||
virtual const std::vector<size_t>& inputs() const = 0;
|
||||
|
||||
// ------------------------------
|
||||
|
@ -26,21 +26,23 @@ public:
|
||||
m_decoder(decoder),
|
||||
m_tensor_map(tensor_map),
|
||||
m_ext_tensor_map(ext_tensor_map),
|
||||
m_external_parameters(external_parameters) {}
|
||||
m_external_parameters(external_parameters),
|
||||
m_decoder_inputs(decoder->inputs()),
|
||||
m_decoder_outputs(decoder->outputs()) {}
|
||||
|
||||
// Do not search for input in tensor map; try to access it as a constant of specified type T and return its value
|
||||
template <typename T>
|
||||
T const_input(size_t index) const;
|
||||
|
||||
size_t get_input_size() const override {
|
||||
return m_decoder->inputs().size();
|
||||
return m_decoder_inputs.size();
|
||||
};
|
||||
|
||||
// Search for input in tensor map and return an output port for already converted op
|
||||
// TODO: int due to base class uses it, but naturally it should be size_t for PT
|
||||
Output<Node> get_input(int index) const override {
|
||||
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
|
||||
auto input = m_decoder->input(index);
|
||||
auto input = m_decoder_inputs.at(index);
|
||||
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
|
||||
return m_tensor_map->at(input);
|
||||
}
|
||||
@ -48,7 +50,7 @@ public:
|
||||
// TODO: upstream to base class
|
||||
OutputVector inputs() const {
|
||||
OutputVector res;
|
||||
for (size_t input : m_decoder->inputs()) {
|
||||
for (auto input : m_decoder_inputs) {
|
||||
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding index: ", input, " exist.");
|
||||
res.push_back(m_tensor_map->at(input));
|
||||
}
|
||||
@ -63,29 +65,22 @@ public:
|
||||
return m_decoder->input_is_none(index);
|
||||
}
|
||||
|
||||
size_t get_output_size() const {
|
||||
return m_decoder_outputs.size();
|
||||
}
|
||||
|
||||
std::vector<size_t> outputs() const {
|
||||
return m_decoder_outputs;
|
||||
}
|
||||
|
||||
// Convert the resulting value of this node to ov Constant; works correctly only for nodes that produce
|
||||
// constant value, naturally for prim::Constant
|
||||
OutputVector as_constant() const;
|
||||
|
||||
/*
|
||||
TODO: Should be uncommented when explicit NodeContext ctor won't require passing op_type
|
||||
const std::string& get_op_type() const override {
|
||||
return m_decoder->get_op_type();
|
||||
}
|
||||
*/
|
||||
|
||||
std::string get_schema() const {
|
||||
return m_decoder->get_schema();
|
||||
}
|
||||
|
||||
size_t num_of_outputs() const {
|
||||
return m_decoder->num_of_outputs();
|
||||
}
|
||||
|
||||
std::vector<size_t> outputs() const {
|
||||
return m_decoder->outputs();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> mark_node(std::shared_ptr<Node> ov_node) const {
|
||||
return m_decoder->mark_node(ov_node);
|
||||
}
|
||||
@ -105,7 +100,7 @@ public:
|
||||
|
||||
void mutate_input(size_t index, Output<Node> ov_output) {
|
||||
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
|
||||
auto input = m_decoder->input(index);
|
||||
auto input = m_decoder_inputs.at(index);
|
||||
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input), "No tensor corresponding input: ", input, " exist.");
|
||||
m_tensor_map->at(input).get_tensor().set_names({std::to_string(input) + "_"});
|
||||
// TODO: find out why this doesn't work
|
||||
@ -148,6 +143,8 @@ private:
|
||||
TensorMap* m_tensor_map;
|
||||
const TensorMap& m_ext_tensor_map;
|
||||
ParameterVector* m_external_parameters;
|
||||
const std::vector<size_t> m_decoder_inputs;
|
||||
const std::vector<size_t> m_decoder_outputs;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
|
@ -3,7 +3,13 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/adaptive_avg_pool.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,23 +17,26 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_adaptive_avg_pool3d(NodeContext& context) {
|
||||
auto const_tile_params = context.mark_node(opset10::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
|
||||
auto const_0 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_neg_3 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-3}));
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
|
||||
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto given_shape = context.get_input(1);
|
||||
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input_tensor, element::i32));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input_tensor, element::i32));
|
||||
auto shape_begin =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(input_shape, const_0, const_neg_3, const_1, const_0));
|
||||
auto output_shape = context.mark_node(std::make_shared<opset10::Concat>(OutputVector{shape_begin, given_shape}, 0));
|
||||
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, const_neg_3, const_1, const_0));
|
||||
auto output_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{shape_begin, given_shape}, 0));
|
||||
|
||||
auto tile = context.mark_node(std::make_shared<opset10::Tile>(input_tensor, const_tile_params));
|
||||
auto adaptive_avg_pool = context.mark_node(std::make_shared<opset10::AdaptiveAvgPool>(tile, given_shape));
|
||||
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(adaptive_avg_pool, output_shape, false));
|
||||
auto tile = context.mark_node(std::make_shared<v0::Tile>(input_tensor, const_tile_params));
|
||||
auto adaptive_avg_pool = context.mark_node(std::make_shared<v8::AdaptiveAvgPool>(tile, given_shape));
|
||||
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(adaptive_avg_pool, output_shape, false));
|
||||
|
||||
return {reshape};
|
||||
};
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/adaptive_max_pool.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,9 +12,10 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_adaptive_max_pool2d(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto adaptive_max_pool = context.mark_node(std::make_shared<opset10::AdaptiveMaxPool>(x, y, ov::element::i32));
|
||||
auto adaptive_max_pool = context.mark_node(std::make_shared<ov::op::v8::AdaptiveMaxPool>(x, y, ov::element::i32));
|
||||
return {adaptive_max_pool->output(0), adaptive_max_pool->output(1)};
|
||||
};
|
||||
|
||||
|
@ -16,6 +16,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_add(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
auto dtype0 = context.get_input_type(0);
|
||||
|
@ -5,7 +5,9 @@
|
||||
#include <climits>
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -13,13 +15,16 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_addcmul(NodeContext& context) {
|
||||
const auto eltwise_mult = std::make_shared<opset10::Multiply>(context.get_input(1), context.get_input(2));
|
||||
num_inputs_check(context, 4, 4);
|
||||
const auto eltwise_mult = std::make_shared<v1::Multiply>(context.get_input(1), context.get_input(2));
|
||||
const auto value = context.get_input(3);
|
||||
const auto converted_value = std::make_shared<opset10::ConvertLike>(value, context.get_input(1));
|
||||
const auto scalar_mult = std::make_shared<opset10::Multiply>(eltwise_mult, converted_value);
|
||||
const auto converted_value = std::make_shared<v1::ConvertLike>(value, context.get_input(1));
|
||||
const auto scalar_mult = std::make_shared<v1::Multiply>(eltwise_mult, converted_value);
|
||||
context.mark_nodes({eltwise_mult, converted_value, scalar_mult});
|
||||
return {context.mark_node(std::make_shared<opset10::Add>(context.get_input(0), scalar_mult))};
|
||||
return {context.mark_node(std::make_shared<v1::Add>(context.get_input(0), scalar_mult))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,18 +14,21 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_addmm(NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
auto input = context.get_input(0);
|
||||
auto m1 = context.get_input(1);
|
||||
auto m2 = context.get_input(2);
|
||||
auto beta = context.get_input(3);
|
||||
auto alpha = context.get_input(4);
|
||||
auto beta_converted = context.mark_node(std::make_shared<opset10::ConvertLike>(beta, input));
|
||||
auto mm = context.mark_node(std::make_shared<opset10::MatMul>(m1, m2));
|
||||
auto alpha_converted = context.mark_node(std::make_shared<opset10::ConvertLike>(alpha, mm));
|
||||
auto input_beta = context.mark_node(std::make_shared<opset10::Multiply>(input, beta_converted));
|
||||
auto mm_alpha = context.mark_node(std::make_shared<opset10::Multiply>(mm, alpha_converted));
|
||||
return {context.mark_node(std::make_shared<opset10::Add>(input_beta, mm_alpha))};
|
||||
auto beta_converted = context.mark_node(std::make_shared<v1::ConvertLike>(beta, input));
|
||||
auto mm = context.mark_node(std::make_shared<v0::MatMul>(m1, m2));
|
||||
auto alpha_converted = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, mm));
|
||||
auto input_beta = context.mark_node(std::make_shared<v1::Multiply>(input, beta_converted));
|
||||
auto mm_alpha = context.mark_node(std::make_shared<v1::Multiply>(mm, alpha_converted));
|
||||
return {context.mark_node(std::make_shared<v1::Add>(input_beta, mm_alpha))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,9 +14,11 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_arange(NodeContext& context) {
|
||||
auto zero = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto dtype = element::f32;
|
||||
bool dtype_applied = false;
|
||||
auto num_inputs = context.get_input_size();
|
||||
@ -22,29 +27,26 @@ OutputVector translate_arange(NodeContext& context) {
|
||||
ov::Output<Node> start = zero;
|
||||
ov::Output<Node> step = one;
|
||||
|
||||
// aten::arange(Scalar end, tensor out)
|
||||
if (num_inputs == 2) {
|
||||
// aten::arange(Scalar end, tensor out)
|
||||
end = context.get_input(0);
|
||||
out_tensor = context.input_is_none(1) ? end : context.get_input(1);
|
||||
}
|
||||
// aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
|
||||
if (num_inputs == 4) {
|
||||
} else if (num_inputs == 4) {
|
||||
// aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
|
||||
start = context.get_input(0);
|
||||
end = context.get_input(1);
|
||||
step = context.get_input(2);
|
||||
out_tensor = context.input_is_none(3) ? end : context.get_input(3);
|
||||
}
|
||||
// aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
if (num_inputs == 5) {
|
||||
} else if (num_inputs == 5) {
|
||||
// aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
end = context.get_input(0);
|
||||
out_tensor = end;
|
||||
if (!context.input_is_none(1)) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(1));
|
||||
dtype_applied = true;
|
||||
}
|
||||
}
|
||||
// aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
if (num_inputs == 6) {
|
||||
} else if (num_inputs == 6) {
|
||||
// aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
start = context.get_input(0);
|
||||
end = context.get_input(1);
|
||||
out_tensor = end;
|
||||
@ -52,9 +54,8 @@ OutputVector translate_arange(NodeContext& context) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(2));
|
||||
dtype_applied = true;
|
||||
}
|
||||
}
|
||||
// aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
if (num_inputs == 7) {
|
||||
} else if (num_inputs == 7) {
|
||||
// aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
|
||||
start = context.get_input(0);
|
||||
end = context.get_input(1);
|
||||
step = context.get_input(2);
|
||||
@ -63,13 +64,15 @@ OutputVector translate_arange(NodeContext& context) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(3));
|
||||
dtype_applied = true;
|
||||
}
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Not expected number of inputs for ", context.get_op_type());
|
||||
}
|
||||
auto r_end = context.mark_node(std::make_shared<opset10::Convert>(end, dtype));
|
||||
auto r_start = context.mark_node(std::make_shared<opset10::Convert>(start, dtype));
|
||||
auto r_step = context.mark_node(std::make_shared<opset10::Convert>(step, dtype));
|
||||
auto range = context.mark_node(std::make_shared<opset10::Range>(r_start, r_end, r_step, dtype));
|
||||
auto r_end = context.mark_node(std::make_shared<v0::Convert>(end, dtype));
|
||||
auto r_start = context.mark_node(std::make_shared<v0::Convert>(start, dtype));
|
||||
auto r_step = context.mark_node(std::make_shared<v0::Convert>(step, dtype));
|
||||
auto range = context.mark_node(std::make_shared<v4::Range>(r_start, r_end, r_step, dtype));
|
||||
if (!dtype_applied) {
|
||||
range = context.mark_node(std::make_shared<opset10::ConvertLike>(range, out_tensor));
|
||||
range = context.mark_node(std::make_shared<v1::ConvertLike>(range, out_tensor));
|
||||
}
|
||||
return {range};
|
||||
};
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -12,24 +14,28 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_as_tensor(NodeContext& context) {
|
||||
// aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor
|
||||
num_inputs_check(context, 1, 4);
|
||||
auto dtype = element::f32;
|
||||
Output<Node> cast;
|
||||
if (!context.input_is_none(1)) {
|
||||
auto dtype_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
|
||||
auto dtype_fw_node = std::dynamic_pointer_cast<PtFrameworkNode>(dtype_ext_node);
|
||||
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
|
||||
auto type_input = dtype_fw_node->input_value(0);
|
||||
return {context.mark_node(std::make_shared<opset10::ConvertLike>(context.get_input(0), type_input))};
|
||||
return {context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(0), type_input))};
|
||||
}
|
||||
if (auto dtype_const = std::dynamic_pointer_cast<opset10::Constant>(dtype_ext_node)) {
|
||||
if (auto dtype_const = std::dynamic_pointer_cast<v0::Constant>(dtype_ext_node)) {
|
||||
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
|
||||
dtype = convert_dtype(pt_type);
|
||||
}
|
||||
}
|
||||
cast = context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), dtype));
|
||||
auto cast = context.mark_node(std::make_shared<v0::Convert>(context.get_input(0), dtype));
|
||||
|
||||
// Input with index 2 is device, we skip this input
|
||||
// Input with index 3 is flag requires_grad, we skip this input
|
||||
return {cast};
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,12 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/avg_pool.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/pad.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,7 +16,10 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_avg_poolnd(NodeContext& context) {
|
||||
num_inputs_check(context, 6, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto kernel = context.const_input<Shape>(1);
|
||||
auto strides = context.const_input<Strides>(2);
|
||||
@ -25,23 +33,22 @@ OutputVector translate_avg_poolnd(NodeContext& context) {
|
||||
// PyTorch allows sliding window go off bound, which leads to this accommodation.
|
||||
// More detail on https://github.com/pytorch/pytorch/issues/57178
|
||||
if (count_include_pad) {
|
||||
auto zero = context.mark_node(opset10::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto zero_i32 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input, element::i32));
|
||||
auto rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, element::i32));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto zero_i32 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input);
|
||||
auto pad_values = context.get_input(3);
|
||||
auto pads_len = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {pads.size()}));
|
||||
auto pads_diff = context.mark_node(std::make_shared<opset10::Subtract>(rank, pads_len));
|
||||
auto pads_remaining = context.mark_node(std::make_shared<opset10::Broadcast>(zero_i32, pads_diff));
|
||||
auto pads_len = context.mark_node(v0::Constant::create(element::i32, Shape{}, {pads.size()}));
|
||||
auto pads_diff = context.mark_node(std::make_shared<v1::Subtract>(rank, pads_len));
|
||||
auto pads_remaining = context.mark_node(std::make_shared<v3::Broadcast>(zero_i32, pads_diff));
|
||||
auto padding = context.mark_node(
|
||||
std::make_shared<opset10::Concat>(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0));
|
||||
input =
|
||||
context.mark_node(std::make_shared<opset10::Pad>(input, padding, padding, zero, ov::op::PadMode::CONSTANT));
|
||||
std::make_shared<v0::Concat>(NodeVector{pads_remaining, pad_values.get_node_shared_ptr()}, 0));
|
||||
input = context.mark_node(std::make_shared<v1::Pad>(input, padding, padding, zero, ov::op::PadMode::CONSTANT));
|
||||
pads = Shape(pads.size(), 0);
|
||||
}
|
||||
|
||||
return {context.mark_node(
|
||||
std::make_shared<opset10::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
|
||||
std::make_shared<v1::AvgPool>(input, strides, pads, pads, kernel, !count_include_pad, rounding_type))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,8 +2,14 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/batch_norm.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,33 +17,38 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
Output<Node> broadcast_const_to_channel_dim(NodeContext& context, Output<Node> input, Output<Node> value) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input));
|
||||
auto zero_i = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one_i = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto channel_dim = context.mark_node(std::make_shared<opset10::Gather>(input_shape, one_i, zero_i));
|
||||
auto channel_dim_exp = context.mark_node(std::make_shared<opset10::Unsqueeze>(channel_dim, zero_i));
|
||||
return context.mark_node(std::make_shared<opset10::Broadcast>(value, channel_dim_exp));
|
||||
Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& value) {
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto zero_i = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one_i = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto channel_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, one_i, zero_i));
|
||||
auto channel_dim_exp = context.mark_node(std::make_shared<v0::Unsqueeze>(channel_dim, zero_i));
|
||||
return context.mark_node(std::make_shared<v3::Broadcast>(value, channel_dim_exp));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_batch_norm(NodeContext& context) {
|
||||
// Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var,
|
||||
// bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
||||
num_inputs_check(context, 8, 9);
|
||||
auto input = context.get_input(0);
|
||||
Output<Node> weight;
|
||||
Output<Node> bias;
|
||||
if (!context.input_is_none(1)) {
|
||||
weight = context.get_input(1);
|
||||
} else {
|
||||
auto one_f = context.mark_node(opset10::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto one_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
weight = broadcast_const_to_channel_dim(context, input, one_f);
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
bias = context.get_input(2);
|
||||
} else {
|
||||
auto zero_f = context.mark_node(opset10::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
bias = broadcast_const_to_channel_dim(context, input, zero_f);
|
||||
}
|
||||
// index 3 running_mean and index 4 running_var can be none for training case only, check that not training before
|
||||
@ -45,10 +56,11 @@ OutputVector translate_batch_norm(NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(!training, "Translation for aten::batch_norm do not support training mode.");
|
||||
auto running_mean = context.get_input(3);
|
||||
auto running_var = context.get_input(4);
|
||||
// Index with index 6 is momentum, it is used only in training mode
|
||||
// Input with index 6 is momentum, it is used only in training mode
|
||||
auto epsilon = context.const_input<float>(7);
|
||||
// Input with index 8 is flag "cudnn_enabled" we can ignore it
|
||||
return {context.mark_node(
|
||||
std::make_shared<opset10::BatchNormInference>(input, weight, bias, running_mean, running_var, epsilon))};
|
||||
std::make_shared<v5::BatchNormInference>(input, weight, bias, running_mean, running_var, epsilon))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/maximum.hpp"
|
||||
#include "openvino/op/minimum.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,17 +13,20 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_clamp(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
if (!context.input_is_none(1)) {
|
||||
auto min_clip = context.get_input(1);
|
||||
min_clip = context.mark_node(std::make_shared<opset10::ConvertLike>(min_clip, x));
|
||||
x = context.mark_node(std::make_shared<opset10::Maximum>(x, min_clip));
|
||||
min_clip = context.mark_node(std::make_shared<v1::ConvertLike>(min_clip, x));
|
||||
x = context.mark_node(std::make_shared<v1::Maximum>(x, min_clip));
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
auto max_clip = context.get_input(2);
|
||||
max_clip = context.mark_node(std::make_shared<opset10::ConvertLike>(max_clip, x));
|
||||
x = context.mark_node(std::make_shared<opset10::Minimum>(x, max_clip));
|
||||
max_clip = context.mark_node(std::make_shared<v1::ConvertLike>(max_clip, x));
|
||||
x = context.mark_node(std::make_shared<v1::Minimum>(x, max_clip));
|
||||
}
|
||||
return {x};
|
||||
};
|
||||
|
@ -13,9 +13,10 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_conv_transposend(NodeContext& context) {
|
||||
auto num_inputs = context.get_input_size();
|
||||
FRONT_END_OP_CONVERSION_CHECK(num_inputs == 8, "Unsupported number of inputs: ", num_inputs);
|
||||
num_inputs_check(context, 8, 8);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
// PyTorch support only symmetric padding, padding sizes are the same for begins and ends for each dimension
|
||||
auto pads = context.const_input<CoordinateDiff>(4);
|
||||
@ -27,16 +28,16 @@ OutputVector translate_conv_transposend(NodeContext& context) {
|
||||
|
||||
std::shared_ptr<ov::Node> conv;
|
||||
if (groups == 1) {
|
||||
conv = std::make_shared<ov::op::v1::ConvolutionBackpropData>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type,
|
||||
output_padding);
|
||||
conv = std::make_shared<v1::ConvolutionBackpropData>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type,
|
||||
output_padding);
|
||||
} else {
|
||||
conv = std::make_shared<ov::op::v1::GroupConvolutionBackpropData>(
|
||||
conv = std::make_shared<v1::GroupConvolutionBackpropData>(
|
||||
context.get_input(0),
|
||||
reshape_kernel_for_group(context, context.get_input(1), groups),
|
||||
strides,
|
||||
@ -52,7 +53,7 @@ OutputVector translate_conv_transposend(NodeContext& context) {
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
conv = context.mark_node(std::make_shared<ov::op::v1::Add>(conv, bias));
|
||||
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
|
||||
}
|
||||
|
||||
return {conv};
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/convolution.hpp"
|
||||
#include "openvino/op/group_conv.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,7 +13,10 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_convnd(NodeContext& context) {
|
||||
num_inputs_check(context, 7, 7);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
// In torch pads at beginning are same as at end
|
||||
auto pads = CoordinateDiff(strides.size(), 0);
|
||||
@ -28,22 +33,21 @@ OutputVector translate_convnd(NodeContext& context) {
|
||||
|
||||
std::shared_ptr<ov::Node> conv;
|
||||
if (groups == 1) {
|
||||
conv = std::make_shared<opset10::Convolution>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
conv = std::make_shared<v1::Convolution>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type);
|
||||
} else {
|
||||
conv = std::make_shared<v1::GroupConvolution>(context.get_input(0),
|
||||
reshape_kernel_for_group(context, context.get_input(1), groups),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type);
|
||||
} else {
|
||||
conv =
|
||||
std::make_shared<opset10::GroupConvolution>(context.get_input(0),
|
||||
reshape_kernel_for_group(context, context.get_input(1), groups),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
pad_type);
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
auto bias = context.get_input(2);
|
||||
@ -51,7 +55,7 @@ OutputVector translate_convnd(NodeContext& context) {
|
||||
if (bias_rank == 1) {
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
|
||||
}
|
||||
|
||||
return {conv};
|
||||
|
@ -2,8 +2,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/convolution.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/group_conv.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,11 +14,14 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_convolution(NodeContext& context) {
|
||||
// Schema: aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[]
|
||||
// dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool
|
||||
// cudnn_enabled, bool allow_tf32) -> Tensor
|
||||
|
||||
num_inputs_check(context, 9, 13);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
auto pads = context.const_input<CoordinateDiff>(4);
|
||||
auto dilations = context.const_input<Strides>(5);
|
||||
@ -26,25 +32,25 @@ OutputVector translate_convolution(NodeContext& context) {
|
||||
std::shared_ptr<ov::Node> conv;
|
||||
if (groups == 1) {
|
||||
if (!transposed) {
|
||||
conv = context.mark_node(std::make_shared<opset10::Convolution>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations));
|
||||
conv = context.mark_node(std::make_shared<v1::Convolution>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations));
|
||||
} else {
|
||||
conv = context.mark_node(std::make_shared<opset10::ConvolutionBackpropData>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
ov::op::PadType::EXPLICIT,
|
||||
output_padding));
|
||||
conv = context.mark_node(std::make_shared<v1::ConvolutionBackpropData>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pads,
|
||||
pads,
|
||||
dilations,
|
||||
ov::op::PadType::EXPLICIT,
|
||||
output_padding));
|
||||
}
|
||||
} else {
|
||||
if (!transposed) {
|
||||
conv = context.mark_node(std::make_shared<opset10::GroupConvolution>(
|
||||
conv = context.mark_node(std::make_shared<v1::GroupConvolution>(
|
||||
context.get_input(0),
|
||||
context.mark_output(reshape_kernel_for_group(context, context.get_input(1), groups)),
|
||||
strides,
|
||||
@ -52,7 +58,7 @@ OutputVector translate_convolution(NodeContext& context) {
|
||||
pads,
|
||||
dilations));
|
||||
} else {
|
||||
conv = context.mark_node(std::make_shared<opset10::GroupConvolutionBackpropData>(
|
||||
conv = context.mark_node(std::make_shared<v1::GroupConvolutionBackpropData>(
|
||||
context.get_input(0),
|
||||
context.mark_output(reshape_kernel_for_group(context, context.get_input(1), groups)),
|
||||
strides,
|
||||
@ -70,7 +76,7 @@ OutputVector translate_convolution(NodeContext& context) {
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
|
||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
|
||||
}
|
||||
|
||||
return {context.mark_output(conv)};
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/convolution.hpp"
|
||||
#include "openvino/op/group_conv.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,9 +13,12 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_convolution_mode(NodeContext& context) {
|
||||
// Schema: aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[]
|
||||
// dilation, int groups) -> Tensor
|
||||
num_inputs_check(context, 7, 7);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
auto pad_mode = context.const_input<std::string>(4);
|
||||
auto dilations = context.const_input<Strides>(5);
|
||||
@ -24,15 +29,15 @@ OutputVector translate_convolution_mode(NodeContext& context) {
|
||||
|
||||
std::shared_ptr<ov::Node> conv;
|
||||
if (groups == 1) {
|
||||
conv = context.mark_node(std::make_shared<opset10::Convolution>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pad_const,
|
||||
pad_const,
|
||||
dilations,
|
||||
auto_pad_mode));
|
||||
conv = context.mark_node(std::make_shared<v1::Convolution>(context.get_input(0),
|
||||
context.get_input(1),
|
||||
strides,
|
||||
pad_const,
|
||||
pad_const,
|
||||
dilations,
|
||||
auto_pad_mode));
|
||||
} else {
|
||||
conv = context.mark_node(std::make_shared<opset10::GroupConvolution>(
|
||||
conv = context.mark_node(std::make_shared<v1::GroupConvolution>(
|
||||
context.get_input(0),
|
||||
context.mark_output(reshape_kernel_for_group(context, context.get_input(1), groups)),
|
||||
strides,
|
||||
@ -49,7 +54,7 @@ OutputVector translate_convolution_mode(NodeContext& context) {
|
||||
bias = reshape_channelwise(context, bias, conv);
|
||||
}
|
||||
|
||||
conv = context.mark_node(std::make_shared<opset10::Add>(conv, bias));
|
||||
conv = context.mark_node(std::make_shared<v1::Add>(conv, bias));
|
||||
}
|
||||
return {context.mark_output(conv)};
|
||||
};
|
||||
|
@ -3,7 +3,6 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,12 +10,13 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_dim(NodeContext& context) {
|
||||
auto shape = std::make_shared<opset10::ShapeOf>(context.get_input(0), element::i32);
|
||||
auto rank = std::make_shared<opset10::ShapeOf>(shape, element::i32);
|
||||
auto squeeze = std::make_shared<opset10::Squeeze>(rank);
|
||||
context.mark_nodes({shape, rank, squeeze});
|
||||
return squeeze->outputs();
|
||||
num_inputs_check(context, 1, 1);
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
|
||||
return {rank};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -18,6 +18,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_div(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
std::string rounding_mode = "";
|
||||
|
@ -2,8 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/elu.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,9 +13,16 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_elu(NodeContext& context) {
|
||||
// aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
|
||||
num_inputs_check(context, 2, 4);
|
||||
auto x = context.get_input(0);
|
||||
auto alpha = context.const_input<float>(1);
|
||||
return {context.mark_node(std::make_shared<opset10::Elu>(x, alpha))};
|
||||
// TODO: Figure out what scale and input_scale do
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<int64_t>(2) == 1,
|
||||
"Unexpected value of scale input for elu operation");
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(3) || context.const_input<int64_t>(3) == 1,
|
||||
"Unexpected value of input_scale input for elu operation");
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Elu>(x, alpha))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,8 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,14 +13,15 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_embedding(NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
auto data = context.get_input(0);
|
||||
auto indices = context.get_input(1);
|
||||
// TODO: find out the meaning of input idx 2
|
||||
FRONT_END_OP_CONVERSION_CHECK(
|
||||
context.const_input<bool>(3) == false && context.const_input<bool>(4) == false,
|
||||
"Only False is supported on inputs with indexes 3 and 4 for aten::embedding translation");
|
||||
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<opset10::Gather>(data, indices, axis_0))};
|
||||
auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(data, indices, axis_0))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,11 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/equal.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,29 +15,37 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
OutputVector base_expand(NodeContext& context, ov::Output<ov::Node> x, ov::Output<ov::Node> sizes) {
|
||||
auto one = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto sizes_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(sizes, element::i32));
|
||||
auto neg_one = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto neg_ones = context.mark_node(std::make_shared<opset10::Broadcast>(neg_one, sizes_shape));
|
||||
auto ones = context.mark_node(std::make_shared<opset10::Broadcast>(one, sizes_shape));
|
||||
auto neg_sizes = context.mark_node(std::make_shared<opset10::Equal>(sizes, neg_ones));
|
||||
auto shape = context.mark_node(std::make_shared<opset10::Select>(neg_sizes, ones, sizes));
|
||||
return {std::make_shared<opset10::Broadcast>(x, shape, ov::op::BroadcastType::BIDIRECTIONAL)};
|
||||
OutputVector base_expand(const NodeContext& context, const Output<Node>& x, const Output<Node>& sizes) {
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto sizes_shape = context.mark_node(std::make_shared<v3::ShapeOf>(sizes, element::i32));
|
||||
auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto neg_ones = context.mark_node(std::make_shared<v3::Broadcast>(neg_one, sizes_shape));
|
||||
auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, sizes_shape));
|
||||
auto neg_sizes = context.mark_node(std::make_shared<v1::Equal>(sizes, neg_ones));
|
||||
auto shape = context.mark_node(std::make_shared<v1::Select>(neg_sizes, ones, sizes));
|
||||
return {context.mark_node(std::make_shared<v3::Broadcast>(x, shape, BroadcastType::BIDIRECTIONAL))};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_expand(NodeContext& context) {
|
||||
// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
// TODO: figure out what implicit means
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(2) || context.const_input<bool>(2) == false,
|
||||
"Unexpected value of implicit for expand operation");
|
||||
return base_expand(context, x, sizes);
|
||||
};
|
||||
|
||||
OutputVector translate_expand_as(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto sizes = context.mark_node(std::make_shared<opset10::ShapeOf>(y, element::i32));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(y, element::i32));
|
||||
return base_expand(context, x, sizes);
|
||||
};
|
||||
|
||||
|
@ -14,16 +14,18 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_eye(NodeContext& context) {
|
||||
size_t num_inputs = context.get_input_size();
|
||||
auto x = context.get_input(0);
|
||||
// num rows and cols should be integer, but at the moment conversion their data type can be unknown yet
|
||||
x = context.mark_node(std::make_shared<ov::op::v0::Convert>(x, element::i64));
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::i64));
|
||||
Output<Node> y;
|
||||
size_t dtype_id;
|
||||
auto dtype = element::f32;
|
||||
// aten::eye support only main diagonal
|
||||
auto diagonal = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto diagonal = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
if (num_inputs == 5) {
|
||||
// aten::eye(n, dtype, layout, device, pin_memory)
|
||||
y = x;
|
||||
@ -31,7 +33,7 @@ OutputVector translate_eye(NodeContext& context) {
|
||||
} else if (num_inputs == 6) {
|
||||
// aten::eye(n, m, dtype, layout, device, pin_memory)
|
||||
y = context.get_input(1);
|
||||
y = context.mark_node(std::make_shared<ov::op::v0::Convert>(y, element::i64));
|
||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::i64));
|
||||
dtype_id = 2;
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Unsupported number of inputs: ", num_inputs, " for aten::eye");
|
||||
@ -39,8 +41,8 @@ OutputVector translate_eye(NodeContext& context) {
|
||||
if (!context.input_is_none(dtype_id)) {
|
||||
dtype = convert_dtype(context.const_input<int64_t>(dtype_id));
|
||||
}
|
||||
auto eye = context.mark_node(std::make_shared<ov::op::v9::Eye>(x, y, diagonal, element::i32));
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(eye, dtype))};
|
||||
auto eye = context.mark_node(std::make_shared<v9::Eye>(x, y, diagonal, element::i32));
|
||||
return {context.mark_node(std::make_shared<v0::Convert>(eye, dtype))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -19,17 +19,32 @@ namespace op {
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_flatten(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto start_dim = context.const_input<int64_t>(1);
|
||||
auto end_dim = context.const_input<int64_t>(2);
|
||||
|
||||
int64_t start_dim = 0;
|
||||
int64_t end_dim = -1;
|
||||
if (!context.input_is_none(1)) {
|
||||
start_dim = context.const_input<int64_t>(1);
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
end_dim = context.const_input<int64_t>(2);
|
||||
}
|
||||
Output<Node> shape;
|
||||
Output<Node> rank;
|
||||
std::tie(shape, rank) = get_shape_rank(context, x, true);
|
||||
// Use opset::If for dim normalization. For now we only have flatten with constant start and end
|
||||
auto start_dim_node = context.get_input(1);
|
||||
auto end_dim_node = context.get_input(2);
|
||||
Output<Node> start_dim_node;
|
||||
Output<Node> end_dim_node;
|
||||
if (!context.input_is_none(1)) {
|
||||
start_dim_node = context.get_input(1);
|
||||
} else {
|
||||
start_dim_node = v0::Constant::create(element::i32, Shape{}, {start_dim});
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
end_dim_node = context.get_input(2);
|
||||
} else {
|
||||
end_dim_node = v0::Constant::create(element::i32, Shape{}, {end_dim});
|
||||
}
|
||||
if (start_dim < 0) {
|
||||
start_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, start_dim_node));
|
||||
}
|
||||
@ -51,7 +66,7 @@ OutputVector translate_flatten(NodeContext& context) {
|
||||
|
||||
context.mark_nodes({zero, one, int_max, start_dim_u, end_dim_u, slice_begin, slice_end, neg_1_const, new_shape});
|
||||
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(context.get_input(0), new_shape, true))};
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(x, new_shape, true))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,8 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/floor.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,11 +12,14 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_floor_divide(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto div = context.mark_node(std::make_shared<opset10::Divide>(x, y, true));
|
||||
return {context.mark_node(std::make_shared<opset10::Floor>(div))};
|
||||
auto div = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
|
||||
return {context.mark_node(std::make_shared<v0::Floor>(div))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,9 +12,10 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_floordiv(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
return {context.mark_node(std::make_shared<opset10::Divide>(x, y, true))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Divide>(x, y, true))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -15,33 +15,36 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
ov::Output<Node> base_translate_full(NodeContext& context, ov::Output<Node> sizes, ov::Output<Node> value) {
|
||||
return context.mark_node(std::make_shared<ov::op::v3::Broadcast>(value, sizes));
|
||||
Output<Node> base_translate_full(const NodeContext& context, const Output<Node>& sizes, const Output<Node>& value) {
|
||||
return context.mark_node(std::make_shared<v3::Broadcast>(value, sizes));
|
||||
}
|
||||
|
||||
ov::Output<Node> base_translate_full_with_convert(NodeContext& context,
|
||||
ov::Output<Node> sizes,
|
||||
ov::Output<Node> value,
|
||||
size_t dtype_id) {
|
||||
Output<Node> base_translate_full_with_convert(const NodeContext& context,
|
||||
const Output<Node>& sizes,
|
||||
const Output<Node>& value,
|
||||
size_t dtype_id) {
|
||||
auto filled_tensor = base_translate_full(context, sizes, value);
|
||||
if (!context.input_is_none(dtype_id)) {
|
||||
auto dtype = convert_dtype(context.const_input<int64_t>(dtype_id));
|
||||
filled_tensor = context.mark_node(std::make_shared<ov::op::v0::Convert>(filled_tensor, dtype));
|
||||
filled_tensor = context.mark_node(std::make_shared<v0::Convert>(filled_tensor, dtype));
|
||||
}
|
||||
return filled_tensor;
|
||||
}
|
||||
|
||||
ov::Output<Node> base_translate_full_with_convertlike(NodeContext& context,
|
||||
ov::Output<Node> sizes,
|
||||
ov::Output<Node> value,
|
||||
ov::Output<Node> out) {
|
||||
Output<Node> base_translate_full_with_convertlike(const NodeContext& context,
|
||||
const Output<Node>& sizes,
|
||||
const Output<Node>& value,
|
||||
const Output<Node>& out) {
|
||||
auto filled_tensor = base_translate_full(context, sizes, value);
|
||||
return context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(filled_tensor, out));
|
||||
return context.mark_node(std::make_shared<v1::ConvertLike>(filled_tensor, out));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_full(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto sizes = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
auto num_inputs = context.get_input_size();
|
||||
@ -58,9 +61,10 @@ OutputVector translate_full(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_full_like(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
auto sizes = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
if (context.get_input_size() == 7) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 2)};
|
||||
}
|
||||
@ -69,13 +73,15 @@ OutputVector translate_full_like(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_fill_(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
auto sizes = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
OutputVector translate_new_full(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
auto value = context.get_input(2);
|
||||
@ -86,8 +92,9 @@ OutputVector translate_new_full(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_zeros(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 5);
|
||||
auto sizes = context.get_input(0);
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto num_inputs = context.get_input_size();
|
||||
if (num_inputs < 5) {
|
||||
int out_id = num_inputs == 2 ? 1 : 2;
|
||||
@ -102,9 +109,10 @@ OutputVector translate_zeros(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_zeros_like(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto sizes = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
if (context.get_input_size() == 6) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 1)};
|
||||
}
|
||||
@ -113,9 +121,10 @@ OutputVector translate_zeros_like(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_new_zeros(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 2)};
|
||||
}
|
||||
@ -123,8 +132,9 @@ OutputVector translate_new_zeros(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_ones(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 5);
|
||||
auto sizes = context.get_input(0);
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto num_inputs = context.get_input_size();
|
||||
if (num_inputs < 5) {
|
||||
int out_id = num_inputs == 2 ? 1 : 2;
|
||||
@ -139,9 +149,10 @@ OutputVector translate_ones(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_ones_like(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto sizes = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
if (context.get_input_size() == 6) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 1)};
|
||||
}
|
||||
@ -150,9 +161,10 @@ OutputVector translate_ones_like(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_new_ones(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
if (context.get_input_size() == 6 && !context.input_is_none(2)) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 2)};
|
||||
}
|
||||
@ -160,11 +172,12 @@ OutputVector translate_new_ones(NodeContext& context) {
|
||||
};
|
||||
|
||||
OutputVector translate_empty(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto sizes = context.get_input(0);
|
||||
// In OV uninitialised data is not supported, so we create a tensor filled with zeros with a given shape and type.
|
||||
auto value = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
int dtype_id = 1;
|
||||
ov::Output<ov::Node> empty;
|
||||
Output<Node> empty;
|
||||
if (!context.input_is_none(dtype_id)) {
|
||||
empty = base_translate_full_with_convert(context, sizes, value, dtype_id);
|
||||
} else {
|
||||
|
@ -2,8 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/gelu.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,11 +13,12 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_gelu(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto approximate = context.const_input<std::string>(1);
|
||||
// TODO: Add support for "tanh" approximate
|
||||
FRONT_END_OP_CONVERSION_CHECK(approximate == "none", "Unsupported approximate for Gelu: ", approximate);
|
||||
return {context.mark_node(std::make_shared<opset10::Gelu>(x))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v7::Gelu>(x))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -13,6 +13,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_getitem(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input = context.get_input(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(cast_fw_node(input.get_node_shared_ptr(), "prim::ListConstruct") == nullptr,
|
||||
"unsupported case for aten::getitem");
|
||||
|
@ -7,21 +7,25 @@
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/sigmoid.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_glu(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto dim = context.input_is_none(1) ? context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {-1}))
|
||||
auto dim = context.input_is_none(1) ? context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}))
|
||||
: context.get_input(1);
|
||||
auto split = context.mark_node(std::make_shared<ov::op::v1::Split>(x, dim, 2));
|
||||
auto split = context.mark_node(std::make_shared<v1::Split>(x, dim, 2));
|
||||
auto first = split->output(0);
|
||||
auto second = split->output(1);
|
||||
auto sigmoid = context.mark_node(std::make_shared<ov::op::v0::Sigmoid>(second));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Multiply>(first, sigmoid))};
|
||||
auto sigmoid = context.mark_node(std::make_shared<v0::Sigmoid>(second));
|
||||
return {context.mark_node(std::make_shared<v1::Multiply>(first, sigmoid))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -4,25 +4,29 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/grid_sample.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_grid_sampler(NodeContext& context) {
|
||||
num_inputs_check(context, 4, 5);
|
||||
auto x = context.get_input(0);
|
||||
auto grid = context.get_input(1);
|
||||
ov::op::v9::GridSample::Attributes attrs{};
|
||||
const std::unordered_map<int64_t, ov::op::v9::GridSample::InterpolationMode> grid_sample_mode_map{
|
||||
{0, ov::op::v9::GridSample::InterpolationMode::BILINEAR},
|
||||
{1, ov::op::v9::GridSample::InterpolationMode::NEAREST},
|
||||
{2, ov::op::v9::GridSample::InterpolationMode::BICUBIC},
|
||||
v9::GridSample::Attributes attrs{};
|
||||
const std::unordered_map<int64_t, v9::GridSample::InterpolationMode> grid_sample_mode_map{
|
||||
{0, v9::GridSample::InterpolationMode::BILINEAR},
|
||||
{1, v9::GridSample::InterpolationMode::NEAREST},
|
||||
{2, v9::GridSample::InterpolationMode::BICUBIC},
|
||||
};
|
||||
const std::unordered_map<int64_t, ov::op::v9::GridSample::PaddingMode> grid_sample_padding_mode_map{
|
||||
{0, ov::op::v9::GridSample::PaddingMode::ZEROS},
|
||||
{1, ov::op::v9::GridSample::PaddingMode::BORDER},
|
||||
{2, ov::op::v9::GridSample::PaddingMode::REFLECTION}};
|
||||
const std::unordered_map<int64_t, v9::GridSample::PaddingMode> grid_sample_padding_mode_map{
|
||||
{0, v9::GridSample::PaddingMode::ZEROS},
|
||||
{1, v9::GridSample::PaddingMode::BORDER},
|
||||
{2, v9::GridSample::PaddingMode::REFLECTION}};
|
||||
auto mode = context.const_input<int64_t>(2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(grid_sample_mode_map.count(mode), "Unknown interpolation mode: ", mode);
|
||||
attrs.mode = grid_sample_mode_map.at(mode);
|
||||
@ -37,7 +41,7 @@ OutputVector translate_grid_sampler(NodeContext& context) {
|
||||
}
|
||||
attrs.align_corners = align_corners;
|
||||
|
||||
return {context.mark_node(std::make_shared<ov::op::v9::GridSample>(x, grid, attrs))};
|
||||
return {context.mark_node(std::make_shared<v9::GridSample>(x, grid, attrs))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,14 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/mvn.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,35 +18,40 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_group_norm(NodeContext& context) {
|
||||
// aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float
|
||||
// eps=1.0000000000000001e-05, bool cudnn_enabled=True) -> Tensor
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto data = context.get_input(0);
|
||||
auto num_groups = context.const_input<int64_t>(1);
|
||||
// input 2 - weights and input 3 - bias are optional without default value, we handle them later
|
||||
auto eps = static_cast<float>(context.const_input<double>(4));
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(data, element::i64));
|
||||
auto scalar_one = context.mark_node(opset10::Constant::create(element::i64, {}, {1}));
|
||||
Output<Node> input_shape;
|
||||
Output<Node> input_rank;
|
||||
std::tie(input_shape, input_rank) = get_shape_rank(context, data, true, element::i64);
|
||||
auto scalar_one = context.mark_node(v0::Constant::create(element::i64, {}, {1}));
|
||||
auto shape = context.mark_node(
|
||||
std::make_shared<opset10::Constant>(element::i64, Shape({3}), std::vector<int64_t>{0, num_groups, -1}));
|
||||
auto reshaped_input = context.mark_node(std::make_shared<opset10::Reshape>(data, shape, true));
|
||||
auto reduction_axes =
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape({1}), std::vector<int64_t>(1, 2)));
|
||||
std::make_shared<v0::Constant>(element::i64, Shape({3}), std::vector<int64_t>{0, num_groups, -1}));
|
||||
auto reshaped_input = context.mark_node(std::make_shared<v1::Reshape>(data, shape, true));
|
||||
auto reduction_axes = context.mark_node(v0::Constant::create(element::i64, Shape({1}), std::vector<int64_t>(1, 2)));
|
||||
auto reshaped_norm = context.mark_node(
|
||||
std::make_shared<opset10::MVN>(reshaped_input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
|
||||
auto norm = context.mark_node(std::make_shared<opset10::Reshape>(reshaped_norm, input_shape, true));
|
||||
auto input_rank2d = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape, element::i64));
|
||||
auto input_rank = context.mark_node(std::make_shared<opset10::Squeeze>(input_rank2d));
|
||||
auto skip_last = context.mark_node(std::make_shared<opset10::Subtract>(input_rank, scalar_one));
|
||||
auto axes = context.mark_node(std::make_shared<opset10::Range>(scalar_one, skip_last, scalar_one, element::i64));
|
||||
std::make_shared<v6::MVN>(reshaped_input, reduction_axes, true, eps, MVNEpsMode::INSIDE_SQRT));
|
||||
auto norm = context.mark_node(std::make_shared<v1::Reshape>(reshaped_norm, input_shape, true));
|
||||
auto skip_last = context.mark_node(std::make_shared<v1::Subtract>(input_rank, scalar_one));
|
||||
auto axes = context.mark_node(std::make_shared<v4::Range>(scalar_one, skip_last, scalar_one, element::i64));
|
||||
if (!context.input_is_none(2)) {
|
||||
auto weights = context.get_input(2);
|
||||
weights = context.mark_node(std::make_shared<opset10::Unsqueeze>(weights, axes));
|
||||
norm = context.mark_node(std::make_shared<opset10::Multiply>(norm, weights));
|
||||
weights = context.mark_node(std::make_shared<v0::Unsqueeze>(weights, axes));
|
||||
norm = context.mark_node(std::make_shared<v1::Multiply>(norm, weights));
|
||||
}
|
||||
if (!context.input_is_none(3)) {
|
||||
auto bias = context.get_input(3);
|
||||
bias = context.mark_node(std::make_shared<opset10::Unsqueeze>(bias, axes));
|
||||
norm = context.mark_node(std::make_shared<opset10::Add>(norm, bias));
|
||||
bias = context.mark_node(std::make_shared<v0::Unsqueeze>(bias, axes));
|
||||
norm = context.mark_node(std::make_shared<v1::Add>(norm, bias));
|
||||
}
|
||||
// Input with index 5 is flag "cudnn_enabled" we can ignore it
|
||||
return {norm};
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/clamp.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,6 +12,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_hardtanh(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
float min = -1;
|
||||
float max = 1;
|
||||
if (!context.input_is_none(1)) {
|
||||
@ -20,7 +21,7 @@ OutputVector translate_hardtanh(NodeContext& context) {
|
||||
if (!context.input_is_none(2)) {
|
||||
max = context.const_input<float>(2);
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Clamp>(context.get_input(0), min, max))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Clamp>(context.get_input(0), min, max))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -56,7 +56,7 @@ OutputVector translate_if(NodeContext& context) {
|
||||
}
|
||||
}
|
||||
OutputVector res;
|
||||
const auto num_outs = context.num_of_outputs();
|
||||
const auto num_outs = context.get_output_size();
|
||||
const auto then_results = then_body->get_results();
|
||||
const auto else_results = else_body->get_results();
|
||||
FRONT_END_OP_CONVERSION_CHECK(then_results.size() >= num_outs && else_results.size() >= num_outs,
|
||||
|
@ -3,7 +3,20 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/pad.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,37 +24,40 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<Node> get_im2col_indices_along_dim(NodeContext& context,
|
||||
ov::Output<Node> input_d,
|
||||
std::shared_ptr<Node> get_im2col_indices_along_dim(const NodeContext& context,
|
||||
const Output<Node>& input_d,
|
||||
int64_t kernel_size_d,
|
||||
int64_t dilation_d,
|
||||
int64_t padding_d,
|
||||
int64_t stride_d) {
|
||||
auto zero = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {-1}));
|
||||
auto kernel_size = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {kernel_size_d}));
|
||||
auto padding_2 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {padding_d * 2}));
|
||||
auto stride = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {stride_d}));
|
||||
auto input_d_squeezed = context.mark_node(std::make_shared<opset10::Squeeze>(input_d, zero));
|
||||
auto blocks_d = context.mark_node(std::make_shared<opset10::Add>(input_d_squeezed, padding_2));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}));
|
||||
auto kernel_size = context.mark_node(v0::Constant::create(element::i64, Shape{}, {kernel_size_d}));
|
||||
auto padding_2 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {padding_d * 2}));
|
||||
auto stride = context.mark_node(v0::Constant::create(element::i64, Shape{}, {stride_d}));
|
||||
auto input_d_squeezed = context.mark_node(std::make_shared<v0::Squeeze>(input_d, zero));
|
||||
auto blocks_d = context.mark_node(std::make_shared<v1::Add>(input_d_squeezed, padding_2));
|
||||
auto subtrahend =
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape{}, {dilation_d * (kernel_size_d - 1)}));
|
||||
blocks_d = context.mark_node(std::make_shared<opset10::Subtract>(blocks_d, subtrahend));
|
||||
auto blocks_d_indices = context.mark_node(std::make_shared<opset10::Range>(zero, blocks_d, stride, element::i64));
|
||||
blocks_d_indices = context.mark_node(std::make_shared<opset10::Unsqueeze>(blocks_d_indices, zero));
|
||||
context.mark_node(v0::Constant::create(element::i64, Shape{}, {dilation_d * (kernel_size_d - 1)}));
|
||||
blocks_d = context.mark_node(std::make_shared<v1::Subtract>(blocks_d, subtrahend));
|
||||
auto blocks_d_indices = context.mark_node(std::make_shared<v4::Range>(zero, blocks_d, stride, element::i64));
|
||||
blocks_d_indices = context.mark_node(std::make_shared<v0::Unsqueeze>(blocks_d_indices, zero));
|
||||
std::vector<int64_t> rng;
|
||||
for (int64_t i = 0; i < kernel_size_d * dilation_d; i += dilation_d) {
|
||||
rng.push_back(i);
|
||||
}
|
||||
|
||||
auto kernel_grid = context.mark_node(opset10::Constant::create(element::i64, Shape{rng.size()}, rng));
|
||||
auto kernel_mask = context.mark_node(std::make_shared<opset10::Unsqueeze>(kernel_grid, minus_one));
|
||||
return context.mark_node(std::make_shared<opset10::Add>(blocks_d_indices, kernel_mask));
|
||||
auto kernel_grid = context.mark_node(v0::Constant::create(element::i64, Shape{rng.size()}, rng));
|
||||
auto kernel_mask = context.mark_node(std::make_shared<v0::Unsqueeze>(kernel_grid, minus_one));
|
||||
return context.mark_node(std::make_shared<v1::Add>(blocks_d_indices, kernel_mask));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_im2col(NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
auto input = context.get_input(0);
|
||||
auto kernel_size = context.const_input<std::vector<int64_t>>(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(kernel_size.size() == 2, "kernel size should contains 2 elements");
|
||||
@ -51,13 +67,13 @@ OutputVector translate_im2col(NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(kernel_size.size() == 2, "padding should contains 2 elements");
|
||||
auto stride = context.const_input<std::vector<int64_t>>(4);
|
||||
FRONT_END_OP_CONVERSION_CHECK(kernel_size.size() == 2, "stride should contains 2 elements");
|
||||
auto zero = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input));
|
||||
auto zero_f = context.mark_node(opset10::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
auto two = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto four = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {4}));
|
||||
auto input_shape_split = context.mark_node(std::make_shared<opset10::Split>(input_shape, zero, 4));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto four = context.mark_node(v0::Constant::create(element::i64, Shape{}, {4}));
|
||||
auto input_shape_split = context.mark_node(std::make_shared<v1::Split>(input_shape, zero, 4));
|
||||
auto input_b = input_shape_split->output(0);
|
||||
auto input_c = input_shape_split->output(1);
|
||||
auto input_h = input_shape_split->output(2);
|
||||
@ -72,22 +88,22 @@ OutputVector translate_im2col(NodeContext& context) {
|
||||
auto kernel_w = kernel_size[1];
|
||||
auto blocks_row_indices = get_im2col_indices_along_dim(context, input_h, kernel_h, dilation_h, padding_h, stride_h);
|
||||
auto blocks_col_indices = get_im2col_indices_along_dim(context, input_w, kernel_w, dilation_w, padding_w, stride_w);
|
||||
auto kernel_window = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {kernel_h * kernel_w}));
|
||||
auto input_c_squeezed = context.mark_node(std::make_shared<opset10::Squeeze>(input_c, zero));
|
||||
auto channel_unfolded = context.mark_node(std::make_shared<opset10::Multiply>(input_c_squeezed, kernel_window));
|
||||
auto channel_unfolded_unsqueezed = context.mark_node(std::make_shared<opset10::Unsqueeze>(channel_unfolded, zero));
|
||||
auto kernel_window = context.mark_node(v0::Constant::create(element::i64, Shape{}, {kernel_h * kernel_w}));
|
||||
auto input_c_squeezed = context.mark_node(std::make_shared<v0::Squeeze>(input_c, zero));
|
||||
auto channel_unfolded = context.mark_node(std::make_shared<v1::Multiply>(input_c_squeezed, kernel_window));
|
||||
auto channel_unfolded_unsqueezed = context.mark_node(std::make_shared<v0::Unsqueeze>(channel_unfolded, zero));
|
||||
auto output_shape = context.mark_node(
|
||||
std::make_shared<opset10::Concat>(OutputVector{input_b, channel_unfolded_unsqueezed, minus_one}, 0));
|
||||
std::make_shared<v0::Concat>(OutputVector{input_b, channel_unfolded_unsqueezed, minus_one}, 0));
|
||||
auto pads = context.mark_node(
|
||||
opset10::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{0, 0, padding_h, padding_w}));
|
||||
v0::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{0, 0, padding_h, padding_w}));
|
||||
auto padded_input =
|
||||
context.mark_node(std::make_shared<opset10::Pad>(input, pads, pads, zero_f, ov::op::PadMode::CONSTANT));
|
||||
auto output = context.mark_node(std::make_shared<opset10::Gather>(padded_input, blocks_row_indices, two));
|
||||
output = context.mark_node(std::make_shared<opset10::Gather>(output, blocks_col_indices, four));
|
||||
context.mark_node(std::make_shared<v1::Pad>(input, pads, pads, zero_f, ov::op::PadMode::CONSTANT));
|
||||
auto output = context.mark_node(std::make_shared<v8::Gather>(padded_input, blocks_row_indices, two));
|
||||
output = context.mark_node(std::make_shared<v8::Gather>(output, blocks_col_indices, four));
|
||||
auto permutation_dims =
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape{6}, std::vector<int64_t>{0, 1, 2, 4, 3, 5}));
|
||||
output = context.mark_node(std::make_shared<opset10::Transpose>(output, permutation_dims));
|
||||
return {context.mark_node(std::make_shared<opset10::Reshape>(output, output_shape, false))};
|
||||
context.mark_node(v0::Constant::create(element::i64, Shape{6}, std::vector<int64_t>{0, 1, 2, 4, 3, 5}));
|
||||
output = context.mark_node(std::make_shared<v1::Transpose>(output, permutation_dims));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(output, output_shape, false))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -24,65 +24,66 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
OutputVector translate_instance_norm_inference(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& reduction_axes,
|
||||
float eps) {
|
||||
auto norm = context.mark_node(
|
||||
std::make_shared<ov::op::v6::MVN>(input, reduction_axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
|
||||
auto norm = context.mark_node(std::make_shared<v6::MVN>(input, reduction_axes, true, eps, MVNEpsMode::INSIDE_SQRT));
|
||||
if (!context.input_is_none(1)) {
|
||||
auto weight = context.get_input(1);
|
||||
weight = reshape_channelwise(context, weight, norm);
|
||||
norm = context.mark_node(std::make_shared<ov::op::v1::Multiply>(norm, weight));
|
||||
norm = context.mark_node(std::make_shared<v1::Multiply>(norm, weight));
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
auto bias = context.get_input(2);
|
||||
bias = reshape_channelwise(context, bias, norm);
|
||||
norm = context.mark_node(std::make_shared<ov::op::v1::Add>(norm, bias));
|
||||
norm = context.mark_node(std::make_shared<v1::Add>(norm, bias));
|
||||
}
|
||||
return {norm};
|
||||
}
|
||||
|
||||
OutputVector translate_instance_norm_train(NodeContext& context,
|
||||
OutputVector translate_instance_norm_train(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& reduction_axes,
|
||||
float eps) {
|
||||
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto batch_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, zero, zero));
|
||||
auto channel_dim = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, one, zero));
|
||||
auto batch_dim_1d = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(batch_dim, zero));
|
||||
auto batch_norm_channels_1d = context.mark_node(std::make_shared<ov::op::v1::Multiply>(batch_dim_1d, channel_dim));
|
||||
auto one_1d = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto tail_shape = context.mark_node(std::make_shared<ov::op::v8::Gather>(input_shape, reduction_axes, zero));
|
||||
auto reshape_shape = context.mark_node(
|
||||
std::make_shared<ov::op::v0::Concat>(OutputVector{one_1d, batch_norm_channels_1d, tail_shape}, 0));
|
||||
auto reshaped_input = context.mark_node(std::make_shared<ov::op::v1::Reshape>(input, reshape_shape, false));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto batch_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, zero, zero));
|
||||
auto channel_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, one, zero));
|
||||
auto batch_dim_1d = context.mark_node(std::make_shared<v0::Unsqueeze>(batch_dim, zero));
|
||||
auto batch_norm_channels_1d = context.mark_node(std::make_shared<v1::Multiply>(batch_dim_1d, channel_dim));
|
||||
auto one_1d = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto tail_shape = context.mark_node(std::make_shared<v8::Gather>(input_shape, reduction_axes, zero));
|
||||
auto reshape_shape =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{one_1d, batch_norm_channels_1d, tail_shape}, 0));
|
||||
auto reshaped_input = context.mark_node(std::make_shared<v1::Reshape>(input, reshape_shape, false));
|
||||
Output<Node> weight;
|
||||
Output<Node> bias;
|
||||
if (context.input_is_none(1)) {
|
||||
weight = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(one, batch_norm_channels_1d));
|
||||
weight = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(weight, input));
|
||||
weight = context.mark_node(std::make_shared<v3::Broadcast>(one, batch_norm_channels_1d));
|
||||
weight = context.mark_node(std::make_shared<v1::ConvertLike>(weight, input));
|
||||
} else {
|
||||
weight = context.get_input(1);
|
||||
weight = context.mark_node(std::make_shared<ov::op::v0::Tile>(weight, batch_dim_1d));
|
||||
weight = context.mark_node(std::make_shared<v0::Tile>(weight, batch_dim_1d));
|
||||
}
|
||||
if (context.input_is_none(2)) {
|
||||
bias = context.mark_node(std::make_shared<ov::op::v3::Broadcast>(zero, batch_norm_channels_1d));
|
||||
bias = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(bias, input));
|
||||
bias = context.mark_node(std::make_shared<v3::Broadcast>(zero, batch_norm_channels_1d));
|
||||
bias = context.mark_node(std::make_shared<v1::ConvertLike>(bias, input));
|
||||
} else {
|
||||
bias = context.get_input(2);
|
||||
bias = context.mark_node(std::make_shared<ov::op::v0::Tile>(bias, batch_dim_1d));
|
||||
bias = context.mark_node(std::make_shared<v0::Tile>(bias, batch_dim_1d));
|
||||
}
|
||||
auto running_mean = context.get_input(3);
|
||||
running_mean = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_mean, batch_dim_1d));
|
||||
running_mean = context.mark_node(std::make_shared<v0::Tile>(running_mean, batch_dim_1d));
|
||||
auto running_var = context.get_input(4);
|
||||
running_var = context.mark_node(std::make_shared<ov::op::v0::Tile>(running_var, batch_dim_1d));
|
||||
running_var = context.mark_node(std::make_shared<v0::Tile>(running_var, batch_dim_1d));
|
||||
auto batch_norm = context.mark_node(
|
||||
std::make_shared<ov::op::v5::BatchNormInference>(reshaped_input, weight, bias, running_mean, running_var, eps));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Reshape>(batch_norm, input_shape, true))};
|
||||
std::make_shared<v5::BatchNormInference>(reshaped_input, weight, bias, running_mean, running_var, eps));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(batch_norm, input_shape, true))};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -91,12 +92,11 @@ OutputVector translate_instance_norm(NodeContext& context) {
|
||||
num_inputs_check(context, 8, 9);
|
||||
auto input = context.get_input(0);
|
||||
auto eps = context.const_input<float>(7);
|
||||
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input));
|
||||
auto rank_1d = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input_shape));
|
||||
auto rank = context.mark_node(std::make_shared<ov::op::v0::Squeeze>(rank_1d));
|
||||
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto two = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto reduction_axes = context.mark_node(std::make_shared<ov::op::v4::Range>(two, rank, one, element::i64));
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input, true, element::i64);
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto reduction_axes = context.mark_node(std::make_shared<v4::Range>(two, rank, one, element::i64));
|
||||
if (context.input_is_none(3) && context.input_is_none(4)) {
|
||||
return translate_instance_norm_inference(context, input, reduction_axes, eps);
|
||||
}
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,7 +12,8 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_int(NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), element::i32))};
|
||||
num_inputs_check(context, 1, 1);
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::i32))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/mvn.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,22 +14,26 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_layer_norm(NodeContext& context) {
|
||||
num_inputs_check(context, 5, 6);
|
||||
auto eps = context.const_input<float>(4);
|
||||
auto normalized_shape = context.const_input<Shape>(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(normalized_shape.size() == 1,
|
||||
"Translation for aten::layer_norm supports only single normalized_shape value, "
|
||||
"which means normalizing over the last dimension.");
|
||||
// TODO: support any dimention
|
||||
auto axes = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
auto out_node = context.mark_node(
|
||||
std::make_shared<opset10::MVN>(context.get_input(0), axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
auto out_node =
|
||||
context.mark_node(std::make_shared<v6::MVN>(context.get_input(0), axes, true, eps, MVNEpsMode::INSIDE_SQRT));
|
||||
if (!context.input_is_none(2)) {
|
||||
out_node = context.mark_node(std::make_shared<opset10::Multiply>(out_node, context.get_input(2)));
|
||||
out_node = context.mark_node(std::make_shared<v1::Multiply>(out_node, context.get_input(2)));
|
||||
}
|
||||
if (!context.input_is_none(3)) {
|
||||
out_node = context.mark_node(std::make_shared<opset10::Add>(out_node, context.get_input(3)));
|
||||
out_node = context.mark_node(std::make_shared<v1::Add>(out_node, context.get_input(3)));
|
||||
}
|
||||
// Input with index 5 is flag "cudnn_enabled" we can ignore it
|
||||
return {out_node};
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,14 +14,17 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_len(NodeContext& context) {
|
||||
auto const_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto input = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input, element::i64));
|
||||
using namespace ov::op;
|
||||
|
||||
auto slice = context.mark_node(std::make_shared<opset10::Slice>(input_shape, const_0, const_1, const_1));
|
||||
auto squeeze = std::make_shared<opset10::Squeeze>(slice, const_0);
|
||||
OutputVector translate_len(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto input = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i64));
|
||||
|
||||
auto slice = context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, const_1, const_1));
|
||||
auto squeeze = std::make_shared<v0::Squeeze>(slice, const_0);
|
||||
return {context.mark_node(squeeze)};
|
||||
};
|
||||
|
||||
|
@ -13,6 +13,7 @@ namespace op {
|
||||
|
||||
OutputVector translate_linear(NodeContext& context) {
|
||||
// schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto matmul = context.mark_node(std::make_shared<ov::op::v0::MatMul>(x, y, false, true));
|
||||
|
@ -3,7 +3,8 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,19 +12,21 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_list_construct(NodeContext& context) {
|
||||
// Process the case when prim::ListConstruct has all inputs constant
|
||||
ov::OutputVector consts;
|
||||
for (size_t i = 0; i < context.get_input_size(); i++) {
|
||||
auto input = context.get_input_from_visible_context(i);
|
||||
auto c_node = std::dynamic_pointer_cast<opset10::Constant>(input.get_node_shared_ptr());
|
||||
auto c_node = std::dynamic_pointer_cast<v0::Constant>(input.get_node_shared_ptr());
|
||||
FRONT_END_OP_CONVERSION_CHECK(c_node, "Translation for prim::ListConstruct support only constant inputs");
|
||||
if (c_node->get_shape().size() == 0) {
|
||||
c_node = std::make_shared<opset10::Constant>(c_node->get_element_type(), Shape{1}, c_node->get_data_ptr());
|
||||
c_node = std::make_shared<v0::Constant>(c_node->get_element_type(), Shape{1}, c_node->get_data_ptr());
|
||||
}
|
||||
consts.push_back(c_node);
|
||||
}
|
||||
auto list_construct = std::make_shared<opset10::Concat>(consts, 0);
|
||||
auto list_construct = std::make_shared<v0::Concat>(consts, 0);
|
||||
if (list_construct->has_evaluate()) {
|
||||
OutputVector replacements(list_construct->get_output_size());
|
||||
|
||||
|
@ -8,28 +8,33 @@
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_log(NodeContext& context) {
|
||||
// torch.log returns a tensor with the natural logarithm of the elements of input.
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
x = context.mark_node(std::make_shared<ov::op::v0::Convert>(x, element::f32));
|
||||
auto log = context.mark_node(std::make_shared<ov::op::v0::Log>(x));
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
|
||||
auto log = context.mark_node(std::make_shared<v0::Log>(x));
|
||||
return {log};
|
||||
};
|
||||
|
||||
OutputVector translate_log2(NodeContext& context) {
|
||||
// torch.log2 returns a tensor with the logarithm to the base 2 of the elements of input.
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto two = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {2}));
|
||||
x = context.mark_node(std::make_shared<ov::op::v0::Convert>(x, element::f32));
|
||||
auto log2 = context.mark_node(std::make_shared<ov::op::v0::Log>(two));
|
||||
auto log = context.mark_node(std::make_shared<ov::op::v0::Log>(x));
|
||||
auto res = context.mark_node(std::make_shared<ov::op::v1::Divide>(log, log2));
|
||||
auto two = context.mark_node(v0::Constant::create(element::f32, Shape{}, {2}));
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
|
||||
auto log2 = context.mark_node(std::make_shared<v0::Log>(two));
|
||||
auto log = context.mark_node(std::make_shared<v0::Log>(x));
|
||||
auto res = context.mark_node(std::make_shared<v1::Divide>(log, log2));
|
||||
return {res};
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,11 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,15 +15,18 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_masked_fill(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto data = context.get_input(0);
|
||||
auto mask = context.get_input(1);
|
||||
auto value = context.const_input<float>(2);
|
||||
auto data_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(data));
|
||||
auto value_const = context.mark_node(opset10::Constant::create(element::f32, Shape({}), {value}));
|
||||
auto broadcasted_value = context.mark_node(std::make_shared<opset10::Broadcast>(value_const, data_shape));
|
||||
auto bool_mask = context.mark_node(std::make_shared<opset10::Convert>(mask, element::boolean));
|
||||
return {context.mark_node(std::make_shared<opset10::Select>(bool_mask, broadcasted_value, data))};
|
||||
auto data_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data));
|
||||
auto value_const = context.mark_node(v0::Constant::create(element::f32, Shape({}), {value}));
|
||||
auto broadcasted_value = context.mark_node(std::make_shared<v3::Broadcast>(value_const, data_shape));
|
||||
auto bool_mask = context.mark_node(std::make_shared<v0::Convert>(mask, element::boolean));
|
||||
return {context.mark_node(std::make_shared<v1::Select>(bool_mask, broadcasted_value, data))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/max_pool.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,20 +11,18 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_max_poolnd(NodeContext& context) {
|
||||
num_inputs_check(context, 6, 6);
|
||||
auto kernel = context.const_input<Shape>(1);
|
||||
auto strides = context.const_input<Strides>(2);
|
||||
auto pads = context.const_input<Shape>(3); // pytorch supports only symmetric paddings
|
||||
auto dilations = context.const_input<Strides>(4);
|
||||
auto rounding_type = context.const_input<bool>(5) ? ov::op::RoundingType::CEIL : ov::op::RoundingType::FLOOR;
|
||||
auto rounding_type = context.const_input<bool>(5) ? RoundingType::CEIL : RoundingType::FLOOR;
|
||||
|
||||
return {context.mark_node(std::make_shared<opset10::MaxPool>(context.get_input(0),
|
||||
strides,
|
||||
dilations,
|
||||
pads,
|
||||
pads,
|
||||
kernel,
|
||||
rounding_type))};
|
||||
return {context.mark_node(
|
||||
std::make_shared<v8::MaxPool>(context.get_input(0), strides, dilations, pads, pads, kernel, rounding_type))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/reduce_mean.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,12 +12,13 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_mean(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 4);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto keep_dims = context.const_input<bool>(2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(3),
|
||||
"Only False is supported for input with index 3 for aten::mean");
|
||||
return {context.mark_node(std::make_shared<opset10::ReduceMean>(x, y, keep_dims))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(x, y, keep_dims))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -11,12 +11,11 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_meshgrid(NodeContext& context) {
|
||||
OutputVector inputs{context.get_input(0)};
|
||||
std::string indexing = "ij";
|
||||
if (!context.input_is_none(1)) {
|
||||
indexing = context.const_input<std::string>(1);
|
||||
}
|
||||
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), inputs);
|
||||
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
|
||||
auto attrs = node->get_attrs();
|
||||
attrs["indexing"] = indexing;
|
||||
node->set_attrs(attrs);
|
||||
|
@ -3,7 +3,14 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/maximum.hpp"
|
||||
#include "openvino/op/minimum.hpp"
|
||||
#include "openvino/op/reduce_max.hpp"
|
||||
#include "openvino/op/reduce_min.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/topk.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,31 +18,33 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_max(NodeContext& context) {
|
||||
// torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
// torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
// torch.max(input)
|
||||
if (context.input_is_none(1) && context.input_is_none(2)) {
|
||||
auto axes = get_axes_range(context, 0);
|
||||
return {context.mark_node(std::make_shared<opset10::ReduceMax>(x, axes, false))};
|
||||
return {context.mark_node(std::make_shared<v1::ReduceMax>(x, axes, false))};
|
||||
}
|
||||
// torch.max(input, other)
|
||||
if (context.input_is_none(2)) {
|
||||
auto y = context.get_input(1);
|
||||
return {context.mark_node(std::make_shared<opset10::Maximum>(x, y))};
|
||||
return {context.mark_node(std::make_shared<v1::Maximum>(x, y))};
|
||||
}
|
||||
// torch.max(input, dim, keepdim), returns values and indicies
|
||||
auto axes_node = context.get_input(1);
|
||||
auto axis_const = context.const_input<int64_t>(1);
|
||||
auto keepdims = context.const_input<bool>(2);
|
||||
auto values = context.mark_node(std::make_shared<opset10::ReduceMax>(x, axes_node, keepdims));
|
||||
auto k = context.mark_node(std::make_shared<opset10::Constant>(element::i64, Shape{}, 1));
|
||||
auto topk =
|
||||
std::make_shared<opset10::TopK>(x, k, axis_const, opset10::TopK::Mode::MAX, opset10::TopK::SortType::NONE);
|
||||
auto indicies = context.mark_node(std::make_shared<opset10::Convert>(topk->output(1), element::i64));
|
||||
auto values = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes_node, keepdims));
|
||||
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i64, Shape{}, 1));
|
||||
auto topk = std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MAX, v3::TopK::SortType::NONE);
|
||||
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
|
||||
if (!keepdims) {
|
||||
indicies = std::make_shared<opset10::Squeeze>(indicies, axes_node);
|
||||
indicies = std::make_shared<v0::Squeeze>(indicies, axes_node);
|
||||
}
|
||||
return {values, indicies};
|
||||
};
|
||||
@ -43,29 +52,28 @@ OutputVector translate_max(NodeContext& context) {
|
||||
OutputVector translate_min(NodeContext& context) {
|
||||
// torch.min (same for torch.max) actually has two interfaces smashed together:
|
||||
// torch.min(x, dim, keepdim) and torch.min(x, y)
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
// torch.min(input)
|
||||
if (context.input_is_none(1) && context.input_is_none(2)) {
|
||||
auto axes = get_axes_range(context, 0);
|
||||
return {context.mark_node(std::make_shared<opset10::ReduceMin>(x, axes, false))};
|
||||
return {context.mark_node(std::make_shared<v1::ReduceMin>(x, axes, false))};
|
||||
}
|
||||
// torch.min(input, other)
|
||||
if (context.input_is_none(2)) {
|
||||
auto y = context.get_input(1);
|
||||
return {context.mark_node(std::make_shared<opset10::Minimum>(x, y))};
|
||||
return {context.mark_node(std::make_shared<v1::Minimum>(x, y))};
|
||||
}
|
||||
// torch.min(input, dim, keepdim), returns values and indicies
|
||||
auto axes_node = context.get_input(1);
|
||||
auto axis_const = context.const_input<int64_t>(1);
|
||||
auto keepdims = context.const_input<bool>(2);
|
||||
auto values = context.mark_node(std::make_shared<opset10::ReduceMin>(x, axes_node, keepdims));
|
||||
auto k = context.mark_node(std::make_shared<opset10::Constant>(element::i64, Shape{}, 1));
|
||||
auto topk =
|
||||
std::make_shared<opset10::TopK>(x, k, axis_const, opset10::TopK::Mode::MIN, opset10::TopK::SortType::NONE);
|
||||
auto indicies = context.mark_node(std::make_shared<opset10::Convert>(topk->output(1), element::i64));
|
||||
|
||||
auto values = context.mark_node(std::make_shared<v1::ReduceMin>(x, axes_node, keepdims));
|
||||
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i64, Shape{}, 1));
|
||||
auto topk = std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MIN, v3::TopK::SortType::NONE);
|
||||
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
|
||||
if (!keepdims) {
|
||||
indicies = std::make_shared<opset10::Squeeze>(indicies, axes_node);
|
||||
indicies = std::make_shared<v0::Squeeze>(indicies, axes_node);
|
||||
}
|
||||
return {values, indicies};
|
||||
};
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,11 +13,14 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_neg(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto const_neg_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto cast = context.mark_node(std::make_shared<opset10::ConvertLike>(const_neg_1, x));
|
||||
return {context.mark_node(std::make_shared<opset10::Multiply>(x, cast))};
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto cast = context.mark_node(std::make_shared<v1::ConvertLike>(const_neg_1, x));
|
||||
return {context.mark_node(std::make_shared<v1::Multiply>(x, cast))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,12 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/non_max_suppression.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,27 +16,29 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_nms(NodeContext& context) {
|
||||
auto const_0 = context.mark_node(opset9::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(opset9::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto const_2 = context.mark_node(opset9::Constant::create(element::i64, Shape{1}, {2}));
|
||||
// the shape that is required by PyTorch operator differs from the shape required in OpenVino
|
||||
auto boxes_shape = context.mark_node(opset9::Constant::create(element::i64, Shape{3}, {1, -1, 4}));
|
||||
using namespace ov::op;
|
||||
|
||||
auto boxes = context.mark_node(std::make_shared<opset9::Reshape>(context.get_input(0), boxes_shape, false));
|
||||
OutputVector translate_nms(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto const_2 = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2}));
|
||||
// the shape that is required by PyTorch operator differs from the shape required in OpenVino
|
||||
auto boxes_shape = context.mark_node(v0::Constant::create(element::i64, Shape{3}, {1, -1, 4}));
|
||||
|
||||
auto boxes = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(0), boxes_shape, false));
|
||||
// Unsqueeze operator is also used to align shapes required by PyTorch and OpenVino
|
||||
auto axis_01 = context.mark_node(opset9::Constant::create(element::i64, Shape{2}, {0, 1}));
|
||||
auto scores = context.mark_node(std::make_shared<opset9::Unsqueeze>(context.get_input(1), axis_01));
|
||||
auto axis_01 = context.mark_node(v0::Constant::create(element::i64, Shape{2}, {0, 1}));
|
||||
auto scores = context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(1), axis_01));
|
||||
auto max_output_per_class =
|
||||
context.mark_node(opset9::Constant::create(element::i64, Shape{1}, {std::numeric_limits<int64_t>::max()}));
|
||||
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {std::numeric_limits<int64_t>::max()}));
|
||||
auto iou_threshold = context.get_input(2);
|
||||
|
||||
auto nms_out = context.mark_node(
|
||||
std::make_shared<opset9::NonMaxSuppression>(boxes, scores, max_output_per_class, iou_threshold));
|
||||
auto select = context.mark_node(std::make_shared<opset9::Gather>(nms_out, const_2, const_1));
|
||||
auto squeeze = std::make_shared<opset9::Squeeze>(select, const_1);
|
||||
auto nms_out =
|
||||
context.mark_node(std::make_shared<v9::NonMaxSuppression>(boxes, scores, max_output_per_class, iou_threshold));
|
||||
auto select = context.mark_node(std::make_shared<v8::Gather>(nms_out, const_2, const_1));
|
||||
|
||||
return {context.mark_node(squeeze)};
|
||||
return {context.mark_node(std::make_shared<v0::Squeeze>(select, const_1))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/non_zero.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,11 +13,14 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_nonzero(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto cond = context.get_input(0);
|
||||
auto non_zero = context.mark_node(std::make_shared<opset10::NonZero>(cond));
|
||||
auto input_order = context.mark_node(opset10::Constant::create(element::i64, Shape{2}, {1, 0}));
|
||||
return {context.mark_node(std::make_shared<opset10::Transpose>(non_zero, input_order))};
|
||||
auto non_zero = context.mark_node(std::make_shared<v3::NonZero>(cond));
|
||||
auto input_order = context.mark_node(v0::Constant::create(element::i64, Shape{2}, {1, 0}));
|
||||
return {context.mark_node(std::make_shared<v1::Transpose>(non_zero, input_order))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,14 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/abs.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/power.hpp"
|
||||
#include "openvino/op/reduce_l1.hpp"
|
||||
#include "openvino/op/reduce_l2.hpp"
|
||||
#include "openvino/op/reduce_max.hpp"
|
||||
#include "openvino/op/reduce_min.hpp"
|
||||
#include "openvino/op/reduce_sum.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,39 +18,35 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_norm(NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto p = context.const_input<float>(1);
|
||||
auto dim = context.get_input(2);
|
||||
auto keep_dim = context.const_input<bool>(3);
|
||||
|
||||
OutputVector res;
|
||||
|
||||
Output<Node> res;
|
||||
if (p == 1) {
|
||||
auto reduce_l1 = context.mark_node(std::make_shared<opset10::ReduceL1>(input_tensor, dim, keep_dim));
|
||||
res.push_back(reduce_l1);
|
||||
res = context.mark_node(std::make_shared<v4::ReduceL1>(input_tensor, dim, keep_dim));
|
||||
} else if (p == 2) {
|
||||
auto reduce_l2 = context.mark_node(std::make_shared<opset10::ReduceL2>(input_tensor, dim, keep_dim));
|
||||
res.push_back(reduce_l2);
|
||||
res = context.mark_node(std::make_shared<v4::ReduceL2>(input_tensor, dim, keep_dim));
|
||||
} else if (p == std::numeric_limits<float>::infinity()) {
|
||||
auto abs = context.mark_node(std::make_shared<opset10::Abs>(input_tensor));
|
||||
auto max = context.mark_node(std::make_shared<opset10::ReduceMax>(abs, dim, keep_dim));
|
||||
res.push_back(max);
|
||||
auto abs = context.mark_node(std::make_shared<v0::Abs>(input_tensor));
|
||||
res = context.mark_node(std::make_shared<v1::ReduceMax>(abs, dim, keep_dim));
|
||||
} else if (p == -std::numeric_limits<float>::infinity()) {
|
||||
auto abs = context.mark_node(std::make_shared<opset10::Abs>(input_tensor));
|
||||
auto min = context.mark_node(std::make_shared<opset10::ReduceMin>(abs, dim, keep_dim));
|
||||
res.push_back(min);
|
||||
auto abs = context.mark_node(std::make_shared<v0::Abs>(input_tensor));
|
||||
res = context.mark_node(std::make_shared<v1::ReduceMin>(abs, dim, keep_dim));
|
||||
} else {
|
||||
auto const_p = context.mark_node(opset10::Constant::create(element::f64, Shape{1}, {p}));
|
||||
auto const_p_inv = context.mark_node(opset10::Constant::create(element::f64, Shape{1}, {1.0 / p}));
|
||||
auto abs = context.mark_node(std::make_shared<opset10::Abs>(input_tensor));
|
||||
auto pow = context.mark_node(std::make_shared<opset10::Power>(abs, const_p));
|
||||
auto sum = context.mark_node(std::make_shared<opset10::ReduceSum>(pow, dim, keep_dim));
|
||||
auto pow_inv = context.mark_node(std::make_shared<opset10::Power>(sum, const_p_inv));
|
||||
res.push_back(pow_inv);
|
||||
auto const_p = context.mark_node(v0::Constant::create(element::f64, Shape{1}, {p}));
|
||||
auto const_p_inv = context.mark_node(v0::Constant::create(element::f64, Shape{1}, {1.0 / p}));
|
||||
auto abs = context.mark_node(std::make_shared<v0::Abs>(input_tensor));
|
||||
auto pow = context.mark_node(std::make_shared<v1::Power>(abs, const_p));
|
||||
auto sum = context.mark_node(std::make_shared<v1::ReduceSum>(pow, dim, keep_dim));
|
||||
res = context.mark_node(std::make_shared<v1::Power>(sum, const_p_inv));
|
||||
}
|
||||
|
||||
return res;
|
||||
return {res};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,6 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,7 +11,8 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_numel(NodeContext& context) {
|
||||
return {numel(context, 0)};
|
||||
num_inputs_check(context, 1, 1);
|
||||
return {numel(context, context.get_input(0))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,9 +2,16 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/pad.hpp"
|
||||
|
||||
#include "openvino/core/coordinate_diff.hpp"
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,15 +19,17 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_pad(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 4);
|
||||
auto data = context.get_input(0);
|
||||
auto paddings = context.const_input<std::vector<int64_t>>(1);
|
||||
std::string mode = "constant";
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(data, element::i32));
|
||||
auto rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, element::i32));
|
||||
auto reduced_rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank));
|
||||
auto zero = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto zero_f = context.mark_node(opset10::Constant::create(element::f32, Shape{}, {0}));
|
||||
Output<Node> shape;
|
||||
Output<Node> rank;
|
||||
std::tie(shape, rank) = get_shape_rank(context, data);
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
size_t pad_size_half = paddings.size() / 2;
|
||||
std::vector<int64_t> pad_b(pad_size_half, 0);
|
||||
std::vector<int64_t> pad_e(pad_size_half, 0);
|
||||
@ -28,15 +37,13 @@ OutputVector translate_pad(NodeContext& context) {
|
||||
pad_b[i] = paddings[paddings.size() - 2 - 2 * i];
|
||||
pad_e[i] = paddings[paddings.size() - 1 - 2 * i];
|
||||
}
|
||||
auto pads_begin_short = context.mark_node(opset10::Constant::create(element::i32, Shape{pad_size_half}, pad_b));
|
||||
auto pads_end_short = context.mark_node(opset10::Constant::create(element::i32, Shape{pad_size_half}, pad_e));
|
||||
auto pads_short_len = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {pad_size_half}));
|
||||
auto pads_diff = context.mark_node(std::make_shared<opset10::Subtract>(rank, pads_short_len));
|
||||
auto pads_remaining = context.mark_node(std::make_shared<opset10::Broadcast>(zero, pads_diff));
|
||||
auto pads_begins =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(NodeVector{pads_remaining, pads_begin_short}, 0));
|
||||
auto pads_ends =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(NodeVector{pads_remaining, pads_end_short}, 0));
|
||||
auto pads_begin_short = context.mark_node(v0::Constant::create(element::i32, Shape{pad_size_half}, pad_b));
|
||||
auto pads_end_short = context.mark_node(v0::Constant::create(element::i32, Shape{pad_size_half}, pad_e));
|
||||
auto pads_short_len = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {pad_size_half}));
|
||||
auto pads_diff = context.mark_node(std::make_shared<v1::Subtract>(rank, pads_short_len));
|
||||
auto pads_remaining = context.mark_node(std::make_shared<v3::Broadcast>(zero, pads_diff));
|
||||
auto pads_begins = context.mark_node(std::make_shared<v0::Concat>(NodeVector{pads_remaining, pads_begin_short}, 0));
|
||||
auto pads_ends = context.mark_node(std::make_shared<v0::Concat>(NodeVector{pads_remaining, pads_end_short}, 0));
|
||||
if (!context.input_is_none(2)) {
|
||||
mode = context.const_input<std::string>(2);
|
||||
}
|
||||
@ -45,64 +52,54 @@ OutputVector translate_pad(NodeContext& context) {
|
||||
int64_t pad_r;
|
||||
auto pad_last_id = paddings.size();
|
||||
auto cur = data.get_node_shared_ptr();
|
||||
auto step = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto step = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
for (size_t i = 0; i < pad_size_half; i++) {
|
||||
ov::NodeVector tensors;
|
||||
pad_r = paddings[pad_last_id - (2 * i + 1)];
|
||||
pad_l = paddings[pad_last_id - (2 * i + 2)];
|
||||
auto axes = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2 + i}));
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + i}));
|
||||
if (pad_l > 0) {
|
||||
auto start =
|
||||
context.mark_node(context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {-pad_l})));
|
||||
auto end = context.mark_node(std::make_shared<opset10::Gather>(
|
||||
shape,
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2 + i})),
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}))));
|
||||
auto start = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-pad_l}));
|
||||
auto end = context.mark_node(std::make_shared<v8::Gather>(shape, axes, zero_1d));
|
||||
|
||||
auto left = context.mark_node(std::make_shared<opset10::Slice>(cur, start, end, step, axes));
|
||||
auto left = context.mark_node(std::make_shared<v8::Slice>(cur, start, end, step, axes));
|
||||
tensors.push_back(left);
|
||||
}
|
||||
if (pad_l < 0 || pad_r < 0) {
|
||||
auto start = context.mark_node(
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {pad_l < 0 ? -pad_l : 0})));
|
||||
auto end = context.mark_node(
|
||||
context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {pad_r < 0 ? pad_r : 0})));
|
||||
auto middle = context.mark_node(std::make_shared<opset10::Slice>(cur, start, end, step, axes));
|
||||
auto start = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {pad_l < 0 ? -pad_l : 0}));
|
||||
auto end = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {pad_r < 0 ? pad_r : 0}));
|
||||
auto middle = context.mark_node(std::make_shared<v8::Slice>(cur, start, end, step, axes));
|
||||
tensors.push_back(middle);
|
||||
} else {
|
||||
tensors.push_back(cur);
|
||||
}
|
||||
if (pad_r > 0) {
|
||||
auto start = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto end = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {pad_r}));
|
||||
auto right = context.mark_node(std::make_shared<opset10::Slice>(cur, start, end, step, axes));
|
||||
auto end = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {pad_r}));
|
||||
auto right = context.mark_node(std::make_shared<v8::Slice>(cur, zero_1d, end, step, axes));
|
||||
tensors.push_back(right);
|
||||
}
|
||||
if (tensors.size()) {
|
||||
cur = context.mark_node(std::make_shared<opset10::Concat>(tensors, 2 + i));
|
||||
cur = context.mark_node(std::make_shared<v0::Concat>(tensors, 2 + i));
|
||||
}
|
||||
}
|
||||
return {cur};
|
||||
}
|
||||
if (mode == "constant") {
|
||||
if (!context.input_is_none(3)) {
|
||||
auto pad_value = context.get_input(3);
|
||||
return {context.mark_node(
|
||||
std::make_shared<opset10::Pad>(data, pads_begins, pads_ends, pad_value, ov::op::PadMode::CONSTANT))};
|
||||
}
|
||||
return {context.mark_node(
|
||||
std::make_shared<opset10::Pad>(data, pads_begins, pads_ends, zero_f, ov::op::PadMode::CONSTANT))};
|
||||
const std::map<std::string, PadMode> pt_to_ov_pad{
|
||||
{"constant", PadMode::CONSTANT},
|
||||
{"reflect", PadMode::REFLECT},
|
||||
{"replicate", PadMode::EDGE},
|
||||
};
|
||||
Output<Node> pad_value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
if (mode == "constant" && !context.input_is_none(3)) {
|
||||
pad_value = context.get_input(3);
|
||||
}
|
||||
if (mode == "reflect") {
|
||||
return {context.mark_node(
|
||||
std::make_shared<opset10::Pad>(data, pads_begins, pads_ends, zero_f, ov::op::PadMode::REFLECT))};
|
||||
}
|
||||
if (mode == "replicate") {
|
||||
return {context.mark_node(
|
||||
std::make_shared<opset10::Pad>(data, pads_begins, pads_ends, zero_f, ov::op::PadMode::EDGE))};
|
||||
}
|
||||
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "aten::pad conversion doesn't support [ " + mode + " ] padding mode");
|
||||
auto ov_mode = pt_to_ov_pad.find(mode);
|
||||
FRONT_END_OP_CONVERSION_CHECK(ov_mode != pt_to_ov_pad.end(),
|
||||
"aten::pad conversion doesn't support [ ",
|
||||
mode,
|
||||
" ] padding mode");
|
||||
return {context.mark_node(std::make_shared<v1::Pad>(data, pads_begins, pads_ends, pad_value, ov_mode->second))};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
@ -12,7 +12,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_pow(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
align_eltwise_input_types(context, lhs, rhs, true);
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/power.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,11 +13,14 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_reciprocal(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto const_neg_1 = opset10::Constant::create(element::i32, Shape{}, {-1});
|
||||
auto cast = std::make_shared<opset10::ConvertLike>(const_neg_1, x);
|
||||
auto power = std::make_shared<opset10::Power>(x, cast);
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto cast = context.mark_node(std::make_shared<v1::ConvertLike>(const_neg_1, x));
|
||||
auto power = context.mark_node(std::make_shared<v1::Power>(x, cast));
|
||||
return {context.mark_node(power)};
|
||||
};
|
||||
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/clamp.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,8 +12,9 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_relu6(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
return {context.mark_node(std::make_shared<opset10::Clamp>(x, 0., 6.))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Clamp>(x, 0., 6.))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -7,19 +7,23 @@
|
||||
#include "openvino/op/floor.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_remainder(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
auto div = context.mark_node(std::make_shared<ov::op::v1::Divide>(x, y, true));
|
||||
auto floor = context.mark_node(std::make_shared<ov::op::v0::Floor>(div));
|
||||
auto quo = context.mark_node(std::make_shared<ov::op::v1::Multiply>(floor, y));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Subtract>(x, quo))};
|
||||
auto div = context.mark_node(std::make_shared<v1::Divide>(x, y, true));
|
||||
auto floor = context.mark_node(std::make_shared<v0::Floor>(div));
|
||||
auto quo = context.mark_node(std::make_shared<v1::Multiply>(floor, y));
|
||||
return {context.mark_node(std::make_shared<v1::Subtract>(x, quo))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,15 +14,18 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_repeat(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto repeats = context.get_input(1);
|
||||
auto one = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto sizes_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(repeats, element::i64));
|
||||
auto expand_shape = context.mark_node(std::make_shared<opset10::Broadcast>(one, sizes_shape));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto sizes_shape = context.mark_node(std::make_shared<v3::ShapeOf>(repeats, element::i64));
|
||||
auto expand_shape = context.mark_node(std::make_shared<v3::Broadcast>(one, sizes_shape));
|
||||
auto expanded_input =
|
||||
context.mark_node(std::make_shared<opset10::Broadcast>(x, expand_shape, ov::op::BroadcastType::BIDIRECTIONAL));
|
||||
return {context.mark_node(std::make_shared<opset10::Tile>(expanded_input, repeats))};
|
||||
context.mark_node(std::make_shared<v3::Broadcast>(x, expand_shape, BroadcastType::BIDIRECTIONAL));
|
||||
return {context.mark_node(std::make_shared<v0::Tile>(expanded_input, repeats))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,15 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,13 +19,15 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
OutputVector generate_indices_from_repeats_tensor(std::vector<int32_t> repeats, NodeContext& context) {
|
||||
OutputVector generate_indices_from_repeats_tensor(const NodeContext& context, const std::vector<int32_t>& repeats) {
|
||||
OutputVector all_indices;
|
||||
for (size_t i = 0; i < repeats.size(); i++) {
|
||||
Shape indices_shape{static_cast<size_t>(repeats.at(i))};
|
||||
std::vector<int32_t> indices_vec(repeats.at(i), i);
|
||||
auto indices = context.mark_node(opset10::Constant::create(element::i32, indices_shape, indices_vec));
|
||||
auto indices = context.mark_node(v0::Constant::create(element::i32, indices_shape, indices_vec));
|
||||
all_indices.push_back(indices);
|
||||
}
|
||||
return all_indices;
|
||||
@ -25,62 +35,61 @@ OutputVector generate_indices_from_repeats_tensor(std::vector<int32_t> repeats,
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_repeat_interleave(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
// constants
|
||||
auto const_0 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto const_1_list = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_neg_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto const_1_list = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
|
||||
// inputs
|
||||
auto input = context.get_input(0);
|
||||
std::shared_ptr<ov::Node> result;
|
||||
|
||||
auto repeats_ext_node = context.get_input_from_visible_context(1).get_node_shared_ptr();
|
||||
auto repeats_fw_node = std::dynamic_pointer_cast<opset10::Constant>(repeats_ext_node);
|
||||
auto repeats_fw_node = std::dynamic_pointer_cast<v0::Constant>(repeats_ext_node);
|
||||
if (repeats_fw_node && repeats_fw_node->cast_vector<int32_t>().size() > 1) {
|
||||
// repeats is Constant with more then 1 element
|
||||
auto repeats = repeats_fw_node->cast_vector<int32_t>();
|
||||
if (context.input_is_none(2)) {
|
||||
// case (repeats=tensor, dim=None)
|
||||
auto flat_shape = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(input, flat_shape, false));
|
||||
OutputVector all_indices = generate_indices_from_repeats_tensor(repeats, context);
|
||||
auto concat = context.mark_node(std::make_shared<opset10::Concat>(all_indices, 0));
|
||||
result = std::make_shared<opset10::Gather>(reshape, concat, const_0);
|
||||
auto flat_shape = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(input, flat_shape, false));
|
||||
OutputVector all_indices = generate_indices_from_repeats_tensor(context, repeats);
|
||||
auto concat = context.mark_node(std::make_shared<v0::Concat>(all_indices, 0));
|
||||
result = std::make_shared<v8::Gather>(reshape, concat, const_0);
|
||||
} else {
|
||||
// case (repeats=tensor, dim=number)
|
||||
auto dimension = context.get_input(2);
|
||||
OutputVector all_indices = generate_indices_from_repeats_tensor(repeats, context);
|
||||
auto concat = context.mark_node(std::make_shared<opset10::Concat>(all_indices, 0));
|
||||
result = std::make_shared<opset10::Gather>(input, concat, dimension);
|
||||
OutputVector all_indices = generate_indices_from_repeats_tensor(context, repeats);
|
||||
auto concat = context.mark_node(std::make_shared<v0::Concat>(all_indices, 0));
|
||||
result = std::make_shared<v8::Gather>(input, concat, dimension);
|
||||
}
|
||||
} else {
|
||||
// repeats is not Constant or single element constant
|
||||
// Curently we support only case when repeats contains only one element. Otherwise next Reshape will fail.
|
||||
auto repeats_input =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(1), const_1_list, false));
|
||||
auto repeats =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{repeats_input, const_1_list}, 0));
|
||||
auto shape_perm = context.mark_node(opset10::Constant::create(element::i32, Shape{2}, {1, 0}));
|
||||
context.mark_node(std::make_shared<v1::Reshape>(context.get_input(1), const_1_list, false));
|
||||
auto repeats = context.mark_node(std::make_shared<v0::Concat>(OutputVector{repeats_input, const_1_list}, 0));
|
||||
auto shape_perm = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
|
||||
if (context.input_is_none(2)) {
|
||||
// case (repeats=number, dim=None)
|
||||
auto flat_shape = context.mark_node(opset10::Constant::create(element::i32, Shape{2}, {1, -1}));
|
||||
auto reshape = context.mark_node(std::make_shared<opset10::Reshape>(input, flat_shape, false));
|
||||
auto tile = context.mark_node(std::make_shared<opset10::Tile>(reshape, repeats));
|
||||
auto transpose = context.mark_node(std::make_shared<opset10::Transpose>(tile, shape_perm));
|
||||
result = std::make_shared<opset10::Reshape>(transpose, const_neg_1, false);
|
||||
auto flat_shape = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, -1}));
|
||||
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(input, flat_shape, false));
|
||||
auto tile = context.mark_node(std::make_shared<v0::Tile>(reshape, repeats));
|
||||
auto transpose = context.mark_node(std::make_shared<v1::Transpose>(tile, shape_perm));
|
||||
result = std::make_shared<v1::Reshape>(transpose, const_neg_1, false);
|
||||
} else {
|
||||
// case (repeats=number, dim=number)
|
||||
auto dimension = context.get_input(2);
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(input, element::i32));
|
||||
auto input_dim_size = context.mark_node(std::make_shared<opset10::Gather>(input_shape, dimension, const_0));
|
||||
auto range =
|
||||
context.mark_node(std::make_shared<opset10::Range>(const_0, input_dim_size, const_1, element::i32));
|
||||
auto range_unsqeezed = context.mark_node(std::make_shared<opset10::Unsqueeze>(range, const_0));
|
||||
auto tile = context.mark_node(std::make_shared<opset10::Tile>(range_unsqeezed, repeats));
|
||||
auto transpose = context.mark_node(std::make_shared<opset10::Transpose>(tile, shape_perm));
|
||||
auto flatten = context.mark_node(std::make_shared<opset10::Reshape>(transpose, const_neg_1, false));
|
||||
result = std::make_shared<opset10::Gather>(input, flatten, dimension);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto input_dim_size = context.mark_node(std::make_shared<v8::Gather>(input_shape, dimension, const_0));
|
||||
auto range = context.mark_node(std::make_shared<v4::Range>(const_0, input_dim_size, const_1, element::i32));
|
||||
auto range_unsqeezed = context.mark_node(std::make_shared<v0::Unsqueeze>(range, const_0));
|
||||
auto tile = context.mark_node(std::make_shared<v0::Tile>(range_unsqeezed, repeats));
|
||||
auto transpose = context.mark_node(std::make_shared<v1::Transpose>(tile, shape_perm));
|
||||
auto flatten = context.mark_node(std::make_shared<v1::Reshape>(transpose, const_neg_1, false));
|
||||
result = std::make_shared<v8::Gather>(input, flatten, dimension);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,9 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/reshape.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -17,9 +17,9 @@ OutputVector translate_reshape(NodeContext& context) {
|
||||
// Schema: aten::view(Tensor input, int[] shape) -> Tensor
|
||||
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
|
||||
// For shape parameter, int[] is converted into single dimensional Tensor.
|
||||
auto reshape =
|
||||
context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), context.get_input(1), false));
|
||||
return {reshape};
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto reshape = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), context.get_input(1), false);
|
||||
return {context.mark_node(reshape)};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,8 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,10 +13,11 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_reshape_as(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto shape_tesnor = context.get_input(1);
|
||||
auto desired_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_tesnor));
|
||||
return {context.mark_node(std::make_shared<opset10::Reshape>(input_tensor, desired_shape, false))};
|
||||
auto desired_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(shape_tesnor));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Reshape>(input_tensor, desired_shape, false))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,8 +2,12 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/roll.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,7 +15,10 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_roll(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
const auto data = context.get_input(0);
|
||||
const auto shifts = context.get_input(1);
|
||||
const auto axes = context.get_input(2);
|
||||
@ -19,16 +26,16 @@ OutputVector translate_roll(NodeContext& context) {
|
||||
const auto axes_pshape = axes.get_partial_shape();
|
||||
const auto match_dims = axes_pshape.compatible(shifts_pshape);
|
||||
if (!match_dims) {
|
||||
const auto const_minus_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
||||
const auto axis_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
|
||||
const auto flat = std::make_shared<opset10::Reshape>(data, const_minus_1, false);
|
||||
const auto roll = std::make_shared<opset10::Roll>(flat, shifts, axis_0);
|
||||
const auto shape_of_data = std::make_shared<opset10::ShapeOf>(data);
|
||||
const auto reshape = std::make_shared<opset10::Reshape>(roll, shape_of_data, false);
|
||||
const auto const_minus_1 = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||
const auto axis_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
const auto flat = std::make_shared<v1::Reshape>(data, const_minus_1, false);
|
||||
const auto roll = std::make_shared<v7::Roll>(flat, shifts, axis_0);
|
||||
const auto shape_of_data = std::make_shared<v3::ShapeOf>(data);
|
||||
const auto reshape = std::make_shared<v1::Reshape>(roll, shape_of_data, false);
|
||||
context.mark_nodes({const_minus_1, flat, roll, shape_of_data, reshape});
|
||||
return {reshape};
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Roll>(data, shifts, axes))};
|
||||
return {context.mark_node(std::make_shared<v7::Roll>(data, shifts, axes))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,10 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/divide.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/sqrt.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,12 +14,15 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_rsqrt(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto data = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(data));
|
||||
auto one_const = context.mark_node(opset10::Constant::create(element::f32, Shape({}), {1}));
|
||||
auto sqrt_data = context.mark_node(std::make_shared<opset10::Sqrt>(data));
|
||||
return {context.mark_node(std::make_shared<opset10::Divide>(one_const, sqrt_data))};
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data));
|
||||
auto one_const = context.mark_node(v0::Constant::create(element::f32, Shape({}), {1}));
|
||||
auto sqrt_data = context.mark_node(std::make_shared<v0::Sqrt>(data));
|
||||
return {context.mark_node(std::make_shared<v1::Divide>(one_const, sqrt_data))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,15 +13,18 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_rsub(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto self = context.get_input(0);
|
||||
auto other = context.get_input(1);
|
||||
auto alpha = context.get_input(2);
|
||||
align_eltwise_input_types(context, self, other);
|
||||
// reverse aten::sub other - self * alpha
|
||||
auto alpha_casted = context.mark_node(std::make_shared<opset10::ConvertLike>(alpha, self));
|
||||
auto alpha_mul = context.mark_node(std::make_shared<opset10::Multiply>(self, alpha_casted));
|
||||
return {context.mark_node(std::make_shared<opset10::Subtract>(other, alpha_mul))};
|
||||
auto alpha_casted = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, self));
|
||||
auto alpha_mul = context.mark_node(std::make_shared<v1::Multiply>(self, alpha_casted));
|
||||
return {context.mark_node(std::make_shared<v1::Subtract>(other, alpha_mul))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,8 +2,15 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/select.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/less.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,22 +18,25 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_select(NodeContext& context) {
|
||||
auto const_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_minus_1 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto const_0 = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {0}));
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto dim = context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(1), const_1, false));
|
||||
auto start = context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(2), const_1, false));
|
||||
auto dim = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(1), const_1, false));
|
||||
auto start = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(2), const_1, false));
|
||||
|
||||
auto less = context.mark_node(std::make_shared<opset10::Less>(start, const_0));
|
||||
auto const_1_signed = context.mark_node(std::make_shared<opset10::Select>(less, const_minus_1, const_1));
|
||||
auto stop = context.mark_node(std::make_shared<opset10::Add>(start, const_1_signed));
|
||||
auto less = context.mark_node(std::make_shared<v1::Less>(start, const_0));
|
||||
auto const_1_signed = context.mark_node(std::make_shared<v1::Select>(less, const_minus_1, const_1));
|
||||
auto stop = context.mark_node(std::make_shared<v1::Add>(start, const_1_signed));
|
||||
|
||||
auto slice_node =
|
||||
context.mark_node(std::make_shared<opset10::Slice>(input_tensor, start, stop, const_1_signed, dim));
|
||||
auto slice_node = context.mark_node(std::make_shared<v8::Slice>(input_tensor, start, stop, const_1_signed, dim));
|
||||
|
||||
return {context.mark_node(std::make_shared<opset10::Squeeze>(slice_node, dim))};
|
||||
return {context.mark_node(std::make_shared<v0::Squeeze>(slice_node, dim))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,8 +2,11 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/selu.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,15 +14,16 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_selu(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto alpha =
|
||||
context.mark_node(opset10::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));
|
||||
auto lambda =
|
||||
context.mark_node(opset10::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946}));
|
||||
alpha = context.mark_node(std::make_shared<opset10::ConvertLike>(alpha, x));
|
||||
lambda = context.mark_node(std::make_shared<opset10::ConvertLike>(lambda, x));
|
||||
return {context.mark_node(std::make_shared<opset10::Selu>(x, alpha, lambda))};
|
||||
auto alpha = context.mark_node(v0::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));
|
||||
auto lambda = context.mark_node(v0::Constant::create(element::f64, Shape{}, {1.0507009873554804934193349852946}));
|
||||
alpha = context.mark_node(std::make_shared<v1::ConvertLike>(alpha, x));
|
||||
lambda = context.mark_node(std::make_shared<v1::ConvertLike>(lambda, x));
|
||||
return {context.mark_node(std::make_shared<v0::Selu>(x, alpha, lambda))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,13 +13,16 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_size(NodeContext& context) {
|
||||
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(context.get_input(0), element::i32));
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(context.get_input(0), element::i32));
|
||||
if (context.input_is_none(1)) {
|
||||
return shape->outputs();
|
||||
} else {
|
||||
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<opset10::Gather>(shape, context.get_input(1), axis_0))};
|
||||
auto axis_0 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<v8::Gather>(shape, context.get_input(1), axis_0))};
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2,10 +2,13 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/slice.hpp"
|
||||
|
||||
#include <climits>
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -13,6 +16,8 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_slice(NodeContext& context) {
|
||||
// aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])
|
||||
// aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
|
||||
@ -20,11 +25,11 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
int start_idx;
|
||||
int end_idx;
|
||||
int step_idx;
|
||||
auto axis_0 = context.mark_node(opset10::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto axis_0 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
if (context.get_input_size() == 5) {
|
||||
dim = context.get_input(1);
|
||||
if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) {
|
||||
dim = context.mark_node(std::make_shared<opset10::Unsqueeze>(dim, axis_0));
|
||||
dim = context.mark_node(std::make_shared<v0::Unsqueeze>(dim, axis_0));
|
||||
}
|
||||
start_idx = 2;
|
||||
end_idx = 3;
|
||||
@ -33,7 +38,7 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
start_idx = 1;
|
||||
end_idx = 2;
|
||||
step_idx = 3;
|
||||
dim = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}));
|
||||
dim = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {0}));
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Slice must have either 4 or 5 inputs.");
|
||||
}
|
||||
@ -42,31 +47,31 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
if (!context.input_is_none(start_idx)) {
|
||||
start = context.get_input(start_idx);
|
||||
if (start.get_partial_shape().rank().is_dynamic() || start.get_partial_shape().rank().get_length() == 0) {
|
||||
start = context.mark_node(std::make_shared<opset10::Unsqueeze>(start, axis_0));
|
||||
start = context.mark_node(std::make_shared<v0::Unsqueeze>(start, axis_0));
|
||||
}
|
||||
} else {
|
||||
start = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {0}));
|
||||
start = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {0}));
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> end;
|
||||
if (!context.input_is_none(end_idx)) {
|
||||
end = context.get_input(end_idx);
|
||||
if (end.get_partial_shape().rank().is_dynamic() || end.get_partial_shape().rank().get_length() == 0) {
|
||||
end = context.mark_node(std::make_shared<opset10::Unsqueeze>(end, axis_0));
|
||||
end = context.mark_node(std::make_shared<v0::Unsqueeze>(end, axis_0));
|
||||
}
|
||||
} else {
|
||||
end = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {INT_MAX}));
|
||||
end = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {INT_MAX}));
|
||||
}
|
||||
ov::Output<ov::Node> step;
|
||||
if (!context.input_is_none(step_idx)) {
|
||||
step = context.get_input(step_idx);
|
||||
if (step.get_partial_shape().rank().is_dynamic() || step.get_partial_shape().rank().get_length() == 0) {
|
||||
step = context.mark_node(std::make_shared<opset10::Unsqueeze>(step, axis_0));
|
||||
step = context.mark_node(std::make_shared<v0::Unsqueeze>(step, axis_0));
|
||||
}
|
||||
} else {
|
||||
step = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
step = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Slice>(context.get_input(0), start, end, step, dim))};
|
||||
return {context.mark_node(std::make_shared<v8::Slice>(context.get_input(0), start, end, step, dim))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,8 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/softmax.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,9 +13,10 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_softmax(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto axis = context.const_input<int64_t>(1);
|
||||
return {context.mark_node(std::make_shared<opset10::Softmax>(x, axis))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v8::Softmax>(x, axis))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,8 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/power.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,10 +12,13 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_square(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto input_0 = context.get_input(0);
|
||||
auto const_2 = context.mark_node(opset10::Constant::create(input_0.get_element_type(), Shape{1}, {2}));
|
||||
return {context.mark_node(std::make_shared<opset10::Power>(input_0, const_2))};
|
||||
auto const_2 = context.mark_node(v0::Constant::create(input_0.get_element_type(), Shape{1}, {2}));
|
||||
return {context.mark_node(std::make_shared<v1::Power>(input_0, const_2))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -2,8 +2,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@ -11,13 +13,12 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_squeeze(NodeContext& context) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 1, "Operation has no inputs.");
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None.");
|
||||
if (inputs.size() == 1 || context.input_is_none(1)) {
|
||||
return {context.mark_node(std::make_shared<opset10::Squeeze>(inputs[0]))};
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto x = context.get_input(0);
|
||||
if (context.input_is_none(1)) {
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Squeeze>(x))};
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Squeeze>(inputs[0], inputs[1]))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Squeeze>(x, context.get_input(1)))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -8,14 +8,15 @@
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_sub(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
align_eltwise_input_types(context, x, y);
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/reduce_sum.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -12,10 +12,9 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_sum(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
bool keep_dims = false;
|
||||
ov::Output<ov::Node> axes;
|
||||
Output<Node> cast;
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Operation should have at least 1 input");
|
||||
auto data = context.get_input(0);
|
||||
if (context.input_is_none(1)) {
|
||||
axes = get_axes_range(context, 0);
|
||||
@ -26,7 +25,7 @@ OutputVector translate_sum(NodeContext& context) {
|
||||
keep_dims = context.const_input<bool>(2);
|
||||
}
|
||||
|
||||
return {context.mark_node(std::make_shared<opset10::ReduceSum>(data, axes, keep_dims))};
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::ReduceSum>(data, axes, keep_dims))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -3,7 +3,9 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "pt_framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -12,6 +14,8 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_to(NodeContext& context) {
|
||||
int dtype_idx;
|
||||
int memory_format_idx;
|
||||
@ -44,13 +48,13 @@ OutputVector translate_to(NodeContext& context) {
|
||||
Output<Node> cast;
|
||||
if (dtype_fw_node && dtype_fw_node->get_op_type() == "prim::dtype") {
|
||||
auto type_input = dtype_fw_node->input_value(0);
|
||||
cast = context.mark_node(std::make_shared<opset10::ConvertLike>(context.get_input(0), type_input));
|
||||
} else if (const auto dtype_const = std::dynamic_pointer_cast<opset10::Constant>(dtype_ext_node)) {
|
||||
cast = context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(0), type_input));
|
||||
} else if (const auto dtype_const = std::dynamic_pointer_cast<v0::Constant>(dtype_ext_node)) {
|
||||
auto pt_type = dtype_const->cast_vector<int64_t>()[0];
|
||||
auto dtype = convert_dtype(pt_type);
|
||||
cast = context.mark_node(std::make_shared<opset10::Convert>(context.get_input(0), dtype));
|
||||
cast = context.mark_node(std::make_shared<v0::Convert>(context.get_input(0), dtype));
|
||||
} else {
|
||||
cast = context.mark_node(std::make_shared<opset10::ConvertLike>(context.get_input(0), context.get_input(1)));
|
||||
cast = context.mark_node(std::make_shared<v1::ConvertLike>(context.get_input(0), context.get_input(1)));
|
||||
}
|
||||
return {cast};
|
||||
}
|
||||
|
@ -6,33 +6,37 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_topk(NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
const auto input_tensor = context.get_input(0);
|
||||
const auto largest = context.const_input<bool>(3);
|
||||
const auto sorted = context.const_input<bool>(4);
|
||||
auto k = context.get_input(1);
|
||||
int64_t axis{-1};
|
||||
auto mode = ov::op::TopKMode::MIN;
|
||||
auto sort = ov::op::TopKSortType::NONE;
|
||||
auto mode = TopKMode::MIN;
|
||||
auto sort = TopKSortType::NONE;
|
||||
|
||||
if (!context.input_is_none(2)) {
|
||||
axis = context.const_input<int64_t>(2);
|
||||
}
|
||||
if (largest) {
|
||||
mode = ov::op::TopKMode::MAX;
|
||||
mode = TopKMode::MAX;
|
||||
}
|
||||
if (sorted) {
|
||||
sort = ov::op::TopKSortType::SORT_VALUES;
|
||||
sort = TopKSortType::SORT_VALUES;
|
||||
}
|
||||
|
||||
auto topk = context.mark_node(std::make_shared<ov::op::v3::TopK>(input_tensor, k, axis, mode, sort));
|
||||
auto indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(topk->output(1), element::i64));
|
||||
auto topk = context.mark_node(std::make_shared<v3::TopK>(input_tensor, k, axis, mode, sort));
|
||||
auto indices = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
|
||||
|
||||
return {topk->output(0), indices};
|
||||
};
|
||||
|
@ -2,43 +2,52 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/transpose.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/scatter_elements_update.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_transpose(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto dim0 = context.const_input<int64_t>(1);
|
||||
auto dim1 = context.const_input<int64_t>(2);
|
||||
auto shape = std::make_shared<opset10::ShapeOf>(context.get_input(0), element::i32);
|
||||
auto rank_ = std::make_shared<opset10::ShapeOf>(shape, element::i32);
|
||||
auto rank = std::make_shared<opset10::Squeeze>(rank_);
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
|
||||
// Use opset::If for dim normalization
|
||||
auto dim0_node = context.get_input(1);
|
||||
auto dim1_node = context.get_input(2);
|
||||
if (dim0 < 0) {
|
||||
dim0_node = std::make_shared<opset10::Add>(rank, dim0_node);
|
||||
dim0_node = std::make_shared<v1::Add>(rank, dim0_node);
|
||||
}
|
||||
if (dim1 < 0) {
|
||||
dim1_node = std::make_shared<opset10::Add>(rank, dim1_node);
|
||||
dim1_node = std::make_shared<v1::Add>(rank, dim1_node);
|
||||
}
|
||||
auto start = opset10::Constant::create(element::i32, {}, {0});
|
||||
auto step = opset10::Constant::create(element::i32, {}, {1});
|
||||
auto range = std::make_shared<opset10::Range>(start, rank, step, element::i32);
|
||||
auto start = v0::Constant::create(element::i32, {}, {0});
|
||||
auto step = v0::Constant::create(element::i32, {}, {1});
|
||||
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);
|
||||
|
||||
auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto dim0_node_ = std::make_shared<opset10::Unsqueeze>(dim0_node, axis_0);
|
||||
auto dim1_node_ = std::make_shared<opset10::Unsqueeze>(dim1_node, axis_0);
|
||||
auto indices = std::make_shared<opset10::Concat>(OutputVector{dim0_node_, dim1_node_}, 0);
|
||||
auto updates = std::make_shared<opset10::Concat>(OutputVector{dim1_node_, dim0_node_}, 0);
|
||||
auto scatter = std::make_shared<opset10::ScatterElementsUpdate>(range, indices, updates, axis_0);
|
||||
context.mark_nodes(
|
||||
{shape, rank_, rank, start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter});
|
||||
auto axis_0 = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto dim0_node_ = std::make_shared<v0::Unsqueeze>(dim0_node, axis_0);
|
||||
auto dim1_node_ = std::make_shared<v0::Unsqueeze>(dim1_node, axis_0);
|
||||
auto indices = std::make_shared<v0::Concat>(OutputVector{dim0_node_, dim1_node_}, 0);
|
||||
auto updates = std::make_shared<v0::Concat>(OutputVector{dim1_node_, dim0_node_}, 0);
|
||||
auto scatter = std::make_shared<v3::ScatterElementsUpdate>(range, indices, updates, axis_0);
|
||||
context.mark_nodes({start, step, range, axis_0, dim0_node_, dim1_node_, indices, updates, scatter});
|
||||
|
||||
return {context.mark_node(std::make_shared<opset10::Transpose>(context.get_input(0), scatter))};
|
||||
return {context.mark_node(std::make_shared<v1::Transpose>(context.get_input(0), scatter))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -14,55 +14,58 @@
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/unsqueeze.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
namespace base {
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_base_triu_tril(NodeContext& context, bool upper) {
|
||||
namespace {
|
||||
OutputVector translate_base_triu_tril(const NodeContext& context, bool upper) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(input_tensor));
|
||||
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto minus_one = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {-1}));
|
||||
auto minus_two = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {-2}));
|
||||
const auto m = context.mark_node(std::make_shared<ov::op::v7::Gather>(input_shape, minus_one, zero));
|
||||
const auto n = context.mark_node(std::make_shared<ov::op::v7::Gather>(input_shape, minus_two, zero));
|
||||
auto horizontal_range = context.mark_node(std::make_shared<ov::op::v4::Range>(zero, m, one, element::i64));
|
||||
horizontal_range = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(horizontal_range, zero));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input_tensor));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}));
|
||||
auto minus_two = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-2}));
|
||||
const auto m = context.mark_node(std::make_shared<v7::Gather>(input_shape, minus_one, zero));
|
||||
const auto n = context.mark_node(std::make_shared<v7::Gather>(input_shape, minus_two, zero));
|
||||
auto horizontal_range = context.mark_node(std::make_shared<v4::Range>(zero, m, one, element::i64));
|
||||
horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero));
|
||||
Output<Node> vertical_range;
|
||||
if (!context.input_is_none(1)) {
|
||||
auto diagonal = context.get_input(1);
|
||||
diagonal = context.mark_node(std::make_shared<ov::op::v0::Convert>(diagonal, element::i64));
|
||||
auto stop = context.mark_node(std::make_shared<ov::op::v1::Add>(n, diagonal));
|
||||
vertical_range = context.mark_node(std::make_shared<ov::op::v4::Range>(diagonal, stop, one, element::i64));
|
||||
diagonal = context.mark_node(std::make_shared<v0::Convert>(diagonal, element::i64));
|
||||
auto stop = context.mark_node(std::make_shared<v1::Add>(n, diagonal));
|
||||
vertical_range = context.mark_node(std::make_shared<v4::Range>(diagonal, stop, one, element::i64));
|
||||
} else {
|
||||
vertical_range = context.mark_node(std::make_shared<ov::op::v4::Range>(zero, n, one, element::i64));
|
||||
vertical_range = context.mark_node(std::make_shared<v4::Range>(zero, n, one, element::i64));
|
||||
}
|
||||
vertical_range = context.mark_node(std::make_shared<ov::op::v0::Unsqueeze>(vertical_range, one));
|
||||
vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one));
|
||||
|
||||
Output<Node> mask;
|
||||
if (upper) {
|
||||
mask = context.mark_node(std::make_shared<ov::op::v1::GreaterEqual>(horizontal_range, vertical_range));
|
||||
mask = context.mark_node(std::make_shared<v1::GreaterEqual>(horizontal_range, vertical_range));
|
||||
} else {
|
||||
mask = context.mark_node(std::make_shared<ov::op::v1::LessEqual>(horizontal_range, vertical_range));
|
||||
mask = context.mark_node(std::make_shared<v1::LessEqual>(horizontal_range, vertical_range));
|
||||
}
|
||||
|
||||
zero = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(zero, input_tensor));
|
||||
zero = context.mark_node(std::make_shared<v1::ConvertLike>(zero, input_tensor));
|
||||
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Select>(mask, input_tensor, zero))};
|
||||
return {context.mark_node(std::make_shared<v1::Select>(mask, input_tensor, zero))};
|
||||
}
|
||||
}; // namespace base
|
||||
}; // namespace
|
||||
|
||||
OutputVector translate_triu(NodeContext& context) {
|
||||
return base::translate_base_triu_tril(context, true);
|
||||
return translate_base_triu_tril(context, true);
|
||||
};
|
||||
|
||||
OutputVector translate_tril(NodeContext& context) {
|
||||
return base::translate_base_triu_tril(context, false);
|
||||
return translate_base_triu_tril(context, false);
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -1,25 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_tuple_construct(NodeContext& context) {
|
||||
auto n_inputs = context.get_input_size();
|
||||
FRONT_END_OP_CONVERSION_CHECK(
|
||||
n_inputs == 1,
|
||||
"prim::TupleConstruct conversion doesn't support cases when the number of inputs is not one.");
|
||||
return {context.get_input(0)};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::opset10;
|
||||
|
||||
@ -13,6 +14,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_unfold(NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
// constants
|
||||
auto const_0 = context.mark_node(Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(Constant::create(element::i32, Shape{}, {1}));
|
||||
@ -22,8 +24,9 @@ OutputVector translate_unfold(NodeContext& context) {
|
||||
|
||||
// get inputs and prepare auxiliary nodes
|
||||
auto input = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<ShapeOf>(input, element::i32));
|
||||
auto input_rank = context.mark_node(std::make_shared<ShapeOf>(input_shape, element::i32));
|
||||
Output<Node> input_shape;
|
||||
Output<Node> input_rank;
|
||||
std::tie(input_shape, input_rank) = get_shape_rank(context, input);
|
||||
|
||||
auto dimension = context.mark_node(std::make_shared<Unsqueeze>(context.get_input(1), const_0));
|
||||
auto dimension_plus_1 = context.mark_node(std::make_shared<Add>(dimension, const_1_list));
|
||||
|
@ -3,63 +3,69 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/interpolate.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
OutputVector base_translate_upsample2d(NodeContext& context, opset10::Interpolate::InterpolateMode interpolate_mode) {
|
||||
OutputVector base_translate_upsample2d(const NodeContext& context, v4::Interpolate::InterpolateMode interpolate_mode) {
|
||||
num_inputs_check(context, 3, 4);
|
||||
auto data = context.get_input(0);
|
||||
std::vector<size_t> pad{0};
|
||||
auto size_mode = opset10::Interpolate::ShapeCalcMode::SIZES;
|
||||
auto size_mode = v4::Interpolate::ShapeCalcMode::SIZES;
|
||||
bool align_corners = false;
|
||||
int scale_id = 2;
|
||||
if (interpolate_mode != opset10::Interpolate::InterpolateMode::NEAREST) {
|
||||
if (interpolate_mode != v4::Interpolate::InterpolateMode::NEAREST) {
|
||||
scale_id = 3;
|
||||
if (!context.input_is_none(2)) {
|
||||
align_corners = context.const_input<bool>(2);
|
||||
}
|
||||
}
|
||||
auto target_axes = std::make_shared<opset10::Constant>(element::i32, Shape{2}, std::vector<int>({2, 3}));
|
||||
auto target_axes = std::make_shared<v0::Constant>(element::i32, Shape{2}, std::vector<int>({2, 3}));
|
||||
auto scales =
|
||||
context.mark_node(std::make_shared<opset10::Constant>(element::f32, Shape{2}, std::vector<double>({1, 1})));
|
||||
context.mark_node(std::make_shared<v0::Constant>(element::f32, Shape{2}, std::vector<double>({1, 1})));
|
||||
auto output_sizes =
|
||||
context.mark_node(std::make_shared<opset10::Constant>(element::i32, Shape{2}, std::vector<int>({1, 1})));
|
||||
context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{2}, std::vector<int>({1, 1})));
|
||||
if (context.input_is_none(1)) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(scale_id), "Scale or Output size should be provided");
|
||||
auto spatial_scales = context.get_input(scale_id);
|
||||
|
||||
size_mode = opset10::Interpolate::ShapeCalcMode::SCALES;
|
||||
scales = context.mark_node(std::make_shared<opset10::Multiply>(spatial_scales, scales));
|
||||
size_mode = v4::Interpolate::ShapeCalcMode::SCALES;
|
||||
scales = context.mark_node(std::make_shared<v1::Multiply>(spatial_scales, scales));
|
||||
} else {
|
||||
auto out_sizes = context.get_input(1);
|
||||
output_sizes = context.mark_node(std::make_shared<opset10::Multiply>(out_sizes, output_sizes));
|
||||
output_sizes = context.mark_node(std::make_shared<v1::Multiply>(out_sizes, output_sizes));
|
||||
}
|
||||
auto attrs = opset10::Interpolate::InterpolateAttrs(interpolate_mode, size_mode, pad, pad);
|
||||
attrs.coordinate_transformation_mode = opset10::Interpolate::CoordinateTransformMode::ASYMMETRIC;
|
||||
attrs.nearest_mode = opset10::Interpolate::NearestMode::FLOOR;
|
||||
if (attrs.mode != opset10::Interpolate::InterpolateMode::NEAREST) {
|
||||
auto attrs = v4::Interpolate::InterpolateAttrs(interpolate_mode, size_mode, pad, pad);
|
||||
attrs.coordinate_transformation_mode = v4::Interpolate::CoordinateTransformMode::ASYMMETRIC;
|
||||
attrs.nearest_mode = v4::Interpolate::NearestMode::FLOOR;
|
||||
if (attrs.mode != v4::Interpolate::InterpolateMode::NEAREST) {
|
||||
if (align_corners) {
|
||||
attrs.coordinate_transformation_mode = opset10::Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
||||
attrs.coordinate_transformation_mode = v4::Interpolate::CoordinateTransformMode::ALIGN_CORNERS;
|
||||
}
|
||||
}
|
||||
return {context.mark_node(std::make_shared<opset10::Interpolate>(data, output_sizes, scales, target_axes, attrs))};
|
||||
return {context.mark_node(std::make_shared<v4::Interpolate>(data, output_sizes, scales, target_axes, attrs))};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_upsample_bilinear2d(NodeContext& context) {
|
||||
return base_translate_upsample2d(context, opset10::Interpolate::InterpolateMode::LINEAR_ONNX);
|
||||
return base_translate_upsample2d(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_nearest2d(NodeContext& context) {
|
||||
return base_translate_upsample2d(context, opset10::Interpolate::InterpolateMode::NEAREST);
|
||||
return base_translate_upsample2d(context, v4::Interpolate::InterpolateMode::NEAREST);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_bicubic2d(NodeContext& context) {
|
||||
return base_translate_upsample2d(context, opset10::Interpolate::InterpolateMode::CUBIC);
|
||||
return base_translate_upsample2d(context, v4::Interpolate::InterpolateMode::CUBIC);
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -18,11 +18,14 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_var_mean(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 4);
|
||||
auto data = context.get_input(0);
|
||||
bool unbiased = true;
|
||||
bool keepdims = false;
|
||||
auto num_elements = numel(context, 0);
|
||||
auto num_elements = numel(context, data);
|
||||
bool keepdim_mean;
|
||||
std::shared_ptr<ov::Node> mean, t_mean;
|
||||
ov::Output<ov::Node> axes;
|
||||
@ -30,7 +33,7 @@ OutputVector translate_var_mean(NodeContext& context) {
|
||||
// aten::var_mean(input, unbiased)
|
||||
axes = context.mark_node(get_axes_range(context, 0));
|
||||
unbiased = context.const_input<bool>(1);
|
||||
mean = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(data, axes, keepdims));
|
||||
mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, keepdims));
|
||||
t_mean = mean;
|
||||
keepdim_mean = keepdims;
|
||||
} else {
|
||||
@ -43,31 +46,31 @@ OutputVector translate_var_mean(NodeContext& context) {
|
||||
}
|
||||
if (context.input_is_none(1)) {
|
||||
axes = context.mark_node(get_axes_range(context, 0));
|
||||
mean = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(data, axes, keepdims));
|
||||
mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, keepdims));
|
||||
t_mean = mean;
|
||||
} else {
|
||||
axes = context.get_input(1);
|
||||
mean = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(data, axes, keepdims));
|
||||
t_mean = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(data, axes, true));
|
||||
auto reduced_dims = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(data));
|
||||
auto zero = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
reduced_dims = context.mark_node(std::make_shared<ov::op::v8::Gather>(reduced_dims, axes, zero));
|
||||
num_elements = context.mark_node(std::make_shared<ov::op::v1::ReduceProd>(reduced_dims, zero, false));
|
||||
mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, keepdims));
|
||||
t_mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, true));
|
||||
auto reduced_dims = context.mark_node(std::make_shared<v3::ShapeOf>(data));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
reduced_dims = context.mark_node(std::make_shared<v8::Gather>(reduced_dims, axes, zero));
|
||||
num_elements = context.mark_node(std::make_shared<v1::ReduceProd>(reduced_dims, zero, false));
|
||||
}
|
||||
keepdim_mean = context.input_is_none(1) ? false : keepdims;
|
||||
}
|
||||
auto sub_v = context.mark_node(std::make_shared<ov::op::v1::Subtract>(data, t_mean));
|
||||
auto sqr_sub = context.mark_node(std::make_shared<ov::op::v1::Multiply>(sub_v, sub_v));
|
||||
auto var = context.mark_node(std::make_shared<ov::op::v1::ReduceMean>(sqr_sub, axes, keepdim_mean));
|
||||
auto sub_v = context.mark_node(std::make_shared<v1::Subtract>(data, t_mean));
|
||||
auto sqr_sub = context.mark_node(std::make_shared<v1::Multiply>(sub_v, sub_v));
|
||||
auto var = context.mark_node(std::make_shared<v1::ReduceMean>(sqr_sub, axes, keepdim_mean));
|
||||
// if unbiased=true Bessel’s correction will be used
|
||||
// Correct bias in calculating variance, by dividing it over (N - 1) instead on N
|
||||
if (unbiased) {
|
||||
num_elements = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(num_elements, data));
|
||||
auto one = context.mark_node(ov::op::v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
one = context.mark_node(std::make_shared<ov::op::v1::ConvertLike>(one, data));
|
||||
auto mul = context.mark_node(std::make_shared<ov::op::v1::Multiply>(var, num_elements));
|
||||
auto n_minus_one = context.mark_node(std::make_shared<ov::op::v1::Subtract>(num_elements, one));
|
||||
var = context.mark_node(std::make_shared<ov::op::v1::Divide>(mul, n_minus_one));
|
||||
num_elements = context.mark_node(std::make_shared<v1::ConvertLike>(num_elements, data));
|
||||
auto one = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
one = context.mark_node(std::make_shared<v1::ConvertLike>(one, data));
|
||||
auto mul = context.mark_node(std::make_shared<v1::Multiply>(var, num_elements));
|
||||
auto n_minus_one = context.mark_node(std::make_shared<v1::Subtract>(num_elements, one));
|
||||
var = context.mark_node(std::make_shared<v1::Divide>(mul, n_minus_one));
|
||||
}
|
||||
return {var, mean};
|
||||
};
|
||||
|
@ -3,7 +3,8 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -11,13 +12,16 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_where(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto cond = context.get_input(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(1), "aten::where(cond) unsupported");
|
||||
auto bool_cond = context.mark_node(std::make_shared<opset10::Convert>(cond, element::boolean));
|
||||
auto bool_cond = context.mark_node(std::make_shared<v0::Convert>(cond, element::boolean));
|
||||
auto x = context.get_input(1);
|
||||
auto y = context.get_input(2);
|
||||
return {context.mark_node(std::make_shared<opset10::Select>(bool_cond, x, y))};
|
||||
return {context.mark_node(std::make_shared<v1::Select>(bool_cond, x, y))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -25,22 +25,22 @@ OP_CONVERTER(translate_avg_poolnd);
|
||||
OP_CONVERTER(translate_batch_norm);
|
||||
OP_CONVERTER(translate_clamp);
|
||||
OP_CONVERTER(translate_constant);
|
||||
OP_CONVERTER(translate_convnd);
|
||||
OP_CONVERTER(translate_conv_transposend);
|
||||
OP_CONVERTER(translate_convnd);
|
||||
OP_CONVERTER(translate_convolution);
|
||||
OP_CONVERTER(translate_convolution_mode);
|
||||
OP_CONVERTER(translate_dim);
|
||||
OP_CONVERTER(translate_div);
|
||||
OP_CONVERTER(translate_elu);
|
||||
OP_CONVERTER(translate_empty);
|
||||
OP_CONVERTER(translate_embedding);
|
||||
OP_CONVERTER(translate_empty);
|
||||
OP_CONVERTER(translate_expand);
|
||||
OP_CONVERTER(translate_expand_as);
|
||||
OP_CONVERTER(translate_eye);
|
||||
OP_CONVERTER(translate_fill_);
|
||||
OP_CONVERTER(translate_flatten);
|
||||
OP_CONVERTER(translate_floordiv);
|
||||
OP_CONVERTER(translate_floor_divide);
|
||||
OP_CONVERTER(translate_floordiv);
|
||||
OP_CONVERTER(translate_full);
|
||||
OP_CONVERTER(translate_full_like);
|
||||
OP_CONVERTER(translate_gelu);
|
||||
@ -61,19 +61,19 @@ OP_CONVERTER(translate_list_construct);
|
||||
OP_CONVERTER(translate_log);
|
||||
OP_CONVERTER(translate_log2);
|
||||
OP_CONVERTER(translate_loop);
|
||||
OP_CONVERTER(translate_max_poolnd);
|
||||
OP_CONVERTER(translate_max);
|
||||
OP_CONVERTER(translate_masked_fill);
|
||||
OP_CONVERTER(translate_max);
|
||||
OP_CONVERTER(translate_max_poolnd);
|
||||
OP_CONVERTER(translate_mean);
|
||||
OP_CONVERTER(translate_min);
|
||||
OP_CONVERTER(translate_meshgrid);
|
||||
OP_CONVERTER(translate_min);
|
||||
OP_CONVERTER(translate_neg);
|
||||
OP_CONVERTER(translate_nonzero);
|
||||
OP_CONVERTER(translate_norm);
|
||||
OP_CONVERTER(translate_new_full);
|
||||
OP_CONVERTER(translate_new_ones);
|
||||
OP_CONVERTER(translate_new_zeros);
|
||||
OP_CONVERTER(translate_nms);
|
||||
OP_CONVERTER(translate_nonzero);
|
||||
OP_CONVERTER(translate_norm);
|
||||
OP_CONVERTER(translate_numel);
|
||||
OP_CONVERTER(translate_ones);
|
||||
OP_CONVERTER(translate_ones_like);
|
||||
@ -86,9 +86,9 @@ OP_CONVERTER(translate_repeat);
|
||||
OP_CONVERTER(translate_repeat_interleave);
|
||||
OP_CONVERTER(translate_reshape);
|
||||
OP_CONVERTER(translate_reshape_as);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
OP_CONVERTER(translate_roll);
|
||||
OP_CONVERTER(translate_rsqrt);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
OP_CONVERTER(translate_select);
|
||||
OP_CONVERTER(translate_selu);
|
||||
OP_CONVERTER(translate_size);
|
||||
@ -103,7 +103,6 @@ OP_CONVERTER(translate_topk);
|
||||
OP_CONVERTER(translate_transpose);
|
||||
OP_CONVERTER(translate_tril);
|
||||
OP_CONVERTER(translate_triu);
|
||||
OP_CONVERTER(translate_tuple_construct);
|
||||
OP_CONVERTER(translate_unfold);
|
||||
OP_CONVERTER(translate_upsample_bicubic2d);
|
||||
OP_CONVERTER(translate_upsample_bilinear2d);
|
||||
@ -119,8 +118,8 @@ OP_CONVERTER(translate_zeros_like);
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
return {
|
||||
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
{"aten::__not__", op::translate_1to1_match_1_inputs<opset10::LogicalNot>},
|
||||
{"aten::_convolution", op::translate_convolution},
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
||||
@ -136,11 +135,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::addcmul", op::translate_addcmul},
|
||||
{"aten::addmm", op::translate_addmm},
|
||||
{"aten::arange", op::translate_arange},
|
||||
{"aten::as_tensor", op::translate_as_tensor},
|
||||
{"aten::asin", op::translate_1to1_match_1_inputs<opset10::Asin>},
|
||||
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asin>>},
|
||||
{"aten::asinh", op::translate_1to1_match_1_inputs<opset10::Asinh>},
|
||||
{"aten::asinh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asinh>>},
|
||||
{"aten::as_tensor", op::translate_as_tensor},
|
||||
{"aten::atan", op::translate_1to1_match_1_inputs<opset10::Atan>},
|
||||
{"aten::atan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Atan>>},
|
||||
{"aten::atanh", op::translate_1to1_match_1_inputs<opset10::Atanh>},
|
||||
@ -149,21 +148,22 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::avg_pool2d", op::translate_avg_poolnd},
|
||||
{"aten::avg_pool3d", op::translate_avg_poolnd},
|
||||
{"aten::batch_norm", op::translate_batch_norm},
|
||||
// {"aten::cat", done as transformation},
|
||||
{"aten::clamp", op::translate_clamp},
|
||||
{"aten::clamp_min", op::translate_1to1_match_2_inputs<opset10::Maximum>},
|
||||
{"aten::clamp_max", op::translate_1to1_match_2_inputs<opset10::Minimum>},
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
//{"aten::cat", done as transformation},
|
||||
{"aten::ceil", op::translate_1to1_match_1_inputs<opset10::Ceiling>},
|
||||
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
|
||||
{"aten::clamp", op::translate_clamp},
|
||||
{"aten::clamp_max", op::translate_1to1_match_2_inputs<opset10::Minimum>},
|
||||
{"aten::clamp_min", op::translate_1to1_match_2_inputs<opset10::Maximum>},
|
||||
{"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
|
||||
{"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail,
|
||||
// we assume all tensors are contiguous
|
||||
{"aten::conv1d", op::translate_convnd},
|
||||
{"aten::conv2d", op::translate_convnd},
|
||||
{"aten::conv3d", op::translate_convnd},
|
||||
{"aten::conv_transpose1d", op::translate_conv_transposend},
|
||||
{"aten::conv_transpose2d", op::translate_conv_transposend},
|
||||
{"aten::conv_transpose3d", op::translate_conv_transposend},
|
||||
{"aten::conv1d", op::translate_convnd},
|
||||
{"aten::conv2d", op::translate_convnd},
|
||||
{"aten::conv3d", op::translate_convnd},
|
||||
{"aten::convolution", op::translate_convolution},
|
||||
{"aten::copy", op::skip_node},
|
||||
{"aten::cos", op::translate_1to1_match_1_inputs<opset10::Cos>},
|
||||
@ -188,60 +188,59 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::flatten", op::translate_flatten},
|
||||
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
|
||||
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
|
||||
{"aten::floordiv", op::translate_floordiv},
|
||||
{"aten::floor_divide", op::translate_floor_divide},
|
||||
{"aten::floordiv", op::translate_floordiv},
|
||||
{"aten::full", op::translate_full},
|
||||
{"aten::full_like", op::translate_full_like},
|
||||
{"aten::ge", op::translate_1to1_match_2_inputs_align_types<opset10::GreaterEqual>},
|
||||
{"aten::gelu", op::translate_gelu},
|
||||
{"aten::glu", op::translate_glu},
|
||||
{"aten::group_norm", op::translate_group_norm},
|
||||
{"aten::ge", op::translate_1to1_match_2_inputs_align_types<opset10::GreaterEqual>},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
|
||||
{"aten::grid_sampler", op::translate_grid_sampler},
|
||||
{"aten::group_norm", op::translate_group_norm},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs_align_types<opset10::Greater>},
|
||||
{"aten::hardsigmoid", op::translate_1to1_match_1_inputs<opset10::HSigmoid>},
|
||||
{"aten::hardswish", op::translate_1to1_match_1_inputs<opset10::HSwish>},
|
||||
{"aten::hardswish_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::HSwish>>},
|
||||
{"aten::hardtanh", op::translate_hardtanh},
|
||||
{"aten::hardtanh_", op::inplace_op<op::translate_hardtanh>},
|
||||
{"aten::Int", op::translate_int},
|
||||
{"aten::IntImplicit", op::translate_int},
|
||||
{"aten::im2col", op::translate_im2col},
|
||||
{"aten::instance_norm", op::translate_instance_norm},
|
||||
{"aten::Int", op::translate_int},
|
||||
{"aten::IntImplicit", op::translate_int},
|
||||
{"aten::is_grad_enabled", op::return_false_scalar},
|
||||
{"aten::layer_norm", op::translate_layer_norm},
|
||||
{"aten::le", op::translate_1to1_match_2_inputs_align_types<opset10::LessEqual>},
|
||||
{"aten::leaky_relu", op::translate_1to1_match_2_inputs<opset10::PRelu>},
|
||||
{"aten::leaky_relu_", op::inplace_op<op::translate_1to1_match_2_inputs<opset10::PRelu>>},
|
||||
{"aten::len", op::translate_len},
|
||||
{"aten::linear", op::translate_linear},
|
||||
{"aten::le", op::translate_1to1_match_2_inputs_align_types<opset10::LessEqual>},
|
||||
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
|
||||
{"aten::log", op::translate_log},
|
||||
{"aten::log_", op::inplace_op<op::translate_log>},
|
||||
{"aten::log2", op::translate_log2},
|
||||
{"aten::log2_", op::inplace_op<op::translate_log2>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
|
||||
{"aten::masked_fill", op::translate_masked_fill},
|
||||
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::max", op::translate_max},
|
||||
{"aten::max_pool1d", op::translate_max_poolnd},
|
||||
{"aten::max_pool2d", op::translate_max_poolnd},
|
||||
{"aten::max_pool3d", op::translate_max_poolnd},
|
||||
{"aten::max", op::translate_max},
|
||||
{"aten::mean", op::translate_mean},
|
||||
{"aten::meshgrid", op::translate_meshgrid},
|
||||
{"aten::min", op::translate_min},
|
||||
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
|
||||
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
|
||||
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
|
||||
{"aten::neg", op::translate_neg},
|
||||
{"aten::norm", op::translate_norm},
|
||||
{"aten::nonzero", op::translate_nonzero},
|
||||
{"aten::numel", op::translate_numel},
|
||||
{"aten::new_full", op::translate_new_full},
|
||||
{"aten::new_ones", op::translate_new_ones},
|
||||
{"aten::new_zeros", op::translate_new_zeros},
|
||||
{"aten::nonzero", op::translate_nonzero},
|
||||
{"aten::norm", op::translate_norm},
|
||||
{"aten::numel", op::translate_numel},
|
||||
{"aten::ones", op::translate_ones},
|
||||
{"aten::ones_like", op::translate_ones_like},
|
||||
{"aten::pad", op::translate_pad},
|
||||
@ -256,9 +255,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::repeat_interleave", op::translate_repeat_interleave},
|
||||
{"aten::reshape", op::translate_reshape},
|
||||
{"aten::reshape_as", op::translate_reshape_as},
|
||||
{"aten::rsub", op::translate_rsub},
|
||||
{"aten::roll", op::translate_roll},
|
||||
{"aten::rsqrt", op::translate_rsqrt},
|
||||
{"aten::rsub", op::translate_rsub},
|
||||
{"aten::ScalarImplicit", op::skip_node},
|
||||
{"aten::select", op::translate_select},
|
||||
{"aten::selu", op::translate_selu},
|
||||
@ -311,7 +310,6 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"prim::Loop", op::translate_loop},
|
||||
{"prim::NumToTensor", op::skip_node}, // In openvino we already store number as tensor with shape []
|
||||
{"prim::requires_grad", op::return_false_scalar},
|
||||
{"prim::TupleConstruct", op::translate_tuple_construct},
|
||||
{"torchvision::nms", op::translate_nms},
|
||||
};
|
||||
};
|
||||
|
@ -16,7 +16,7 @@ namespace pytorch {
|
||||
|
||||
void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() > min_inputs, "Got less inputs than expected");
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= min_inputs, "Got less inputs than expected");
|
||||
for (auto i = max_inputs; i < inputs.size(); i++) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||
}
|
||||
@ -42,7 +42,9 @@ Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||
}
|
||||
}
|
||||
|
||||
Output<ov::Node> reshape_channelwise(const NodeContext& context, Output<ov::Node> data, Output<ov::Node> shape_source) {
|
||||
Output<Node> reshape_channelwise(const NodeContext& context,
|
||||
const Output<Node>& data,
|
||||
const Output<Node>& shape_source) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_source));
|
||||
auto input_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape));
|
||||
auto one_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
@ -113,8 +115,7 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, size_t input_id
|
||||
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
|
||||
};
|
||||
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, size_t input_id) {
|
||||
auto x = context.get_input(input_id);
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x));
|
||||
auto axes = context.mark_node(opset10::Constant::create(element::i64, Shape({1}), {0}));
|
||||
return context.mark_node(std::make_shared<opset10::ReduceProd>(input_shape, axes, false));
|
||||
@ -135,7 +136,7 @@ const std::unordered_map<std::string, ov::op::PadType> TORCH_AUTO_PAD_TO_OV{{"va
|
||||
{"same", ov::op::PadType::SAME_UPPER}};
|
||||
} // namespace
|
||||
|
||||
ov::element::Type convert_dtype(int64_t pt_type) {
|
||||
element::Type convert_dtype(int64_t pt_type) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(TORCH_TO_OV_TYPE.count(pt_type), "Unknown type: ", pt_type);
|
||||
return TORCH_TO_OV_TYPE.at(pt_type);
|
||||
};
|
||||
@ -166,8 +167,9 @@ OutputVector make_framework_node(NodeContext* context) {
|
||||
// Hack. Can indicate mutable inputs, but can it be reliable?
|
||||
if (schema.find('!') != std::string::npos) {
|
||||
// We create additional output for such nodes. It contains new tensor that represents input that was changed.
|
||||
auto fw_node =
|
||||
std::make_shared<PtFrameworkNode>(context->get_decoder(), context->inputs(), context->num_of_outputs() + 1);
|
||||
auto fw_node = std::make_shared<PtFrameworkNode>(context->get_decoder(),
|
||||
context->inputs(),
|
||||
context->get_output_size() + 1);
|
||||
fw_node->set_friendly_name(context->get_op_type());
|
||||
auto outputs = fw_node->outputs();
|
||||
// Usually mutated input index is 0, because it is usually "self" input, so we need to replace this tensor with
|
||||
@ -185,7 +187,7 @@ OutputVector make_framework_node(NodeContext* context) {
|
||||
std::map<size_t, ParameterVector> inputs_map;
|
||||
std::map<size_t, ResultVector> extra_outputs_map;
|
||||
std::set<size_t> input_idxs; // initial inputs
|
||||
std::vector<std::shared_ptr<ov::Model>> bodies;
|
||||
std::vector<std::shared_ptr<Model>> bodies;
|
||||
// We need to remember initial inputs to be able to find extra inputs to body that were created to propagate
|
||||
// external context
|
||||
size_t num_body_outs = 0;
|
||||
@ -221,12 +223,13 @@ OutputVector make_framework_node(NodeContext* context) {
|
||||
// Number of body outputs can be higher then number of pt node outputs, e.g. in case of loop first body output is
|
||||
// condition, we have to skip such outputs.
|
||||
int num_skip_body_outputs =
|
||||
num_body_outs > context->num_of_outputs() ? num_body_outs - context->num_of_outputs() : 0;
|
||||
num_body_outs > context->get_output_size() ? num_body_outs - context->get_output_size() : 0;
|
||||
|
||||
// We need to reduce number of outputs, because some outputs are outputs from body
|
||||
auto fw_node = std::make_shared<PtFrameworkNode>(context->get_decoder(),
|
||||
context->inputs(),
|
||||
context->num_of_outputs() - num_body_outs + num_skip_body_outputs);
|
||||
auto fw_node =
|
||||
std::make_shared<PtFrameworkNode>(context->get_decoder(),
|
||||
context->inputs(),
|
||||
context->get_output_size() - num_body_outs + num_skip_body_outputs);
|
||||
fw_node->set_friendly_name(context->get_op_type());
|
||||
for (size_t i = 0; i < bodies.size(); ++i) {
|
||||
fw_node->set_function(i, bodies[i]);
|
||||
@ -293,9 +296,9 @@ OutputVector convert_node(NodeContext* context) {
|
||||
/// which is visible from nested model. Empty external_tensor_map is used as an indication that this is a main body
|
||||
/// conversion.
|
||||
/// \return fully converted OV Model
|
||||
std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> pytorch_model,
|
||||
const TensorMap& external_tensor_map) {
|
||||
std::shared_ptr<ov::Model> resulting_model; // define here to make a conversion in a nested scope
|
||||
std::shared_ptr<Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> pytorch_model,
|
||||
const TensorMap& external_tensor_map) {
|
||||
std::shared_ptr<Model> resulting_model; // define here to make a conversion in a nested scope
|
||||
{
|
||||
ParameterVector parameters;
|
||||
TensorMap tensor_map; // tensor map of the current context
|
||||
@ -307,8 +310,8 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> p
|
||||
PartialShape ps = pytorch_model->get_input_shape(i);
|
||||
auto type = simplified_type_interpret(pytorch_model->get_input_type(i));
|
||||
// TODO: Use special API to set custom type detalization
|
||||
auto parameter = std::make_shared<opset10::Parameter>(ov::element::dynamic, ps);
|
||||
parameter->get_output_tensor(0).add_names({std::to_string(pytorch_model->input(i))});
|
||||
auto parameter = std::make_shared<opset10::Parameter>(element::dynamic, ps);
|
||||
parameter->get_output_tensor(0).add_names({std::to_string(inputs.at(i))});
|
||||
parameters.push_back(parameter);
|
||||
auto order = pytorch_model->get_input_transpose_order(i);
|
||||
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
|
||||
@ -322,9 +325,9 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> p
|
||||
auto reshape = std::make_shared<opset10::Reshape>(parameter, shape_const, false);
|
||||
auto order_const = opset10::Constant::create(element::i32, {order.size()}, order);
|
||||
auto transpose = std::make_shared<opset10::Transpose>(reshape, order_const);
|
||||
tensor_map[pytorch_model->input(i)] = transpose;
|
||||
tensor_map[inputs.at(i)] = transpose;
|
||||
} else {
|
||||
tensor_map[pytorch_model->input(i)] = parameter;
|
||||
tensor_map[inputs.at(i)] = parameter;
|
||||
}
|
||||
}
|
||||
|
||||
@ -335,7 +338,7 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> p
|
||||
|
||||
auto raw_inputs = node->inputs();
|
||||
for (size_t i = 0; i < raw_inputs.size(); ++i) {
|
||||
auto input = node->input(i);
|
||||
auto input = raw_inputs.at(i);
|
||||
if (tensor_map.find(input) == tensor_map.end()) {
|
||||
// Input refers value in the outer scope, need to create a new Parameter in the current scope
|
||||
// Linkage to external scope will be performed on the level of the parent operation (if or loop)
|
||||
@ -426,7 +429,7 @@ std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> p
|
||||
results.push_back(std::make_shared<opset10::Result>(tensor_map.at(tensor_id)));
|
||||
}
|
||||
}
|
||||
resulting_model = std::make_shared<ov::Model>(results, parameters);
|
||||
resulting_model = std::make_shared<Model>(results, parameters);
|
||||
// Did a conversion in a nested scope to automatically remove any holders of nodes except those in the graph
|
||||
}
|
||||
|
||||
@ -474,10 +477,7 @@ std::unordered_map<size_t, element::Type> bit_to_int{
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void align_eltwise_input_types(const NodeContext& context,
|
||||
ov::Output<ov::Node>& lhs,
|
||||
ov::Output<ov::Node>& rhs,
|
||||
bool align_scalars) {
|
||||
void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Output<Node>& rhs, bool align_scalars) {
|
||||
const auto& lhs_type = lhs.get_element_type();
|
||||
const auto& rhs_type = rhs.get_element_type();
|
||||
if (lhs_type.is_dynamic() || rhs_type.is_dynamic()) {
|
||||
|
@ -25,9 +25,9 @@ Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||
size_t bias_input_idx,
|
||||
const std::vector<int>& unsqueeze_dims = {});
|
||||
|
||||
Output<ov::Node> reshape_channelwise(const NodeContext& context,
|
||||
Output<ov::Node> data,
|
||||
Output<ngraph::Node> shape_source);
|
||||
Output<Node> reshape_channelwise(const NodeContext& context,
|
||||
const Output<Node>& data,
|
||||
const Output<Node>& shape_source);
|
||||
|
||||
std::tuple<Output<Node>, Output<Node>> get_shape_rank(const NodeContext& context,
|
||||
const Output<Node>& x,
|
||||
@ -38,26 +38,26 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
|
||||
|
||||
std::shared_ptr<Node> get_axes_range(const NodeContext& context, size_t input_id);
|
||||
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, size_t input_id);
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);
|
||||
|
||||
element::Type convert_dtype(int64_t dtype_value);
|
||||
ov::op::PadType convert_pad(const std::string& pt_pad);
|
||||
op::PadType convert_pad(const std::string& pt_pad);
|
||||
|
||||
std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input);
|
||||
|
||||
std::shared_ptr<ov::Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> pytorch_model,
|
||||
const TensorMap& external_tensor_map = {});
|
||||
std::shared_ptr<Model> convert_pytorch_model(std::shared_ptr<TorchDecoder> pytorch_model,
|
||||
const TensorMap& external_tensor_map = {});
|
||||
|
||||
OutputVector convert_node(NodeContext* context);
|
||||
|
||||
std::shared_ptr<ov::op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
|
||||
std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
|
||||
|
||||
// TODO: Elimitate the need of this function by implementing more accurate custom data type handling
|
||||
Any simplified_type_interpret(Any type);
|
||||
|
||||
void align_eltwise_input_types(const NodeContext& context,
|
||||
ov::Output<ov::Node>& lhs,
|
||||
ov::Output<ov::Node>& rhs,
|
||||
Output<Node>& lhs,
|
||||
Output<Node>& rhs,
|
||||
bool align_scalars = false);
|
||||
|
||||
namespace op {
|
||||
@ -72,36 +72,24 @@ OutputVector inplace_op(NodeContext& context) {
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_1_inputs(NodeContext& context) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 1, "Operation has no inputs.");
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||
}
|
||||
num_inputs_check(context, 1, 1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None.");
|
||||
return {context.mark_node(std::make_shared<T>(inputs[0]))};
|
||||
return {context.mark_node(std::make_shared<T>(context.get_input(0)))};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_2_inputs(NodeContext& context) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Operation has less then 2 inputs.");
|
||||
for (size_t i = 2; i < inputs.size(); i++) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||
}
|
||||
num_inputs_check(context, 2, 2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
|
||||
return {context.mark_node(std::make_shared<T>(inputs[0], inputs[1]))};
|
||||
return {context.mark_node(std::make_shared<T>(context.get_input(0), context.get_input(1)))};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) {
|
||||
auto inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Operation has less then 2 inputs.");
|
||||
for (size_t i = 2; i < inputs.size(); i++) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(context.input_is_none(i), "Got more inputs than expected.");
|
||||
}
|
||||
num_inputs_check(context, 2, 2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
|
||||
auto lhs = inputs[0];
|
||||
auto rhs = inputs[1];
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
align_eltwise_input_types(context, lhs, rhs);
|
||||
return {context.mark_node(std::make_shared<T>(lhs, rhs))};
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user