Use i32 across all PyTorch frontend (#15896)
* Use i32 across all PyTorch frontend * Fix corner cases * Fix tests
This commit is contained in:
parent
112c763256
commit
b1d0e152e3
@ -128,18 +128,18 @@ std::vector<int64_t> NodeContext::const_input<std::vector<int64_t>>(size_t index
|
||||
}
|
||||
|
||||
template <>
|
||||
ngraph::Strides NodeContext::const_input<ngraph::Strides>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<ngraph::Strides::value_type>();
|
||||
Strides NodeContext::const_input<Strides>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<Strides::value_type>();
|
||||
}
|
||||
|
||||
template <>
|
||||
ngraph::CoordinateDiff NodeContext::const_input<ngraph::CoordinateDiff>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<ngraph::CoordinateDiff::value_type>();
|
||||
CoordinateDiff NodeContext::const_input<CoordinateDiff>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<CoordinateDiff::value_type>();
|
||||
}
|
||||
|
||||
template <>
|
||||
ngraph::Shape NodeContext::const_input<ngraph::Shape>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<ngraph::Shape::value_type>();
|
||||
Shape NodeContext::const_input<Shape>(size_t index) const {
|
||||
return get_constant_at_input(*this, index)->cast_vector<Shape::value_type>();
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -23,9 +23,9 @@ namespace {
|
||||
Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& value) {
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto zero_i = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one_i = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto zero_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one_i = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto channel_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, one_i, zero_i));
|
||||
auto channel_dim_exp = context.mark_node(std::make_shared<v0::Unsqueeze>(channel_dim, zero_i));
|
||||
return context.mark_node(std::make_shared<v3::Broadcast>(value, channel_dim_exp));
|
||||
|
@ -41,7 +41,7 @@ OutputVector translate_div(NodeContext& context) {
|
||||
if (rounding_mode == "floor") {
|
||||
res = context.mark_node(std::make_shared<v0::Floor>(res));
|
||||
} else if (rounding_mode == "trunc") {
|
||||
const auto convert = context.mark_node(std::make_shared<v0::Convert>(res, element::i64));
|
||||
const auto convert = context.mark_node(std::make_shared<v0::Convert>(res, element::i32));
|
||||
res = context.mark_node(std::make_shared<v1::ConvertLike>(convert, x));
|
||||
}
|
||||
return {res};
|
||||
|
@ -19,13 +19,13 @@ OutputVector translate_embedding(NodeContext& context) {
|
||||
num_inputs_check(context, 5, 5);
|
||||
auto data = context.get_input(0);
|
||||
auto indices = context.get_input(1);
|
||||
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i64));
|
||||
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i32));
|
||||
// skip parameters 2, 3, 4 used only during trainig:
|
||||
// padding_idx - if specified, the entries at padding_idx do not contribute to the gradient
|
||||
// scale_grad_by_freq - if given, this will scale gradients by the inverse of frequency of
|
||||
// the words in the mini-batch.
|
||||
// sparse - if True, gradient will be represented as sparse tensor
|
||||
auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<ov::op::v8::Gather>(data, indices, axis_0))};
|
||||
};
|
||||
|
||||
|
@ -20,12 +20,12 @@ OutputVector translate_eye(NodeContext& context) {
|
||||
size_t num_inputs = context.get_input_size();
|
||||
auto x = context.get_input(0);
|
||||
// num rows and cols should be integer, but at the moment conversion their data type can be unknown yet
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::i64));
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::i32));
|
||||
Output<Node> y;
|
||||
int dtype_id;
|
||||
auto dtype = element::f32;
|
||||
// aten::eye support only main diagonal
|
||||
auto diagonal = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto diagonal = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
if (num_inputs == 5) {
|
||||
// aten::eye(n, dtype, layout, device, pin_memory)
|
||||
y = x;
|
||||
@ -33,7 +33,7 @@ OutputVector translate_eye(NodeContext& context) {
|
||||
} else if (num_inputs == 6) {
|
||||
// aten::eye(n, m, dtype, layout, device, pin_memory)
|
||||
y = context.get_input(1);
|
||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::i64));
|
||||
y = context.mark_node(std::make_shared<v0::Convert>(y, element::i32));
|
||||
dtype_id = 2;
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Unsupported number of inputs: ", num_inputs, " for aten::eye");
|
||||
|
@ -64,7 +64,7 @@ OutputVector translate_full_like(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 7);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
if (context.get_input_size() == 7) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 2)};
|
||||
}
|
||||
@ -76,7 +76,7 @@ OutputVector translate_fill_(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.get_input(1);
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
return {base_translate_full_with_convertlike(context, sizes, value, input)};
|
||||
};
|
||||
|
||||
@ -112,7 +112,7 @@ OutputVector translate_zeros_like(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
if (context.get_input_size() == 6) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 1)};
|
||||
}
|
||||
@ -152,7 +152,7 @@ OutputVector translate_ones_like(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 6);
|
||||
auto input = context.get_input(0);
|
||||
auto value = context.mark_node(v0::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto sizes = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
if (context.get_input_size() == 6) {
|
||||
return {base_translate_full_with_convert(context, sizes, value, 1)};
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ using namespace ov::op;
|
||||
OutputVector translate_glu(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto dim = context.input_is_none(1) ? context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}))
|
||||
auto dim = context.input_is_none(1) ? context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}))
|
||||
: context.get_input(1);
|
||||
auto split = context.mark_node(std::make_shared<v1::Split>(x, dim, 2));
|
||||
auto first = split->output(0);
|
||||
|
@ -30,17 +30,17 @@ OutputVector translate_group_norm(NodeContext& context) {
|
||||
auto eps = static_cast<float>(context.const_input<double>(4));
|
||||
Output<Node> input_shape;
|
||||
Output<Node> input_rank;
|
||||
std::tie(input_shape, input_rank) = get_shape_rank(context, data, true, element::i64);
|
||||
auto scalar_one = context.mark_node(v0::Constant::create(element::i64, {}, {1}));
|
||||
std::tie(input_shape, input_rank) = get_shape_rank(context, data, true, element::i32);
|
||||
auto scalar_one = context.mark_node(v0::Constant::create(element::i32, {}, {1}));
|
||||
auto shape = context.mark_node(
|
||||
std::make_shared<v0::Constant>(element::i64, Shape({3}), std::vector<int64_t>{0, num_groups, -1}));
|
||||
std::make_shared<v0::Constant>(element::i32, Shape({3}), std::vector<int64_t>{0, num_groups, -1}));
|
||||
auto reshaped_input = context.mark_node(std::make_shared<v1::Reshape>(data, shape, true));
|
||||
auto reduction_axes = context.mark_node(v0::Constant::create(element::i64, Shape({1}), std::vector<int64_t>(1, 2)));
|
||||
auto reduction_axes = context.mark_node(v0::Constant::create(element::i32, Shape({1}), std::vector<int64_t>(1, 2)));
|
||||
auto reshaped_norm = context.mark_node(
|
||||
std::make_shared<v6::MVN>(reshaped_input, reduction_axes, true, eps, MVNEpsMode::INSIDE_SQRT));
|
||||
auto norm = context.mark_node(std::make_shared<v1::Reshape>(reshaped_norm, input_shape, true));
|
||||
auto skip_last = context.mark_node(std::make_shared<v1::Subtract>(input_rank, scalar_one));
|
||||
auto axes = context.mark_node(std::make_shared<v4::Range>(scalar_one, skip_last, scalar_one, element::i64));
|
||||
auto axes = context.mark_node(std::make_shared<v4::Range>(scalar_one, skip_last, scalar_one, element::i32));
|
||||
if (!context.input_is_none(2)) {
|
||||
auto weights = context.get_input(2);
|
||||
weights = context.mark_node(std::make_shared<v0::Unsqueeze>(weights, axes));
|
||||
|
@ -33,24 +33,24 @@ std::shared_ptr<Node> get_im2col_indices_along_dim(const NodeContext& context,
|
||||
int64_t dilation_d,
|
||||
int64_t padding_d,
|
||||
int64_t stride_d) {
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}));
|
||||
auto kernel_size = context.mark_node(v0::Constant::create(element::i64, Shape{}, {kernel_size_d}));
|
||||
auto padding_2 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {padding_d * 2}));
|
||||
auto stride = context.mark_node(v0::Constant::create(element::i64, Shape{}, {stride_d}));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto kernel_size = context.mark_node(v0::Constant::create(element::i32, Shape{}, {kernel_size_d}));
|
||||
auto padding_2 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {padding_d * 2}));
|
||||
auto stride = context.mark_node(v0::Constant::create(element::i32, Shape{}, {stride_d}));
|
||||
auto input_d_squeezed = context.mark_node(std::make_shared<v0::Squeeze>(input_d, zero));
|
||||
auto blocks_d = context.mark_node(std::make_shared<v1::Add>(input_d_squeezed, padding_2));
|
||||
auto subtrahend =
|
||||
context.mark_node(v0::Constant::create(element::i64, Shape{}, {dilation_d * (kernel_size_d - 1)}));
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{}, {dilation_d * (kernel_size_d - 1)}));
|
||||
blocks_d = context.mark_node(std::make_shared<v1::Subtract>(blocks_d, subtrahend));
|
||||
auto blocks_d_indices = context.mark_node(std::make_shared<v4::Range>(zero, blocks_d, stride, element::i64));
|
||||
auto blocks_d_indices = context.mark_node(std::make_shared<v4::Range>(zero, blocks_d, stride, element::i32));
|
||||
blocks_d_indices = context.mark_node(std::make_shared<v0::Unsqueeze>(blocks_d_indices, zero));
|
||||
std::vector<int64_t> rng;
|
||||
for (int64_t i = 0; i < kernel_size_d * dilation_d; i += dilation_d) {
|
||||
rng.push_back(i);
|
||||
}
|
||||
|
||||
auto kernel_grid = context.mark_node(v0::Constant::create(element::i64, Shape{rng.size()}, rng));
|
||||
auto kernel_grid = context.mark_node(v0::Constant::create(element::i32, Shape{rng.size()}, rng));
|
||||
auto kernel_mask = context.mark_node(std::make_shared<v0::Unsqueeze>(kernel_grid, minus_one));
|
||||
return context.mark_node(std::make_shared<v1::Add>(blocks_d_indices, kernel_mask));
|
||||
}
|
||||
@ -67,12 +67,12 @@ OutputVector translate_im2col(NodeContext& context) {
|
||||
FRONT_END_OP_CONVERSION_CHECK(kernel_size.size() == 2, "padding should contains 2 elements");
|
||||
auto stride = context.const_input<std::vector<int64_t>>(4);
|
||||
FRONT_END_OP_CONVERSION_CHECK(kernel_size.size() == 2, "stride should contains 2 elements");
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto zero_f = context.mark_node(v0::Constant::create(element::f32, Shape{}, {0}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto four = context.mark_node(v0::Constant::create(element::i64, Shape{}, {4}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
|
||||
auto four = context.mark_node(v0::Constant::create(element::i32, Shape{}, {4}));
|
||||
auto input_shape_split = context.mark_node(std::make_shared<v1::Split>(input_shape, zero, 4));
|
||||
auto input_b = input_shape_split->output(0);
|
||||
auto input_c = input_shape_split->output(1);
|
||||
@ -88,20 +88,20 @@ OutputVector translate_im2col(NodeContext& context) {
|
||||
auto kernel_w = kernel_size[1];
|
||||
auto blocks_row_indices = get_im2col_indices_along_dim(context, input_h, kernel_h, dilation_h, padding_h, stride_h);
|
||||
auto blocks_col_indices = get_im2col_indices_along_dim(context, input_w, kernel_w, dilation_w, padding_w, stride_w);
|
||||
auto kernel_window = context.mark_node(v0::Constant::create(element::i64, Shape{}, {kernel_h * kernel_w}));
|
||||
auto kernel_window = context.mark_node(v0::Constant::create(element::i32, Shape{}, {kernel_h * kernel_w}));
|
||||
auto input_c_squeezed = context.mark_node(std::make_shared<v0::Squeeze>(input_c, zero));
|
||||
auto channel_unfolded = context.mark_node(std::make_shared<v1::Multiply>(input_c_squeezed, kernel_window));
|
||||
auto channel_unfolded_unsqueezed = context.mark_node(std::make_shared<v0::Unsqueeze>(channel_unfolded, zero));
|
||||
auto output_shape = context.mark_node(
|
||||
std::make_shared<v0::Concat>(OutputVector{input_b, channel_unfolded_unsqueezed, minus_one}, 0));
|
||||
auto pads = context.mark_node(
|
||||
v0::Constant::create(element::i64, Shape{4}, std::vector<int64_t>{0, 0, padding_h, padding_w}));
|
||||
v0::Constant::create(element::i32, Shape{4}, std::vector<int64_t>{0, 0, padding_h, padding_w}));
|
||||
auto padded_input =
|
||||
context.mark_node(std::make_shared<v1::Pad>(input, pads, pads, zero_f, ov::op::PadMode::CONSTANT));
|
||||
auto output = context.mark_node(std::make_shared<v8::Gather>(padded_input, blocks_row_indices, two));
|
||||
output = context.mark_node(std::make_shared<v8::Gather>(output, blocks_col_indices, four));
|
||||
auto permutation_dims =
|
||||
context.mark_node(v0::Constant::create(element::i64, Shape{6}, std::vector<int64_t>{0, 1, 2, 4, 3, 5}));
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{6}, std::vector<int64_t>{0, 1, 2, 4, 3, 5}));
|
||||
output = context.mark_node(std::make_shared<v1::Transpose>(output, permutation_dims));
|
||||
return {context.mark_node(std::make_shared<v1::Reshape>(output, output_shape, false))};
|
||||
};
|
||||
|
@ -49,14 +49,14 @@ OutputVector translate_instance_norm_train(const NodeContext& context,
|
||||
const Output<Node>& input,
|
||||
const Output<Node>& reduction_axes,
|
||||
float eps) {
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
auto batch_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, zero, zero));
|
||||
auto channel_dim = context.mark_node(std::make_shared<v8::Gather>(input_shape, one, zero));
|
||||
auto batch_dim_1d = context.mark_node(std::make_shared<v0::Unsqueeze>(batch_dim, zero));
|
||||
auto batch_norm_channels_1d = context.mark_node(std::make_shared<v1::Multiply>(batch_dim_1d, channel_dim));
|
||||
auto one_1d = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto tail_shape = context.mark_node(std::make_shared<v8::Gather>(input_shape, reduction_axes, zero));
|
||||
auto reshape_shape =
|
||||
context.mark_node(std::make_shared<v0::Concat>(OutputVector{one_1d, batch_norm_channels_1d, tail_shape}, 0));
|
||||
@ -93,10 +93,10 @@ OutputVector translate_instance_norm(NodeContext& context) {
|
||||
auto input = context.get_input(0);
|
||||
auto eps = context.const_input<float>(7);
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input, true, element::i64);
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i64, Shape{}, {2}));
|
||||
auto reduction_axes = context.mark_node(std::make_shared<v4::Range>(two, rank, one, element::i64));
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input, true, element::i32);
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {2}));
|
||||
auto reduction_axes = context.mark_node(std::make_shared<v4::Range>(two, rank, one, element::i32));
|
||||
if (context.input_is_none(3) && context.input_is_none(4)) {
|
||||
return translate_instance_norm_inference(context, input, reduction_axes, eps);
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ OutputVector translate_layer_norm(NodeContext& context) {
|
||||
"Translation for aten::layer_norm supports only single normalized_shape value, "
|
||||
"which means normalizing over the last dimension.");
|
||||
// TODO: support any dimention
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-1}));
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto out_node =
|
||||
context.mark_node(std::make_shared<v6::MVN>(context.get_input(0), axes, true, eps, MVNEpsMode::INSIDE_SQRT));
|
||||
if (!context.input_is_none(2)) {
|
||||
|
@ -18,10 +18,10 @@ using namespace ov::op;
|
||||
|
||||
OutputVector translate_len(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto input = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i64));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
|
||||
auto slice = context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, const_1, const_1));
|
||||
// Slice will return empty tensor for empty lists, we use the fact that ReduceSum(empty tensor) = 0
|
||||
|
@ -23,7 +23,7 @@ OutputVector translate_masked_fill(NodeContext& context) {
|
||||
auto data = context.get_input(0);
|
||||
auto mask = context.get_input(1);
|
||||
auto value = context.get_input(2);
|
||||
auto data_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data));
|
||||
auto data_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
|
||||
value = context.mark_node(std::make_shared<v1::ConvertLike>(value, data));
|
||||
auto broadcasted_value = context.mark_node(std::make_shared<v3::Broadcast>(value, data_shape));
|
||||
auto bool_mask = context.mark_node(std::make_shared<v0::Convert>(mask, element::boolean));
|
||||
|
@ -40,7 +40,7 @@ OutputVector translate_max(NodeContext& context) {
|
||||
auto axis_const = context.const_input<int64_t>(1);
|
||||
auto keepdims = context.const_input<bool>(2);
|
||||
auto values = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes_node, keepdims));
|
||||
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i64, Shape{}, 1));
|
||||
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{}, 1));
|
||||
auto topk = std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MAX, v3::TopK::SortType::NONE);
|
||||
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
|
||||
if (!keepdims) {
|
||||
@ -69,7 +69,7 @@ OutputVector translate_min(NodeContext& context) {
|
||||
auto axis_const = context.const_input<int64_t>(1);
|
||||
auto keepdims = context.const_input<bool>(2);
|
||||
auto values = context.mark_node(std::make_shared<v1::ReduceMin>(x, axes_node, keepdims));
|
||||
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i64, Shape{}, 1));
|
||||
auto k = context.mark_node(std::make_shared<v0::Constant>(element::i32, Shape{}, 1));
|
||||
auto topk = std::make_shared<v3::TopK>(x, k, axis_const, v3::TopK::Mode::MIN, v3::TopK::SortType::NONE);
|
||||
auto indicies = context.mark_node(std::make_shared<v0::Convert>(topk->output(1), element::i64));
|
||||
if (!keepdims) {
|
||||
|
@ -20,18 +20,18 @@ using namespace ov::op;
|
||||
|
||||
OutputVector translate_nms(NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto const_2 = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto const_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2}));
|
||||
// the shape that is required by PyTorch operator differs from the shape required in OpenVino
|
||||
auto boxes_shape = context.mark_node(v0::Constant::create(element::i64, Shape{3}, {1, -1, 4}));
|
||||
auto boxes_shape = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {1, -1, 4}));
|
||||
|
||||
auto boxes = context.mark_node(std::make_shared<v1::Reshape>(context.get_input(0), boxes_shape, false));
|
||||
// Unsqueeze operator is also used to align shapes required by PyTorch and OpenVino
|
||||
auto axis_01 = context.mark_node(v0::Constant::create(element::i64, Shape{2}, {0, 1}));
|
||||
auto axis_01 = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 1}));
|
||||
auto scores = context.mark_node(std::make_shared<v0::Unsqueeze>(context.get_input(1), axis_01));
|
||||
auto max_output_per_class =
|
||||
context.mark_node(v0::Constant::create(element::i64, Shape{1}, {std::numeric_limits<int64_t>::max()}));
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()}));
|
||||
auto iou_threshold = context.get_input(2);
|
||||
|
||||
auto nms_out =
|
||||
|
@ -19,7 +19,7 @@ OutputVector translate_nonzero(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto cond = context.get_input(0);
|
||||
auto non_zero = context.mark_node(std::make_shared<v3::NonZero>(cond));
|
||||
auto input_order = context.mark_node(v0::Constant::create(element::i64, Shape{2}, {1, 0}));
|
||||
auto input_order = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {1, 0}));
|
||||
return {context.mark_node(std::make_shared<v1::Transpose>(non_zero, input_order))};
|
||||
};
|
||||
|
||||
|
@ -52,30 +52,30 @@ OutputVector translate_pad(NodeContext& context) {
|
||||
int64_t pad_r;
|
||||
auto pad_last_id = paddings.size();
|
||||
auto cur = data.get_node_shared_ptr();
|
||||
auto step = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
for (size_t i = 0; i < pad_size_half; i++) {
|
||||
ov::NodeVector tensors;
|
||||
pad_r = paddings[pad_last_id - (2 * i + 1)];
|
||||
pad_l = paddings[pad_last_id - (2 * i + 2)];
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {2 + i}));
|
||||
auto axes = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {2 + i}));
|
||||
if (pad_l > 0) {
|
||||
auto start = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {-pad_l}));
|
||||
auto start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-pad_l}));
|
||||
auto end = context.mark_node(std::make_shared<v8::Gather>(shape, axes, zero_1d));
|
||||
|
||||
auto left = context.mark_node(std::make_shared<v8::Slice>(cur, start, end, step, axes));
|
||||
tensors.push_back(left);
|
||||
}
|
||||
if (pad_l < 0 || pad_r < 0) {
|
||||
auto start = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {pad_l < 0 ? -pad_l : 0}));
|
||||
auto end = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {pad_r < 0 ? pad_r : 0}));
|
||||
auto start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {pad_l < 0 ? -pad_l : 0}));
|
||||
auto end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {pad_r < 0 ? pad_r : 0}));
|
||||
auto middle = context.mark_node(std::make_shared<v8::Slice>(cur, start, end, step, axes));
|
||||
tensors.push_back(middle);
|
||||
} else {
|
||||
tensors.push_back(cur);
|
||||
}
|
||||
if (pad_r > 0) {
|
||||
auto end = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {pad_r}));
|
||||
auto end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {pad_r}));
|
||||
auto right = context.mark_node(std::make_shared<v8::Slice>(cur, zero_1d, end, step, axes));
|
||||
tensors.push_back(right);
|
||||
}
|
||||
|
@ -20,8 +20,8 @@ OutputVector translate_repeat(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto x = context.get_input(0);
|
||||
auto repeats = context.get_input(1);
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto sizes_shape = context.mark_node(std::make_shared<v3::ShapeOf>(repeats, element::i64));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto sizes_shape = context.mark_node(std::make_shared<v3::ShapeOf>(repeats, element::i32));
|
||||
auto expand_shape = context.mark_node(std::make_shared<v3::Broadcast>(one, sizes_shape));
|
||||
auto expanded_input =
|
||||
context.mark_node(std::make_shared<v3::Broadcast>(x, expand_shape, BroadcastType::BIDIRECTIONAL));
|
||||
|
@ -16,7 +16,7 @@ OutputVector translate_reshape_as(NodeContext& context) {
|
||||
num_inputs_check(context, 2, 2);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto shape_tesnor = context.get_input(1);
|
||||
auto desired_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(shape_tesnor));
|
||||
auto desired_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(shape_tesnor, element::i32));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Reshape>(input_tensor, desired_shape, false))};
|
||||
};
|
||||
|
||||
|
@ -30,7 +30,7 @@ OutputVector translate_roll(NodeContext& context) {
|
||||
const auto axis_0 = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
const auto flat = std::make_shared<v1::Reshape>(data, const_minus_1, false);
|
||||
const auto roll = std::make_shared<v7::Roll>(flat, shifts, axis_0);
|
||||
const auto shape_of_data = std::make_shared<v3::ShapeOf>(data);
|
||||
const auto shape_of_data = std::make_shared<v3::ShapeOf>(data, element::i32);
|
||||
const auto reshape = std::make_shared<v1::Reshape>(roll, shape_of_data, false);
|
||||
context.mark_nodes({const_minus_1, flat, roll, shape_of_data, reshape});
|
||||
return {reshape};
|
||||
|
@ -19,7 +19,7 @@ using namespace ov::op;
|
||||
OutputVector translate_rsqrt(NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto data = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
|
||||
auto one_const = context.mark_node(v0::Constant::create(element::f32, Shape({}), {1}));
|
||||
auto sqrt_data = context.mark_node(std::make_shared<v0::Sqrt>(data));
|
||||
return {context.mark_node(std::make_shared<v1::Divide>(one_const, sqrt_data))};
|
||||
|
@ -21,7 +21,7 @@ OutputVector translate_size(NodeContext& context) {
|
||||
if (context.input_is_none(1)) {
|
||||
return shape->outputs();
|
||||
} else {
|
||||
auto axis_0 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto axis_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
return {context.mark_node(std::make_shared<v8::Gather>(shape, context.get_input(1), axis_0))};
|
||||
}
|
||||
};
|
||||
|
@ -25,7 +25,7 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
int start_idx;
|
||||
int end_idx;
|
||||
int step_idx;
|
||||
auto axis_0 = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto axis_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
if (context.get_input_size() == 5) {
|
||||
dim = context.get_input(1);
|
||||
if (dim.get_partial_shape().rank().is_dynamic() || dim.get_partial_shape().rank().get_length() == 0) {
|
||||
@ -38,7 +38,7 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
start_idx = 1;
|
||||
end_idx = 2;
|
||||
step_idx = 3;
|
||||
dim = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {0}));
|
||||
dim = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
} else {
|
||||
FRONT_END_OP_CONVERSION_CHECK(false, "Slice must have either 4 or 5 inputs.");
|
||||
}
|
||||
@ -50,7 +50,7 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
start = context.mark_node(std::make_shared<v0::Unsqueeze>(start, axis_0));
|
||||
}
|
||||
} else {
|
||||
start = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {0}));
|
||||
start = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> end;
|
||||
@ -60,7 +60,7 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
end = context.mark_node(std::make_shared<v0::Unsqueeze>(end, axis_0));
|
||||
}
|
||||
} else {
|
||||
end = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {INT_MAX}));
|
||||
end = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {INT_MAX}));
|
||||
}
|
||||
ov::Output<ov::Node> step;
|
||||
if (!context.input_is_none(step_idx)) {
|
||||
@ -69,7 +69,7 @@ OutputVector translate_slice(NodeContext& context) {
|
||||
step = context.mark_node(std::make_shared<v0::Unsqueeze>(step, axis_0));
|
||||
}
|
||||
} else {
|
||||
step = context.mark_node(v0::Constant::create(element::i64, Shape{1}, {1}));
|
||||
step = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
}
|
||||
return {context.mark_node(std::make_shared<v8::Slice>(context.get_input(0), start, end, step, dim))};
|
||||
};
|
||||
|
@ -39,7 +39,7 @@ OutputVector translate_transpose(NodeContext& context) {
|
||||
auto step = v0::Constant::create(element::i32, {}, {1});
|
||||
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);
|
||||
|
||||
auto axis_0 = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto dim0_node_ = std::make_shared<v0::Unsqueeze>(dim0_node, axis_0);
|
||||
auto dim1_node_ = std::make_shared<v0::Unsqueeze>(dim1_node, axis_0);
|
||||
auto indices = std::make_shared<v0::Concat>(OutputVector{dim0_node_, dim1_node_}, 0);
|
||||
|
@ -27,23 +27,23 @@ namespace {
|
||||
OutputVector translate_base_triu_tril(const NodeContext& context, bool upper) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto input_tensor = context.get_input(0);
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input_tensor));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}));
|
||||
auto minus_two = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-2}));
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input_tensor, element::i32));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto minus_two = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-2}));
|
||||
const auto m = context.mark_node(std::make_shared<v7::Gather>(input_shape, minus_one, zero));
|
||||
const auto n = context.mark_node(std::make_shared<v7::Gather>(input_shape, minus_two, zero));
|
||||
auto horizontal_range = context.mark_node(std::make_shared<v4::Range>(zero, m, one, element::i64));
|
||||
auto horizontal_range = context.mark_node(std::make_shared<v4::Range>(zero, m, one, element::i32));
|
||||
horizontal_range = context.mark_node(std::make_shared<v0::Unsqueeze>(horizontal_range, zero));
|
||||
Output<Node> vertical_range;
|
||||
if (!context.input_is_none(1)) {
|
||||
auto diagonal = context.get_input(1);
|
||||
diagonal = context.mark_node(std::make_shared<v0::Convert>(diagonal, element::i64));
|
||||
diagonal = context.mark_node(std::make_shared<v0::Convert>(diagonal, element::i32));
|
||||
auto stop = context.mark_node(std::make_shared<v1::Add>(n, diagonal));
|
||||
vertical_range = context.mark_node(std::make_shared<v4::Range>(diagonal, stop, one, element::i64));
|
||||
vertical_range = context.mark_node(std::make_shared<v4::Range>(diagonal, stop, one, element::i32));
|
||||
} else {
|
||||
vertical_range = context.mark_node(std::make_shared<v4::Range>(zero, n, one, element::i64));
|
||||
vertical_range = context.mark_node(std::make_shared<v4::Range>(zero, n, one, element::i32));
|
||||
}
|
||||
vertical_range = context.mark_node(std::make_shared<v0::Unsqueeze>(vertical_range, one));
|
||||
|
||||
|
@ -52,8 +52,8 @@ OutputVector translate_var_mean(NodeContext& context) {
|
||||
axes = context.get_input(1);
|
||||
mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, keepdims));
|
||||
t_mean = context.mark_node(std::make_shared<v1::ReduceMean>(data, axes, true));
|
||||
auto reduced_dims = context.mark_node(std::make_shared<v3::ShapeOf>(data));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
auto reduced_dims = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
reduced_dims = context.mark_node(std::make_shared<v8::Gather>(reduced_dims, axes, zero));
|
||||
num_elements = context.mark_node(std::make_shared<v1::ReduceProd>(reduced_dims, zero, false));
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ AppendListUnpackReplacer::AppendListUnpackReplacer() {
|
||||
if (index_val[0] < 0) {
|
||||
index = inputs.size() + index;
|
||||
}
|
||||
auto axis_0 = ov::op::v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_0 = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(inputs[index], axis_0, list_unpack->get_output_size());
|
||||
NodeVector to_copy_rt{axis_0, split};
|
||||
OutputVector res;
|
||||
|
@ -49,10 +49,10 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
||||
if (rank.get_length() == 0) {
|
||||
// Based on slice_size and output index select size.
|
||||
// Constants required by transformation.
|
||||
auto const_1 = ov::op::v0::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto const_1_0d = ov::op::v0::Constant::create(element::i64, Shape{}, {1});
|
||||
auto const_0 = ov::op::v0::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto const_0_0d = ov::op::v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto const_1 = ov::op::v0::Constant::create(element::i32, Shape{1}, {1});
|
||||
auto const_1_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {1});
|
||||
auto const_0 = ov::op::v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto const_0_0d = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
|
||||
// Load and convert op inputs.
|
||||
auto input = torch_split->get_input_source_output(0);
|
||||
@ -63,7 +63,7 @@ AtenGetItemReplacer::AtenGetItemReplacer() {
|
||||
auto getitem_idx = getitem->input(1).get_source_output();
|
||||
|
||||
// Calculate number of splits based on input shape and split_size.
|
||||
auto shape = std::make_shared<ov::op::v0::ShapeOf>(input);
|
||||
auto shape = std::make_shared<ov::op::v3::ShapeOf>(input, element::i32);
|
||||
auto len_to_split = std::make_shared<ov::op::v8::Gather>(shape, axis, const_0);
|
||||
// Convert to f64 from int to calculate reminder - last chunk can be smaller if Shape in given axis is
|
||||
// not equally divisible.
|
||||
|
@ -41,20 +41,20 @@ std::shared_ptr<Node> flatten(const Output<Node>& value, size_t axis) {
|
||||
// [d_{axis}, ..., d_n]
|
||||
Output<Node> output_shape;
|
||||
if (axis == 0) {
|
||||
output_shape = v0::Constant::create(element::i64, Shape{2}, {1, -1});
|
||||
output_shape = v0::Constant::create(element::i32, Shape{2}, {1, -1});
|
||||
} else if (axis == 1) {
|
||||
output_shape = v0::Constant::create(element::i64, Shape{2}, {0, -1});
|
||||
output_shape = v0::Constant::create(element::i32, Shape{2}, {0, -1});
|
||||
} else {
|
||||
const auto value_shape = std::make_shared<v3::ShapeOf>(value);
|
||||
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape);
|
||||
const auto axis_node = v0::Constant::create(element::i64, Shape{}, {axis});
|
||||
auto start = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto step = v0::Constant::create(element::i64, Shape{}, {1});
|
||||
const auto value_shape = std::make_shared<v3::ShapeOf>(value, element::i32);
|
||||
const auto value_rank = std::make_shared<v3::ShapeOf>(value_shape, element::i32);
|
||||
const auto axis_node = v0::Constant::create(element::i32, Shape{}, {axis});
|
||||
auto start = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto step = v0::Constant::create(element::i32, Shape{}, {1});
|
||||
const auto first_part_dims = std::make_shared<v8::Slice>(value_shape, start, axis_node, step);
|
||||
auto zero = v0::Constant::create(element::i64, {}, {0});
|
||||
auto zero = v0::Constant::create(element::i32, {}, {0});
|
||||
auto first_part_dims_length = std::make_shared<ov::op::v1::ReduceProd>(first_part_dims, zero, true);
|
||||
|
||||
auto remaining_part_length = v0::Constant::create(element::i64, {1}, {-1});
|
||||
auto remaining_part_length = v0::Constant::create(element::i32, {1}, {-1});
|
||||
|
||||
output_shape = std::make_shared<v0::Concat>(OutputVector{first_part_dims_length, remaining_part_length}, 0);
|
||||
}
|
||||
@ -112,7 +112,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
if (id_dtype == element::boolean || id_dtype == element::u8) {
|
||||
auto idx = std::make_shared<ov::op::v0::Convert>(ids[i], element::u8);
|
||||
auto nonzero = std::make_shared<ov::op::v3::NonZero>(idx);
|
||||
auto input_order = v0::Constant::create(element::i64, Shape{2}, {1, 0});
|
||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
|
||||
masked_indicies.push_back(masked_id);
|
||||
is_masked_bool.push_back(true);
|
||||
@ -132,14 +132,14 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
// perform gather for single element case
|
||||
if (advanced_ids.size() == 1) {
|
||||
auto index = masked_indicies[advanced_ids[0]];
|
||||
index = std::make_shared<v0::Convert>(index, element::i64);
|
||||
index = std::make_shared<v0::Convert>(index, element::i32);
|
||||
if (is_masked_bool[advanced_ids[0]]) {
|
||||
auto gather = std::make_shared<v8::GatherND>(input_node, index);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
return true;
|
||||
}
|
||||
auto dim = v0::Constant::create(element::i64, Shape{}, {advanced_ids[0]});
|
||||
auto dim = v0::Constant::create(element::i32, Shape{}, {advanced_ids[0]});
|
||||
auto gather = std::make_shared<v8::Gather>(input_node, index, dim);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
@ -150,8 +150,8 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
if (rank.is_dynamic()) {
|
||||
FRONT_END_CHECK_IMPLEMENTED(false, "indexing for tensor with dynamic rank is not implemented ");
|
||||
}
|
||||
auto input_shape = std::make_shared<v3::ShapeOf>(input_node);
|
||||
auto zero = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto input_shape = std::make_shared<v3::ShapeOf>(input_node, element::i32);
|
||||
auto zero = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto input_dims = std::make_shared<v1::Split>(input_shape, zero, rank.get_length());
|
||||
std::vector<size_t> non_used_dims;
|
||||
for (auto i = 0; i < rank.get_length(); i++) {
|
||||
@ -162,7 +162,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
std::vector<size_t> permutation_dims;
|
||||
permutation_dims.insert(permutation_dims.end(), advanced_ids.begin(), advanced_ids.end());
|
||||
permutation_dims.insert(permutation_dims.end(), non_used_dims.begin(), non_used_dims.end());
|
||||
auto transpose_dims = v0::Constant::create(element::i64, Shape{permutation_dims.size()}, permutation_dims);
|
||||
auto transpose_dims = v0::Constant::create(element::i32, Shape{permutation_dims.size()}, permutation_dims);
|
||||
auto transposed_input = std::make_shared<v1::Transpose>(input_node, transpose_dims);
|
||||
auto flatten_input = flatten(transposed_input, adv_idx_count);
|
||||
auto cum_adv_index = masked_indicies[advanced_ids[adv_idx_count - 1]];
|
||||
@ -177,14 +177,14 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
OutputVector concat_dims;
|
||||
// check if all advanced indices are consecutive.
|
||||
std::vector<size_t> consequence_dims;
|
||||
auto cum_adv_index_shape_tensor = std::make_shared<v3::ShapeOf>(cum_adv_index);
|
||||
auto cum_adv_index_shape_tensor = std::make_shared<v3::ShapeOf>(cum_adv_index, element::i32);
|
||||
for (size_t i = advanced_ids[0]; i <= advanced_ids[advanced_ids.size() - 1]; i++) {
|
||||
consequence_dims.push_back(i);
|
||||
}
|
||||
// unfold regular index axes
|
||||
if (advanced_ids == consequence_dims) {
|
||||
OutputVector folded_adv_idx_shape_vector;
|
||||
auto minus_one = v0::Constant::create(element::i64, Shape{1}, {-1});
|
||||
auto minus_one = v0::Constant::create(element::i32, Shape{1}, {-1});
|
||||
folded_adv_idx_shape_vector.push_back(minus_one);
|
||||
for (auto i : non_used_dims) {
|
||||
folded_adv_idx_shape_vector.push_back(input_dims->output(i));
|
||||
@ -201,7 +201,7 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
}
|
||||
// Transpose folded advanced indexed axis to its original location.
|
||||
auto permute_indicies =
|
||||
v0::Constant::create(element::i64, Shape{adv_idx_permute.size()}, adv_idx_permute);
|
||||
v0::Constant::create(element::i32, Shape{adv_idx_permute.size()}, adv_idx_permute);
|
||||
gather = std::make_shared<v1::Transpose>(gather, permute_indicies);
|
||||
// unfold advanced index axes
|
||||
for (size_t i = 0; i <= advanced_ids[0]; i++) {
|
||||
@ -242,17 +242,17 @@ AtenIndexToSelect::AtenIndexToSelect() {
|
||||
auto index_dtype = indicies->get_output_element_type(0);
|
||||
if (index_dtype == element::boolean || index_dtype == element::u8) {
|
||||
auto nonzero = std::make_shared<v3::NonZero>(indicies);
|
||||
auto input_order = v0::Constant::create(element::i64, Shape{2}, {1, 0});
|
||||
auto input_order = v0::Constant::create(element::i32, Shape{2}, {1, 0});
|
||||
auto masked_id = std::make_shared<v1::Transpose>(nonzero, input_order);
|
||||
auto gather = std::make_shared<v8::GatherND>(input_node, masked_id);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
return true;
|
||||
}
|
||||
if (index_dtype != element::i32 && index_dtype != element::i64) {
|
||||
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i64);
|
||||
if (index_dtype != element::i32 && index_dtype != element::i32) {
|
||||
indicies = std::make_shared<ov::op::v0::Convert>(indicies, element::i32);
|
||||
}
|
||||
auto dim = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto dim = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto gather = std::make_shared<v8::Gather>(input_node, indicies, dim);
|
||||
copy_runtime_info({index_op, input_node, indicies}, gather);
|
||||
replace_node(index_op, gather);
|
||||
|
@ -42,7 +42,7 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
auto adapool_op = pattern::wrap_type<v8::AdaptiveAvgPool>({pattern::any_input(), list_construct});
|
||||
// replace list construct for aten::expand(tensor, prim::ListConstruct(shapes)) decomposition
|
||||
// shape_of + broadcast + equal + select
|
||||
auto shape_of_op = pattern::wrap_type<v3::ShapeOf>({list_construct, pattern::any_input()});
|
||||
auto shape_of_op = pattern::wrap_type<v3::ShapeOf>({list_construct});
|
||||
auto equal_op = pattern::wrap_type<v1::Equal>({list_construct, pattern::any_input()});
|
||||
auto select_op = pattern::wrap_type<v1::Select>({pattern::any_input(), pattern::any_input(), list_construct});
|
||||
// replace list construct for aten::repeat(tensor, prim::ListConstruct(shapes)))
|
||||
@ -67,7 +67,7 @@ ListConstructReplacer::ListConstructReplacer() {
|
||||
if (auto list_unpack_node = cast_fw_node(list_construct_node, "prim::ListConstruct")) {
|
||||
// Concatenation is possible because all elements in list should be scalar intigers.
|
||||
OutputVector inputs;
|
||||
auto axis_0 = v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_0 = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
for (auto& input : list_construct_node->inputs()) {
|
||||
auto rank = input.get_partial_shape().rank();
|
||||
FRONT_END_OP_CONVERSION_CHECK(rank.is_dynamic() || rank.get_length() == 0, "Rank must be 0");
|
||||
|
@ -35,12 +35,12 @@ std::shared_ptr<Node> create_padding(std::shared_ptr<Node> input_rank,
|
||||
// PyTorch paddings represented as [N_pad_begins, N_pad_ends, N-1_pad_begins, N-1_pad_ends, ... ]
|
||||
// if len of paddings not equal to input rank * 2, zero padding added to first rank - N dimensions
|
||||
// OV expects paddings separated on begins and ends for each dimension from first to last
|
||||
auto minus_two = ov::op::v0::Constant::create(element::i64, Shape{}, {-2});
|
||||
auto zero = ov::op::v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto pad_id_range = std::make_shared<ov::op::v4::Range>(start_id, end_id, minus_two, element::i64);
|
||||
auto minus_two = ov::op::v0::Constant::create(element::i32, Shape{}, {-2});
|
||||
auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto pad_id_range = std::make_shared<ov::op::v4::Range>(start_id, end_id, minus_two, element::i32);
|
||||
auto pads = std::make_shared<ov::op::v8::Gather>(padding, pad_id_range, zero);
|
||||
// add left side zero padding for difference between padding size and input rank
|
||||
auto pads_short_len = std::make_shared<ov::op::v3::ShapeOf>(pads);
|
||||
auto pads_short_len = std::make_shared<ov::op::v3::ShapeOf>(pads, element::i32);
|
||||
auto pads_diff = std::make_shared<ov::op::v1::Subtract>(input_rank, pads_short_len);
|
||||
auto pads_remaining = std::make_shared<ov::op::v3::Broadcast>(zero, pads_diff);
|
||||
auto pads_remaining_c = std::make_shared<ov::op::v1::ConvertLike>(pads_remaining, pads);
|
||||
@ -62,18 +62,18 @@ PrimListConstructPadReplacer::PrimListConstructPadReplacer() {
|
||||
if (!pad_op) {
|
||||
return false;
|
||||
}
|
||||
auto minus_two = ov::op::v0::Constant::create(element::i64, Shape{}, {-2});
|
||||
auto minus_one = ov::op::v0::Constant::create(element::i64, Shape{}, {-1});
|
||||
auto zero = ov::op::v0::Constant::create(element::i64, Shape{}, {0});
|
||||
auto minus_two = ov::op::v0::Constant::create(element::i32, Shape{}, {-2});
|
||||
auto minus_one = ov::op::v0::Constant::create(element::i32, Shape{}, {-1});
|
||||
auto zero = ov::op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto input_node = pad_op->input_value(0).get_node_shared_ptr();
|
||||
auto padding = pad_op->input_value(1).get_node_shared_ptr();
|
||||
// for case. when padding is list of scalars, concatenate them into one tensor
|
||||
auto pad_values = concat_list_construct(padding);
|
||||
std::string mode = "constant";
|
||||
auto zero_f = ov::op::v0::Constant::create(element::f32, Shape{}, {0});
|
||||
auto input_shape = std::make_shared<ov::op::v3::ShapeOf>(input_node);
|
||||
auto input_rank = std::make_shared<ov::op::v3::ShapeOf>(input_shape);
|
||||
auto pad_size_1d = std::make_shared<ov::op::v3::ShapeOf>(pad_values);
|
||||
auto input_shape = std::make_shared<ov::op::v3::ShapeOf>(input_node, element::i32);
|
||||
auto input_rank = std::make_shared<ov::op::v3::ShapeOf>(input_shape, element::i32);
|
||||
auto pad_size_1d = std::make_shared<ov::op::v3::ShapeOf>(pad_values, element::i32);
|
||||
auto pad_size = std::make_shared<ov::op::v0::Squeeze>(pad_size_1d, zero);
|
||||
// get pad_begins and pad_ends indexes starting for end of paddings
|
||||
auto start_pad_begins = std::make_shared<ov::op::v1::Add>(pad_size, minus_two);
|
||||
|
@ -105,7 +105,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
if (auto where = cast_fw_node(input_node, "aten::where")) {
|
||||
const auto input = where->get_input_source_output(0);
|
||||
auto non_zero = std::make_shared<opset10::NonZero>(input);
|
||||
auto axis = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
const auto num_splits = list_unpack->get_output_size();
|
||||
auto split = std::make_shared<opset10::Split>(non_zero, axis, num_splits);
|
||||
NodeVector to_copy_rt{split};
|
||||
@ -123,7 +123,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
if (auto nonzero_numpy = cast_fw_node(input_node, "aten::nonzero_numpy")) {
|
||||
const auto input = nonzero_numpy->get_input_source_output(0);
|
||||
auto non_zero = std::make_shared<opset10::NonZero>(input);
|
||||
auto axis = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
const auto num_splits = list_unpack->get_output_size();
|
||||
auto split = std::make_shared<opset10::Split>(non_zero, axis, num_splits);
|
||||
NodeVector to_copy_rt{split};
|
||||
@ -168,12 +168,12 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
}
|
||||
NodeVector cat_shapes{};
|
||||
NodeVector reshapes{};
|
||||
auto const_neg_1 = opset10::Constant::create(element::i64, Shape{1}, {-1});
|
||||
auto const_1 = opset10::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto const_neg_1 = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
||||
auto const_1 = opset10::Constant::create(element::i32, Shape{1}, {1});
|
||||
int input_idx = 0;
|
||||
for (auto& input : meshgrid_inputs) {
|
||||
auto reshaped_input = std::make_shared<opset10::Reshape>(input, const_neg_1, false);
|
||||
auto shape = std::make_shared<opset10::ShapeOf>(reshaped_input);
|
||||
auto shape = std::make_shared<opset10::ShapeOf>(reshaped_input, element::i32);
|
||||
cat_shapes.push_back(shape);
|
||||
NodeVector cat_inputs(meshgrid_inputs.size(), const_1);
|
||||
cat_inputs[input_idx] = shape;
|
||||
@ -202,7 +202,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
if (auto shape_of = std::dynamic_pointer_cast<opset10::ShapeOf>(input_node)) {
|
||||
// case aten::size as input
|
||||
// Number of ListUnpack outputs should be equal to rank of input shape.
|
||||
auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto split = std::make_shared<opset10::Split>(shape_of, axis_0, list_unpack->get_output_size());
|
||||
|
||||
NodeVector to_copy_rt{axis_0, split};
|
||||
@ -222,7 +222,7 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
|
||||
if (auto slice = std::dynamic_pointer_cast<opset10::Slice>(input_node)) {
|
||||
// case aten::slice as input
|
||||
// Number of ListUnpack outputs should be equal to rank of input shape.
|
||||
auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto split = std::make_shared<opset10::Split>(slice, axis_0, list_unpack->get_output_size());
|
||||
|
||||
NodeVector to_copy_rt{axis_0, split};
|
||||
|
@ -64,7 +64,7 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(std::shared_ptr<T
|
||||
for (size_t i = 0; i < sh.size(); i++) {
|
||||
new_shape[order[i]] = sh[i];
|
||||
}
|
||||
auto shape_const = v0::Constant::create(element::i64, {new_shape.size()}, new_shape);
|
||||
auto shape_const = v0::Constant::create(element::i32, {new_shape.size()}, new_shape);
|
||||
auto reshape = std::make_shared<v1::Reshape>(parameter, shape_const, false);
|
||||
auto order_const = v0::Constant::create(element::i32, {order.size()}, order);
|
||||
auto transpose = std::make_shared<v1::Transpose>(reshape, order_const);
|
||||
|
@ -46,13 +46,13 @@ Output<Node> make_optional_bias(const Output<Node>& base_op,
|
||||
Output<Node> reshape_channelwise(const NodeContext& context,
|
||||
const Output<Node>& data,
|
||||
const Output<Node>& shape_source) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_source));
|
||||
auto input_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape));
|
||||
auto one_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {1}));
|
||||
auto two_const = context.mark_node(opset10::Constant::create(element::i64, Shape{1}, {2}));
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(shape_source, element::i32));
|
||||
auto input_rank = context.mark_node(std::make_shared<opset10::ShapeOf>(input_shape, element::i32));
|
||||
auto one_const = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto two_const = context.mark_node(opset10::Constant::create(element::i32, Shape{1}, {2}));
|
||||
auto tail_shape_rank = context.mark_node(std::make_shared<opset10::Subtract>(input_rank, two_const));
|
||||
auto tail_shape = context.mark_node(std::make_shared<opset10::Broadcast>(one_const, tail_shape_rank));
|
||||
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(data));
|
||||
auto channels_dim = context.mark_node(std::make_shared<opset10::ShapeOf>(data, element::i32));
|
||||
auto new_shape =
|
||||
context.mark_node(std::make_shared<opset10::Concat>(OutputVector{one_const, channels_dim, tail_shape}, 0));
|
||||
|
||||
@ -74,19 +74,19 @@ std::tuple<Output<Node>, Output<Node>> get_shape_rank(const NodeContext& context
|
||||
Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<Node>& kernel, int64_t groups) {
|
||||
using std::make_shared;
|
||||
|
||||
auto axis_0 = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto groups_const = opset10::Constant::create(element::i64, Shape{1}, {groups});
|
||||
auto neg_1_const = opset10::Constant::create(element::i64, Shape{1}, {-1});
|
||||
auto axis_0 = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto groups_const = opset10::Constant::create(element::i32, Shape{1}, {groups});
|
||||
auto neg_1_const = opset10::Constant::create(element::i32, Shape{1}, {-1});
|
||||
|
||||
auto kernel_shape = std::make_shared<opset10::ShapeOf>(kernel);
|
||||
auto c_out_idx = opset10::Constant::create(element::i64, Shape{}, {0});
|
||||
auto kernel_shape = std::make_shared<opset10::ShapeOf>(kernel, element::i32);
|
||||
auto c_out_idx = opset10::Constant::create(element::i32, Shape{}, {0});
|
||||
auto kernel_shape_0 = make_shared<opset10::Gather>(kernel_shape, c_out_idx, axis_0);
|
||||
auto kernel_shape_0_uns = make_shared<opset10::Unsqueeze>(kernel_shape_0, axis_0);
|
||||
auto c_out_value = make_shared<opset10::Divide>(kernel_shape_0_uns, groups_const);
|
||||
|
||||
auto start = opset10::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto stop = opset10::Constant::create(element::i64, Shape{1}, {std::numeric_limits<int64_t>::max()});
|
||||
auto step = opset10::Constant::create(element::i64, Shape{1}, {1});
|
||||
auto start = opset10::Constant::create(element::i32, Shape{1}, {2});
|
||||
auto stop = opset10::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
|
||||
auto step = opset10::Constant::create(element::i32, Shape{1}, {1});
|
||||
auto remaining_shape = make_shared<opset10::Slice>(kernel_shape, start, stop, step);
|
||||
|
||||
auto new_kernel_shape =
|
||||
@ -117,8 +117,8 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
|
||||
};
|
||||
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x) {
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x));
|
||||
auto axes = context.mark_node(opset10::Constant::create(element::i64, Shape({1}), {0}));
|
||||
auto input_shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, element::i32));
|
||||
auto axes = context.mark_node(opset10::Constant::create(element::i32, Shape({1}), {0}));
|
||||
return context.mark_node(std::make_shared<opset10::ReduceProd>(input_shape, axes, false));
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user