[PT FE] Make NodeContext constant inside conversion rules (#16165)
* Make NodeContext constant inside conversion rules * Use shared_ptr * Fix ptr * Fix logical not
This commit is contained in:
parent
4ffecce63f
commit
7f8786d9aa
@ -60,7 +60,7 @@ protected:
|
||||
bool supported_impl(const std::vector<ov::Any>& variants) const override;
|
||||
ov::frontend::InputModel::Ptr load_impl(const std::vector<ov::Any>& variants) const override;
|
||||
|
||||
std::map<std::string, PytorchCreatorFunction> m_op_translators;
|
||||
std::map<std::string, CreatorFunction> m_op_translators;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
|
@ -19,20 +19,22 @@ typedef std::unordered_map<size_t, Output<Node>> TensorMap;
|
||||
class NodeContext : public frontend::NodeContext {
|
||||
public:
|
||||
NodeContext(std::shared_ptr<TorchDecoder> decoder,
|
||||
TensorMap* tensor_map,
|
||||
ParameterVector* external_parameters,
|
||||
const TensorMap& ext_tensor_map,
|
||||
std::shared_ptr<TensorMap> tensor_map,
|
||||
std::shared_ptr<ParameterVector> external_parameters,
|
||||
std::shared_ptr<std::set<size_t>> mutated_tensors,
|
||||
TranslateSession* translate_session)
|
||||
: frontend::NodeContext(decoder->get_op_type()),
|
||||
m_decoder(decoder),
|
||||
m_tensor_map(tensor_map),
|
||||
m_ext_tensor_map(ext_tensor_map),
|
||||
m_tensor_map(tensor_map),
|
||||
m_external_parameters(external_parameters),
|
||||
m_mutated_tensors(mutated_tensors),
|
||||
m_translate_session(translate_session),
|
||||
m_decoder_inputs(decoder->inputs()),
|
||||
m_decoder_outputs(decoder->outputs()) {
|
||||
FRONT_END_GENERAL_CHECK(tensor_map != nullptr && external_parameters != nullptr &&
|
||||
translate_session != nullptr);
|
||||
FRONT_END_GENERAL_CHECK(m_tensor_map != nullptr && m_external_parameters != nullptr &&
|
||||
m_mutated_tensors != nullptr && m_translate_session != nullptr);
|
||||
}
|
||||
|
||||
// Do not search for input in tensor map; try to access it as a constant of specified type T and return its value
|
||||
@ -106,11 +108,7 @@ public:
|
||||
"There is no any named attributes in PyTorch node, query by attribute name is not implemented");
|
||||
}
|
||||
|
||||
void mutate_input(size_t index, Output<Node> ov_output);
|
||||
|
||||
std::set<size_t> get_mutated_tensors() const {
|
||||
return m_mutated_tensors;
|
||||
}
|
||||
void mutate_input(size_t index, Output<Node> ov_output) const;
|
||||
|
||||
std::shared_ptr<TorchDecoder> get_decoder() const {
|
||||
return m_decoder;
|
||||
@ -120,7 +118,7 @@ public:
|
||||
return m_translate_session;
|
||||
}
|
||||
|
||||
void add_tensor_to_context(size_t index, Output<Node> ov_output);
|
||||
void add_tensor_to_context(size_t index, Output<Node> ov_output) const;
|
||||
|
||||
Output<Node> get_tensor_from_model(size_t index) const {
|
||||
if (m_tensor_map->find(index) != m_tensor_map->end()) {
|
||||
@ -130,22 +128,22 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
Output<Node> get_tensor_from_model_or_create_input(size_t index);
|
||||
Output<Node> get_tensor_from_model_or_create_input(size_t index) const;
|
||||
Output<Node> get_input_from_visible_context(size_t index) const;
|
||||
std::shared_ptr<ov::Model> convert_subgraph(size_t index);
|
||||
std::shared_ptr<ov::Model> convert_subgraph(size_t index) const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<TorchDecoder> m_decoder;
|
||||
std::set<size_t> m_mutated_tensors;
|
||||
TensorMap* m_tensor_map;
|
||||
const TensorMap& m_ext_tensor_map;
|
||||
ParameterVector* m_external_parameters;
|
||||
TranslateSession* m_translate_session;
|
||||
std::shared_ptr<TensorMap> m_tensor_map;
|
||||
std::shared_ptr<ParameterVector> m_external_parameters;
|
||||
std::shared_ptr<std::set<size_t>> m_mutated_tensors;
|
||||
TranslateSession* m_translate_session = nullptr;
|
||||
const std::vector<size_t> m_decoder_inputs;
|
||||
const std::vector<size_t> m_decoder_outputs;
|
||||
};
|
||||
|
||||
using PytorchCreatorFunction = std::function<OutputVector(NodeContext&)>;
|
||||
using CreatorFunction = std::function<ov::OutputVector(const ov::frontend::pytorch::NodeContext&)>;
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -42,16 +42,16 @@ std::shared_ptr<Node> NodeContext::mark_node(std::shared_ptr<Node> ov_node) cons
|
||||
return m_decoder->mark_node(ov_node);
|
||||
}
|
||||
|
||||
void NodeContext::mutate_input(size_t index, Output<Node> ov_output) {
|
||||
void NodeContext::mutate_input(size_t index, Output<Node> ov_output) const {
|
||||
FRONT_END_GENERAL_CHECK(!m_decoder->input_is_none(index), "Input is none with index: ", index);
|
||||
auto input_id = m_decoder_inputs.at(index);
|
||||
FRONT_END_GENERAL_CHECK(m_tensor_map->count(input_id), "No tensor corresponding input: ", input_id, " exist.");
|
||||
m_translate_session->encode_tensor_name(ov_output, input_id, m_decoder->get_input_debug_name(index));
|
||||
(*m_tensor_map)[input_id] = ov_output;
|
||||
m_mutated_tensors.insert(input_id);
|
||||
m_mutated_tensors->insert(input_id);
|
||||
}
|
||||
|
||||
void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) {
|
||||
void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) const {
|
||||
if (m_tensor_map->count(index)) {
|
||||
OPENVINO_DEBUG << "[ WARNING ] Current context has tensor. Rewriting.\n";
|
||||
}
|
||||
@ -59,7 +59,7 @@ void NodeContext::add_tensor_to_context(size_t index, Output<Node> ov_output) {
|
||||
(*m_tensor_map)[index] = ov_output;
|
||||
}
|
||||
|
||||
Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) {
|
||||
Output<Node> NodeContext::get_tensor_from_model_or_create_input(size_t index) const {
|
||||
if (m_tensor_map->find(index) != m_tensor_map->end()) {
|
||||
return m_tensor_map->at(index);
|
||||
} else {
|
||||
@ -87,7 +87,7 @@ Output<Node> NodeContext::get_input_from_visible_context(size_t index) const {
|
||||
return input_tensor;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) {
|
||||
std::shared_ptr<ov::Model> NodeContext::convert_subgraph(size_t index) const {
|
||||
auto subgraph_decoder = m_decoder->get_subgraph_decoder(index);
|
||||
|
||||
// Extend external context with internal tensors except Parameter nodes, because internal Parameters are created to
|
||||
|
@ -19,7 +19,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_adaptive_avg_pool3d(NodeContext& context) {
|
||||
OutputVector translate_adaptive_avg_pool3d(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_adaptive_max_pool2d(NodeContext& context) {
|
||||
OutputVector translate_adaptive_max_pool2d(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -15,7 +15,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_add(NodeContext& context) {
|
||||
OutputVector translate_add(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
|
@ -17,7 +17,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_addcmul(NodeContext& context) {
|
||||
OutputVector translate_addcmul(const NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
const auto eltwise_mult = std::make_shared<v1::Multiply>(context.get_input(1), context.get_input(2));
|
||||
const auto value = context.get_input(3);
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_addmm(NodeContext& context) {
|
||||
OutputVector translate_addmm(const NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
auto input = context.get_input(0);
|
||||
auto m1 = context.get_input(1);
|
||||
|
@ -17,7 +17,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_arange(NodeContext& context) {
|
||||
OutputVector translate_arange(const NodeContext& context) {
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
int dtype_port = -1;
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_as_tensor(NodeContext& context) {
|
||||
OutputVector translate_as_tensor(const NodeContext& context) {
|
||||
// aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor
|
||||
num_inputs_check(context, 1, 4);
|
||||
auto dtype = element::f32;
|
||||
|
@ -18,7 +18,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_avg_poolnd(NodeContext& context) {
|
||||
OutputVector translate_avg_poolnd(const NodeContext& context) {
|
||||
num_inputs_check(context, 6, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto kernel = context.const_input<Shape>(1);
|
||||
|
@ -32,7 +32,7 @@ Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_batch_norm(NodeContext& context) {
|
||||
OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
// Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var,
|
||||
// bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
||||
num_inputs_check(context, 8, 9);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_bitwise_not(NodeContext& context) {
|
||||
OutputVector translate_bitwise_not(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto x = context.get_input(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(x.get_element_type().compatible(element::boolean),
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_bool(NodeContext& context) {
|
||||
OutputVector translate_bool(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::boolean))};
|
||||
};
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_cat(NodeContext& context) {
|
||||
OutputVector translate_cat(const NodeContext& context) {
|
||||
// This translator is only needed to get axis as constant from external scope
|
||||
num_inputs_check(context, 2, 2);
|
||||
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_clamp(NodeContext& context) {
|
||||
OutputVector translate_clamp(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
if (!context.input_is_none(1)) {
|
||||
|
@ -9,7 +9,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_constant(NodeContext& context) {
|
||||
OutputVector translate_constant(const NodeContext& context) {
|
||||
return context.as_constant();
|
||||
};
|
||||
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_conv_transposend(NodeContext& context) {
|
||||
OutputVector translate_conv_transposend(const NodeContext& context) {
|
||||
num_inputs_check(context, 8, 8);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
// PyTorch support only symmetric padding, padding sizes are the same for begins and ends for each dimension
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_convnd(NodeContext& context) {
|
||||
OutputVector translate_convnd(const NodeContext& context) {
|
||||
num_inputs_check(context, 7, 7);
|
||||
auto strides = context.const_input<Strides>(3);
|
||||
// In torch pads at beginning are same as at end
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_convolution(NodeContext& context) {
|
||||
OutputVector translate_convolution(const NodeContext& context) {
|
||||
// Schema: aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[]
|
||||
// dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool
|
||||
// cudnn_enabled, bool allow_tf32) -> Tensor
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_convolution_mode(NodeContext& context) {
|
||||
OutputVector translate_convolution_mode(const NodeContext& context) {
|
||||
// Schema: aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, int[] stride, str padding, int[]
|
||||
// dilation, int groups) -> Tensor
|
||||
num_inputs_check(context, 7, 7);
|
||||
|
@ -13,7 +13,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_cumsum(NodeContext& context) {
|
||||
OutputVector translate_cumsum(const NodeContext& context) {
|
||||
// aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None, Tensor out=None)
|
||||
num_inputs_check(context, 2, 4);
|
||||
auto x = context.get_input(0);
|
||||
|
@ -12,7 +12,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_dim(NodeContext& context) {
|
||||
OutputVector translate_dim(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
|
||||
|
@ -17,7 +17,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_div(NodeContext& context) {
|
||||
OutputVector translate_div(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_elu(NodeContext& context) {
|
||||
OutputVector translate_elu(const NodeContext& context) {
|
||||
// aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
|
||||
num_inputs_check(context, 2, 4);
|
||||
auto x = context.get_input(0);
|
||||
|
@ -13,7 +13,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_embedding(NodeContext& context) {
|
||||
OutputVector translate_embedding(const NodeContext& context) {
|
||||
// aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool
|
||||
// sparse=False)
|
||||
num_inputs_check(context, 5, 5);
|
||||
|
@ -30,7 +30,7 @@ OutputVector base_expand(const NodeContext& context, const Output<Node>& x, cons
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_expand(NodeContext& context) {
|
||||
OutputVector translate_expand(const NodeContext& context) {
|
||||
// aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
@ -41,7 +41,7 @@ OutputVector translate_expand(NodeContext& context) {
|
||||
return base_expand(context, x, sizes);
|
||||
};
|
||||
|
||||
OutputVector translate_expand_as(NodeContext& context) {
|
||||
OutputVector translate_expand_as(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_eye(NodeContext& context) {
|
||||
OutputVector translate_eye(const NodeContext& context) {
|
||||
size_t num_inputs = context.get_input_size();
|
||||
auto x = context.get_input(0);
|
||||
// num rows and cols should be integer, but at the moment conversion their data type can be unknown yet
|
||||
|
@ -18,7 +18,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_flatten(NodeContext& context) {
|
||||
OutputVector translate_flatten(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
int64_t start_dim = 0;
|
||||
|
@ -14,7 +14,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_floor_divide(NodeContext& context) {
|
||||
OutputVector translate_floor_divide(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_floordiv(NodeContext& context) {
|
||||
OutputVector translate_floordiv(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -42,7 +42,7 @@ Output<Node> base_translate_full_with_convert(const NodeContext& context,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_full(NodeContext& context) {
|
||||
OutputVector translate_full(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto sizes = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
@ -59,7 +59,7 @@ OutputVector translate_full(NodeContext& context) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
|
||||
};
|
||||
|
||||
OutputVector translate_full_like(NodeContext& context) {
|
||||
OutputVector translate_full_like(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
@ -71,7 +71,7 @@ OutputVector translate_full_like(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, out)};
|
||||
};
|
||||
|
||||
OutputVector translate_fill_(NodeContext& context) {
|
||||
OutputVector translate_fill_(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
@ -79,7 +79,7 @@ OutputVector translate_fill_(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
OutputVector translate_new_full(NodeContext& context) {
|
||||
OutputVector translate_new_full(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
@ -90,7 +90,7 @@ OutputVector translate_new_full(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
OutputVector translate_zeros(NodeContext& context) {
|
||||
OutputVector translate_zeros(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 5);
|
||||
auto sizes = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
@ -107,7 +107,7 @@ OutputVector translate_zeros(NodeContext& context) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
|
||||
};
|
||||
|
||||
OutputVector translate_zeros_like(NodeContext& context) {
|
||||
OutputVector translate_zeros_like(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
@ -119,7 +119,7 @@ OutputVector translate_zeros_like(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, out)};
|
||||
};
|
||||
|
||||
OutputVector translate_new_zeros(NodeContext& context) {
|
||||
OutputVector translate_new_zeros(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
@ -130,7 +130,7 @@ OutputVector translate_new_zeros(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
OutputVector translate_ones(NodeContext& context) {
|
||||
OutputVector translate_ones(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 5);
|
||||
auto sizes = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
@ -147,7 +147,7 @@ OutputVector translate_ones(NodeContext& context) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, dtype_id)};
|
||||
};
|
||||
|
||||
OutputVector translate_ones_like(NodeContext& context) {
|
||||
OutputVector translate_ones_like(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
@ -159,7 +159,7 @@ OutputVector translate_ones_like(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, out)};
|
||||
};
|
||||
|
||||
OutputVector translate_new_ones(NodeContext& context) {
|
||||
OutputVector translate_new_ones(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto sizes = context.get_input(1);
|
||||
@ -170,7 +170,7 @@ OutputVector translate_new_ones(NodeContext& context) {
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
OutputVector translate_empty(NodeContext& context) {
|
||||
OutputVector translate_empty(const NodeContext& context) {
|
||||
// aten::empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
|
||||
// pin_memory=None, MemoryFormat? memory_format=None) -> Tensor layout, device and work with memory ignored on our
|
||||
// side, so just skip these parameters
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_gelu(NodeContext& context) {
|
||||
OutputVector translate_gelu(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto approximate = context.const_input<std::string>(1);
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_get_attr(NodeContext& context) {
|
||||
OutputVector translate_get_attr(const NodeContext& context) {
|
||||
auto res = context.get_decoder()->try_decode_get_attr();
|
||||
FRONT_END_OP_CONVERSION_CHECK(res.size() > 0, "GetAttr must have at least one output.");
|
||||
return res;
|
||||
|
@ -13,7 +13,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_getitem(NodeContext& context) {
|
||||
OutputVector translate_getitem(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input = context.get_input(0);
|
||||
if (std::dynamic_pointer_cast<ov::op::util::FrameworkNode>(input.get_node_shared_ptr())) {
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_glu(NodeContext& context) {
|
||||
OutputVector translate_glu(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto dim = context.input_is_none(1) ? context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}))
|
||||
|
@ -13,7 +13,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_grid_sampler(NodeContext& context) {
|
||||
OutputVector translate_grid_sampler(const NodeContext& context) {
|
||||
num_inputs_check(context, 4, 5);
|
||||
auto x = context.get_input(0);
|
||||
auto grid = context.get_input(1);
|
||||
|
@ -20,7 +20,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_group_norm(NodeContext& context) {
|
||||
OutputVector translate_group_norm(const NodeContext& context) {
|
||||
// aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float
|
||||
// eps=1.0000000000000001e-05, bool cudnn_enabled=True) -> Tensor
|
||||
num_inputs_check(context, 2, 6);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_hardtanh(NodeContext& context) {
|
||||
OutputVector translate_hardtanh(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
float min = -1;
|
||||
float max = 1;
|
||||
|
@ -13,7 +13,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_if(NodeContext& context) {
|
||||
OutputVector translate_if(const NodeContext& context) {
|
||||
auto if_node = std::make_shared<opset10::If>(context.get_input(0));
|
||||
context.mark_node(if_node);
|
||||
auto decoder = context.get_decoder();
|
||||
|
@ -56,7 +56,7 @@ std::shared_ptr<Node> get_im2col_indices_along_dim(const NodeContext& context,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_im2col(NodeContext& context) {
|
||||
OutputVector translate_im2col(const NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
auto input = context.get_input(0);
|
||||
auto kernel_size = context.const_input<std::vector<int64_t>>(1);
|
||||
|
@ -10,9 +10,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_index_put_(NodeContext& context) {
|
||||
OutputVector translate_index_put_(const NodeContext& context) {
|
||||
// Pass as PtFrameworkNode to register as `inplace_op`. Conversion to OV operators is done as transformation.
|
||||
auto node = std::make_shared<PtFrameworkNode>(context.get_decoder(), context.inputs());
|
||||
return {context.mark_node(node)};
|
||||
|
@ -88,7 +88,7 @@ OutputVector translate_instance_norm_train(const NodeContext& context,
|
||||
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_instance_norm(NodeContext& context) {
|
||||
OutputVector translate_instance_norm(const NodeContext& context) {
|
||||
num_inputs_check(context, 8, 9);
|
||||
auto input = context.get_input(0);
|
||||
auto eps = context.const_input<float>(7);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_int(NodeContext& context) {
|
||||
OutputVector translate_int(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Convert>(context.get_input(0), element::i32))};
|
||||
};
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_layer_norm(NodeContext& context) {
|
||||
OutputVector translate_layer_norm(const NodeContext& context) {
|
||||
num_inputs_check(context, 5, 6);
|
||||
auto eps = context.const_input<float>(4);
|
||||
auto normalized_shape = context.const_input<Shape>(1);
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_len(NodeContext& context) {
|
||||
OutputVector translate_len(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_linear(NodeContext& context) {
|
||||
OutputVector translate_linear(const NodeContext& context) {
|
||||
// schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_list_construct(NodeContext& context) {
|
||||
OutputVector translate_list_construct(const NodeContext& context) {
|
||||
// Process the case when prim::ListConstruct has all inputs constant
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
ov::OutputVector consts;
|
||||
|
@ -17,7 +17,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_log(NodeContext& context) {
|
||||
OutputVector translate_log(const NodeContext& context) {
|
||||
// torch.log returns a tensor with the natural logarithm of the elements of input.
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
@ -26,7 +26,7 @@ OutputVector translate_log(NodeContext& context) {
|
||||
return {log};
|
||||
};
|
||||
|
||||
OutputVector translate_log2(NodeContext& context) {
|
||||
OutputVector translate_log2(const NodeContext& context) {
|
||||
// torch.log2 returns a tensor with the logarithm to the base 2 of the elements of input.
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
|
@ -13,7 +13,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_loop(NodeContext& context) {
|
||||
OutputVector translate_loop(const NodeContext& context) {
|
||||
const auto& inputs = context.inputs();
|
||||
FRONT_END_OP_CONVERSION_CHECK(inputs.size() >= 2, "Loop must have at least 2 inputs.");
|
||||
auto loop = std::make_shared<ov::op::v5::Loop>(inputs[0], inputs[1]);
|
||||
|
@ -18,7 +18,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_masked_fill(NodeContext& context) {
|
||||
OutputVector translate_masked_fill(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto data = context.get_input(0);
|
||||
auto mask = context.get_input(1);
|
||||
|
@ -13,7 +13,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_max_poolnd(NodeContext& context) {
|
||||
OutputVector translate_max_poolnd(const NodeContext& context) {
|
||||
num_inputs_check(context, 6, 6);
|
||||
auto kernel = context.const_input<Shape>(1);
|
||||
auto strides = context.const_input<Strides>(2);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_mean(NodeContext& context) {
|
||||
OutputVector translate_mean(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 4);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -10,7 +10,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_meshgrid(NodeContext& context) {
|
||||
OutputVector translate_meshgrid(const NodeContext& context) {
|
||||
std::string indexing = "ij";
|
||||
if (!context.input_is_none(1)) {
|
||||
indexing = context.const_input<std::string>(1);
|
||||
|
@ -20,7 +20,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_max(NodeContext& context) {
|
||||
OutputVector translate_max(const NodeContext& context) {
|
||||
// torch.max (same for torch.min) actually has two interfaces smashed together:
|
||||
// torch.max(x, dim, keepdim) and torch.max(x, y)
|
||||
num_inputs_check(context, 1, 3);
|
||||
@ -49,7 +49,7 @@ OutputVector translate_max(NodeContext& context) {
|
||||
return {values, indicies};
|
||||
};
|
||||
|
||||
OutputVector translate_min(NodeContext& context) {
|
||||
OutputVector translate_min(const NodeContext& context) {
|
||||
// torch.min (same for torch.max) actually has two interfaces smashed together:
|
||||
// torch.min(x, dim, keepdim) and torch.min(x, y)
|
||||
num_inputs_check(context, 1, 3);
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_narrow(NodeContext& context) {
|
||||
OutputVector translate_narrow(const NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_neg(NodeContext& context) {
|
||||
OutputVector translate_neg(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
|
@ -18,7 +18,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_nms(NodeContext& context) {
|
||||
OutputVector translate_nms(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_nonzero(NodeContext& context) {
|
||||
OutputVector translate_nonzero(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto cond = context.get_input(0);
|
||||
auto non_zero = context.mark_node(std::make_shared<v3::NonZero>(cond));
|
||||
|
@ -20,7 +20,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_norm(NodeContext& context) {
|
||||
OutputVector translate_norm(const NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto p = context.const_input<float>(1);
|
||||
|
@ -10,7 +10,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_numel(NodeContext& context) {
|
||||
OutputVector translate_numel(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
return {numel(context, context.get_input(0))};
|
||||
};
|
||||
|
@ -22,7 +22,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_pad(NodeContext& context) {
|
||||
OutputVector translate_pad(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 4);
|
||||
auto data = context.get_input(0);
|
||||
auto paddings = context.const_input<std::vector<int64_t>>(1);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_pow(NodeContext& context) {
|
||||
OutputVector translate_pow(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto lhs = context.get_input(0);
|
||||
auto rhs = context.get_input(1);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_pythonop(NodeContext& context) {
|
||||
OutputVector translate_pythonop(const NodeContext& context) {
|
||||
auto decoder = context.get_decoder();
|
||||
FRONT_END_OP_CONVERSION_CHECK(decoder->get_subgraph_size() == 1,
|
||||
"PythonOp must have 1 subgraph to be able to translate it to OV.");
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_reciprocal(NodeContext& context) {
|
||||
OutputVector translate_reciprocal(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_relu6(NodeContext& context) {
|
||||
OutputVector translate_relu6(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
return {context.mark_node(std::make_shared<ov::op::v0::Clamp>(x, 0., 6.))};
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_remainder(NodeContext& context) {
|
||||
OutputVector translate_remainder(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_repeat(NodeContext& context) {
|
||||
OutputVector translate_repeat(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto repeats = context.get_input(1);
|
||||
|
@ -34,7 +34,7 @@ OutputVector generate_indices_from_repeats_tensor(const NodeContext& context, co
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_repeat_interleave(NodeContext& context) {
|
||||
OutputVector translate_repeat_interleave(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
// constants
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_reshape(NodeContext& context) {
|
||||
OutputVector translate_reshape(const NodeContext& context) {
|
||||
// Translation is used by both aten::view and aten::reshape.
|
||||
// Schema: aten::view(Tensor input, int[] shape) -> Tensor
|
||||
// Schema: aten::reshape(Tensor input, int[] shape) -> Tensor
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_reshape_as(NodeContext& context) {
|
||||
OutputVector translate_reshape_as(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto shape_tesnor = context.get_input(1);
|
||||
|
@ -19,7 +19,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_roi_align(NodeContext& context) {
|
||||
OutputVector translate_roi_align(const NodeContext& context) {
|
||||
num_inputs_check(context, 7, 7);
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
|
@ -17,7 +17,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_roll(NodeContext& context) {
|
||||
OutputVector translate_roll(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
const auto data = context.get_input(0);
|
||||
const auto shifts = context.get_input(1);
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_rsqrt(NodeContext& context) {
|
||||
OutputVector translate_rsqrt(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto data = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_rsub(NodeContext& context) {
|
||||
OutputVector translate_rsub(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto self = context.get_input(0);
|
||||
auto other = context.get_input(1);
|
||||
|
@ -20,7 +20,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_select(NodeContext& context) {
|
||||
OutputVector translate_select(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_minus_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_selu(NodeContext& context) {
|
||||
OutputVector translate_selu(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto alpha = context.mark_node(v0::Constant::create(element::f64, Shape{}, {1.6732632423543772848170429916717}));
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_set_item(NodeContext& context) {
|
||||
OutputVector translate_set_item(const NodeContext& context) {
|
||||
// schema: aten::_set_item.t(t[](a!) l, int idx, t(b -> *) el) -> t[](a!)
|
||||
// _set_item inserts element in list
|
||||
num_inputs_check(context, 3, 3);
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_size(NodeContext& context) {
|
||||
OutputVector translate_size(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(context.get_input(0), element::i32));
|
||||
if (context.input_is_none(1)) {
|
||||
|
@ -18,7 +18,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_slice(NodeContext& context) {
|
||||
OutputVector translate_slice(const NodeContext& context) {
|
||||
// aten::slice.t(t[] l, int? start=None, int? end=None, int step=1) -> (t[])
|
||||
// aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
|
||||
ov::Output<ov::Node> dim;
|
||||
|
@ -13,7 +13,7 @@ namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
OutputVector translate_softmax(NodeContext& context) {
|
||||
OutputVector translate_softmax(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto axis = context.const_input<int64_t>(1);
|
||||
|
@ -9,7 +9,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_sort(NodeContext& context) {
|
||||
OutputVector translate_sort(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 4);
|
||||
const auto input_tensor = context.get_input(0);
|
||||
bool stable, descending;
|
||||
@ -40,7 +40,7 @@ OutputVector translate_sort(NodeContext& context) {
|
||||
return topk->outputs();
|
||||
};
|
||||
|
||||
OutputVector translate_argsort(NodeContext& context) {
|
||||
OutputVector translate_argsort(const NodeContext& context) {
|
||||
auto sort = translate_sort(context);
|
||||
return {sort[1]};
|
||||
};
|
||||
|
@ -14,7 +14,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_square(NodeContext& context) {
|
||||
OutputVector translate_square(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto input_0 = context.get_input(0);
|
||||
auto const_2 = context.mark_node(v0::Constant::create(input_0.get_element_type(), Shape{1}, {2}));
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_squeeze(NodeContext& context) {
|
||||
OutputVector translate_squeeze(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto x = context.get_input(0);
|
||||
if (context.input_is_none(1)) {
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_sub(NodeContext& context) {
|
||||
OutputVector translate_sub(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
|
@ -11,7 +11,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_sum(NodeContext& context) {
|
||||
OutputVector translate_sum(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
bool keep_dims = false;
|
||||
ov::Output<ov::Node> axes;
|
||||
|
@ -16,7 +16,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_to(NodeContext& context) {
|
||||
OutputVector translate_to(const NodeContext& context) {
|
||||
int dtype_idx;
|
||||
int memory_format_idx;
|
||||
if (context.get_input_size() == 5) {
|
||||
|
@ -15,7 +15,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_topk(NodeContext& context) {
|
||||
OutputVector translate_topk(const NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
const auto input_tensor = context.get_input(0);
|
||||
const auto largest = context.const_input<bool>(3);
|
||||
|
@ -20,7 +20,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_transpose(NodeContext& context) {
|
||||
OutputVector translate_transpose(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto dim0 = context.const_input<int64_t>(1);
|
||||
auto dim1 = context.const_input<int64_t>(2);
|
||||
|
@ -60,11 +60,11 @@ OutputVector translate_base_triu_tril(const NodeContext& context, bool upper) {
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
OutputVector translate_triu(NodeContext& context) {
|
||||
OutputVector translate_triu(const NodeContext& context) {
|
||||
return translate_base_triu_tril(context, true);
|
||||
};
|
||||
|
||||
OutputVector translate_tril(NodeContext& context) {
|
||||
OutputVector translate_tril(const NodeContext& context) {
|
||||
return translate_base_triu_tril(context, false);
|
||||
};
|
||||
|
||||
|
@ -13,7 +13,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_unfold(NodeContext& context) {
|
||||
OutputVector translate_unfold(const NodeContext& context) {
|
||||
num_inputs_check(context, 4, 4);
|
||||
// constants
|
||||
auto const_0 = context.mark_node(Constant::create(element::i32, Shape{}, {0}));
|
||||
|
@ -69,32 +69,32 @@ OutputVector base_translate_upsample(const NodeContext& context,
|
||||
};
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_upsample_linear1d(NodeContext& context) {
|
||||
OutputVector translate_upsample_linear1d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 1);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_bilinear2d(NodeContext& context) {
|
||||
OutputVector translate_upsample_bilinear2d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 2);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_trilinear3d(NodeContext& context) {
|
||||
OutputVector translate_upsample_trilinear3d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::LINEAR_ONNX, 3);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_nearest1d(NodeContext& context) {
|
||||
OutputVector translate_upsample_nearest1d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 1);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_nearest2d(NodeContext& context) {
|
||||
OutputVector translate_upsample_nearest2d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 2);
|
||||
};
|
||||
|
||||
OutputVector translate_upsample_nearest3d(NodeContext& context) {
|
||||
OutputVector translate_upsample_nearest3d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::NEAREST, 3);
|
||||
};
|
||||
|
||||
// bicubic is only supported for 2d in pytorch
|
||||
OutputVector translate_upsample_bicubic2d(NodeContext& context) {
|
||||
OutputVector translate_upsample_bicubic2d(const NodeContext& context) {
|
||||
return base_translate_upsample(context, v4::Interpolate::InterpolateMode::CUBIC, 2);
|
||||
};
|
||||
|
||||
|
@ -20,7 +20,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_var_mean(NodeContext& context) {
|
||||
OutputVector translate_var_mean(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 4);
|
||||
auto data = context.get_input(0);
|
||||
bool unbiased = true;
|
||||
@ -75,7 +75,7 @@ OutputVector translate_var_mean(NodeContext& context) {
|
||||
return {var, mean};
|
||||
};
|
||||
|
||||
OutputVector translate_var(NodeContext& context) {
|
||||
OutputVector translate_var(const NodeContext& context) {
|
||||
auto res = translate_var_mean(context);
|
||||
return {res[0]};
|
||||
}
|
||||
|
@ -14,7 +14,7 @@ namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_where(NodeContext& context) {
|
||||
OutputVector translate_where(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto cond = context.get_input(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(1), "aten::where(cond) unsupported");
|
||||
|
@ -12,7 +12,7 @@ namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
#define OP_CONVERTER(op) OutputVector op(NodeContext& node)
|
||||
#define OP_CONVERTER(op) OutputVector op(const NodeContext& node)
|
||||
|
||||
OP_CONVERTER(translate_adaptive_avg_pool3d);
|
||||
OP_CONVERTER(translate_adaptive_max_pool2d);
|
||||
@ -130,7 +130,7 @@ OP_CONVERTER(translate_zeros_like);
|
||||
|
||||
} // namespace op
|
||||
|
||||
const std::map<std::string, PytorchCreatorFunction> get_supported_ops() {
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
return {
|
||||
{"aten::__and__", op::translate_1to1_match_2_inputs<opset10::LogicalAnd>}, // TODO: cover numerical cases
|
||||
{"aten::__getitem__", op::translate_getitem},
|
||||
|
@ -10,7 +10,7 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
|
||||
const std::map<std::string, PytorchCreatorFunction> get_supported_ops();
|
||||
const std::map<std::string, CreatorFunction> get_supported_ops();
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
@ -20,7 +20,7 @@ namespace pytorch {
|
||||
using namespace ov::op;
|
||||
|
||||
TranslateSession::TranslateSession(const ov::frontend::InputModel::Ptr& input_model,
|
||||
const std::map<std::string, PytorchCreatorFunction>& translator_map)
|
||||
const std::map<std::string, CreatorFunction>& translator_map)
|
||||
: m_input_model(input_model),
|
||||
m_translator_map(translator_map),
|
||||
m_ov_model(nullptr) {}
|
||||
@ -45,9 +45,9 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
const std::unordered_map<size_t, PlaceDesc>& external_descriptors) {
|
||||
std::shared_ptr<Model> resulting_model; // define here to make a conversion in a nested scope
|
||||
{
|
||||
ParameterVector parameters;
|
||||
TensorMap tensor_map; // tensor map of the current context
|
||||
std::set<size_t> mutated_tensors;
|
||||
auto parameters = std::make_shared<ParameterVector>();
|
||||
auto tensor_map = std::make_shared<TensorMap>(); // tensor map of the current context
|
||||
auto mutated_tensors = std::make_shared<std::set<size_t>>();
|
||||
|
||||
// Go over all pytorch_model inputs and register them in the tensor map:
|
||||
auto inputs = pytorch_model->inputs();
|
||||
@ -74,7 +74,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
if (!input_node) {
|
||||
auto parameter = std::make_shared<v0::Parameter>(type, pshape);
|
||||
encode_tensor_name(parameter->output(0), inputs.at(i), pytorch_model->get_input_debug_name(i));
|
||||
parameters.push_back(parameter);
|
||||
parameters->push_back(parameter);
|
||||
input_node = parameter;
|
||||
auto order = pytorch_model->get_input_transpose_order(i);
|
||||
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
|
||||
@ -91,7 +91,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
input_node = transpose;
|
||||
}
|
||||
}
|
||||
tensor_map[inputs.at(i)] = input_node;
|
||||
(*tensor_map)[inputs.at(i)] = input_node;
|
||||
}
|
||||
|
||||
auto node_visitor = [&](std::shared_ptr<TorchDecoder> node) {
|
||||
@ -102,7 +102,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
auto raw_inputs = node->inputs();
|
||||
for (size_t i = 0; i < raw_inputs.size(); ++i) {
|
||||
auto input = raw_inputs.at(i);
|
||||
if (tensor_map.find(input) == tensor_map.end()) {
|
||||
if (tensor_map->find(input) == tensor_map->end()) {
|
||||
// Input refers value in the outer scope, need to create a new Parameter in the current scope
|
||||
// Linkage to external scope will be performed on the level of the parent operation (if or loop)
|
||||
// TODO: Eliminate duplication with the main code for Parameters creation
|
||||
@ -111,18 +111,15 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
// TODO: Use special API to set custom type specification
|
||||
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, ps);
|
||||
// TODO: Missing get_input_transpose_order handling for not trivial layouts
|
||||
tensor_map[input] = parameter;
|
||||
(*tensor_map)[input] = parameter;
|
||||
// set name of parameter to the index of node in the model
|
||||
encode_tensor_name(parameter->output(0), input);
|
||||
parameters.push_back(parameter);
|
||||
parameters->push_back(parameter);
|
||||
}
|
||||
}
|
||||
auto context = NodeContext(node, &tensor_map, ¶meters, external_tensor_map, this);
|
||||
auto context = NodeContext(node, external_tensor_map, tensor_map, parameters, mutated_tensors, this);
|
||||
auto converted_outputs = convert_node(context);
|
||||
|
||||
auto mutated_t = context.get_mutated_tensors();
|
||||
mutated_tensors.insert(mutated_t.begin(), mutated_t.end());
|
||||
|
||||
auto fw_outputs = node->outputs();
|
||||
// Ops with subgraphs or with mutated inputs may have more outputs after conversion compared to pytorch ones
|
||||
FRONT_END_OP_CONVERSION_CHECK(fw_outputs.size() <= converted_outputs.size(),
|
||||
@ -134,10 +131,10 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
// FIXME: Now it is not true for at least prim::Constant
|
||||
for (size_t i = 0; i < fw_outputs.size(); ++i) {
|
||||
size_t fw_tensor_id = node->output(i);
|
||||
FRONT_END_GENERAL_CHECK(tensor_map.find(fw_tensor_id) == tensor_map.end(),
|
||||
FRONT_END_GENERAL_CHECK(tensor_map->find(fw_tensor_id) == tensor_map->end(),
|
||||
"Duplicated producer for PT value with unique ID: ",
|
||||
fw_tensor_id);
|
||||
tensor_map[fw_tensor_id] = converted_outputs[i];
|
||||
(*tensor_map)[fw_tensor_id] = converted_outputs[i];
|
||||
encode_tensor_name(converted_outputs[i], fw_tensor_id, node->get_output_debug_name(i));
|
||||
}
|
||||
};
|
||||
@ -148,14 +145,14 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
ResultVector results;
|
||||
for (size_t i = 0; i < pytorch_model->num_of_outputs(); ++i) {
|
||||
size_t id = pytorch_model->output(i);
|
||||
if (tensor_map.find(id) == tensor_map.end()) {
|
||||
if (tensor_map->find(id) == tensor_map->end()) {
|
||||
// Not found in this scope, adding Parameter to connect to external scope
|
||||
auto parameter = std::make_shared<v0::Parameter>(element::dynamic, PartialShape::dynamic());
|
||||
encode_tensor_name(parameter->output(0), id);
|
||||
parameters.push_back(parameter);
|
||||
tensor_map[id] = parameter;
|
||||
parameters->push_back(parameter);
|
||||
(*tensor_map)[id] = parameter;
|
||||
}
|
||||
auto ov_output = tensor_map[id];
|
||||
auto ov_output = tensor_map->at(id);
|
||||
auto order = pytorch_model->get_output_transpose_order(i);
|
||||
FRONT_END_GENERAL_CHECK(order.size() == 0 || std::is_sorted(order.begin(), order.end()),
|
||||
"Output strides have wrong order.");
|
||||
@ -168,32 +165,32 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
|
||||
// Since parameters can be added we need to list all current parameters
|
||||
std::set<size_t> param_names;
|
||||
for (const auto& param : parameters) {
|
||||
for (const auto& param : *parameters) {
|
||||
auto input_idx = decode_tensor_name(param->output(0));
|
||||
param_names.insert(input_idx);
|
||||
}
|
||||
for (const auto& tensor_id : mutated_tensors) {
|
||||
for (const auto& tensor_id : *mutated_tensors) {
|
||||
if (param_names.count(tensor_id)) {
|
||||
FRONT_END_GENERAL_CHECK(tensor_map.count(tensor_id),
|
||||
FRONT_END_GENERAL_CHECK(tensor_map->count(tensor_id),
|
||||
"Tensor with id: ",
|
||||
tensor_id,
|
||||
" doesn't exist in tensor map.");
|
||||
// model input was mutated we need to make a result for it
|
||||
auto mutated_tensor = tensor_map.at(tensor_id);
|
||||
auto mutated_tensor = tensor_map->at(tensor_id);
|
||||
// empty external_tensor_map means this is main body of the model and we don't want to create
|
||||
// additional outputs in that case.
|
||||
if (mutated_tensor.get_target_inputs().empty() && !external_tensor_map.empty())
|
||||
results.push_back(std::make_shared<v0::Result>(tensor_map.at(tensor_id)));
|
||||
results.push_back(std::make_shared<v0::Result>(tensor_map->at(tensor_id)));
|
||||
}
|
||||
}
|
||||
resulting_model = std::make_shared<Model>(results, parameters);
|
||||
resulting_model = std::make_shared<Model>(results, *parameters);
|
||||
// Did a conversion in a nested scope to automatically remove any holders of nodes except those in the graph
|
||||
}
|
||||
|
||||
return resulting_model;
|
||||
}
|
||||
|
||||
OutputVector TranslateSession::convert_node(NodeContext& context) {
|
||||
OutputVector TranslateSession::convert_node(const NodeContext& context) {
|
||||
try {
|
||||
auto it = m_translator_map.find(context.get_op_type());
|
||||
if (it != m_translator_map.end()) {
|
||||
|
@ -17,7 +17,7 @@ namespace pytorch {
|
||||
class TranslateSession {
|
||||
public:
|
||||
TranslateSession(const frontend::InputModel::Ptr& input_model,
|
||||
const std::map<std::string, PytorchCreatorFunction>& translator_map);
|
||||
const std::map<std::string, CreatorFunction>& translator_map);
|
||||
std::shared_ptr<Model> get_converted_model();
|
||||
std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model);
|
||||
|
||||
@ -38,10 +38,10 @@ public:
|
||||
size_t m_friendly_name_counter = 0;
|
||||
|
||||
private:
|
||||
OutputVector convert_node(NodeContext& context);
|
||||
OutputVector convert_node(const NodeContext& context);
|
||||
|
||||
const frontend::InputModel::Ptr m_input_model;
|
||||
const std::map<std::string, PytorchCreatorFunction>& m_translator_map;
|
||||
const std::map<std::string, CreatorFunction>& m_translator_map;
|
||||
|
||||
std::shared_ptr<Model> m_ov_model;
|
||||
std::map<size_t, std::pair<size_t, Output<Node>>> m_counter_map;
|
||||
|
@ -177,7 +177,7 @@ std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input) {
|
||||
return input;
|
||||
}
|
||||
|
||||
OutputVector make_framework_node(NodeContext& context) {
|
||||
OutputVector make_framework_node(const NodeContext& context) {
|
||||
auto schema = context.get_schema();
|
||||
// TODO: properly process schema to get the actual position of mutable input
|
||||
// Hack. Can indicate mutable inputs, but can it be reliable?
|
||||
|
@ -48,7 +48,7 @@ op::PadType convert_pad(const std::string& pt_pad);
|
||||
|
||||
std::shared_ptr<Node> concat_list_construct(std::shared_ptr<Node> input);
|
||||
|
||||
OutputVector make_framework_node(NodeContext& context);
|
||||
OutputVector make_framework_node(const NodeContext& context);
|
||||
|
||||
std::shared_ptr<op::util::FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type);
|
||||
|
||||
@ -63,8 +63,8 @@ void align_eltwise_input_types(const NodeContext& context,
|
||||
std::deque<Output<Node>> get_list_as_outputs(const Output<Node>& start);
|
||||
|
||||
namespace op {
|
||||
template <OutputVector (*T)(NodeContext&), size_t idx = 0>
|
||||
OutputVector inplace_op(NodeContext& context) {
|
||||
template <OutputVector (*T)(const NodeContext&), size_t idx = 0>
|
||||
OutputVector inplace_op(const NodeContext& context) {
|
||||
auto translation_res = T(context);
|
||||
FRONT_END_OP_CONVERSION_CHECK(translation_res.size() == 1,
|
||||
"inplace_op function must be used on single output translators");
|
||||
@ -73,21 +73,21 @@ OutputVector inplace_op(NodeContext& context) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_1_inputs(NodeContext& context) {
|
||||
OutputVector translate_1to1_match_1_inputs(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0), "Input should not be None.");
|
||||
return {context.mark_node(std::make_shared<T>(context.get_input(0)))};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_2_inputs(NodeContext& context) {
|
||||
OutputVector translate_1to1_match_2_inputs(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
|
||||
return {context.mark_node(std::make_shared<T>(context.get_input(0), context.get_input(1)))};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) {
|
||||
OutputVector translate_1to1_match_2_inputs_align_types(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!context.input_is_none(0) && !context.input_is_none(1), "Inputs should not be None.");
|
||||
auto lhs = context.get_input(0);
|
||||
@ -96,11 +96,11 @@ OutputVector translate_1to1_match_2_inputs_align_types(NodeContext& context) {
|
||||
return {context.mark_node(std::make_shared<T>(lhs, rhs))};
|
||||
}
|
||||
|
||||
inline OutputVector return_false_scalar(NodeContext& context) {
|
||||
inline OutputVector return_false_scalar(const NodeContext& context) {
|
||||
return {context.mark_node(ov::op::v0::Constant::create(element::boolean, Shape{}, {false}))};
|
||||
}
|
||||
|
||||
inline OutputVector skip_node(NodeContext& context) {
|
||||
inline OutputVector skip_node(const NodeContext& context) {
|
||||
return {context.get_input(0).get_node_shared_ptr()};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user