[PT FE] Refactor aten::flatten and aten::transpose conversion (#19098)

* [PT FE] Refactor aten::flatten and aten::transpose conversion

* Fix code style

* Fix codestyle
This commit is contained in:
Maxim Vafin 2023-08-10 10:28:36 +02:00 committed by GitHub
parent 9deef1480a
commit f5f221a3a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 21 additions and 42 deletions

View File

@ -67,7 +67,7 @@ OutputVector translate_fake_quantize_per_channel_affine(const NodeContext& conte
auto rank = std::get<1>(get_shape_rank(context, input_node)); auto rank = std::get<1>(get_shape_rank(context, input_node));
auto ones = std::make_shared<v3::Broadcast>(const_1, rank); auto ones = std::make_shared<v3::Broadcast>(const_1, rank);
auto normalized_axis = normalize_axis(context, axis, input_node); auto normalized_axis = normalize_axis(context, axis, rank);
// Create vector of length of rank filled with ones, except single -1 value at place selected by axis element. // Create vector of length of rank filled with ones, except single -1 value at place selected by axis element.
auto new_shape = std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, const_neg_1, const_0); auto new_shape = std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, const_neg_1, const_0);
// Reshape scale and zero point to tensor of the same rank as input, having shape 1 everywhere except dimension // Reshape scale and zero point to tensor of the same rank as input, having shape 1 everywhere except dimension

View File

@ -21,14 +21,6 @@ using namespace ov::op;
OutputVector translate_flatten(const NodeContext& context) { OutputVector translate_flatten(const NodeContext& context) {
num_inputs_check(context, 1, 3); num_inputs_check(context, 1, 3);
auto x = context.get_input(0); auto x = context.get_input(0);
int64_t start_dim = 0;
int64_t end_dim = -1;
if (!context.input_is_none(1)) {
start_dim = context.const_input<int64_t>(1);
}
if (!context.input_is_none(2)) {
end_dim = context.const_input<int64_t>(2);
}
Output<Node> shape; Output<Node> shape;
Output<Node> rank; Output<Node> rank;
std::tie(shape, rank) = get_shape_rank(context, x, true); std::tie(shape, rank) = get_shape_rank(context, x, true);
@ -38,20 +30,16 @@ OutputVector translate_flatten(const NodeContext& context) {
if (!context.input_is_none(1)) { if (!context.input_is_none(1)) {
start_dim_node = context.get_input(1); start_dim_node = context.get_input(1);
} else { } else {
start_dim_node = v0::Constant::create(element::i32, Shape{}, {start_dim}); start_dim_node = v0::Constant::create(element::i32, Shape{}, {0});
} }
if (!context.input_is_none(2)) { if (!context.input_is_none(2)) {
end_dim_node = context.get_input(2); end_dim_node = context.get_input(2);
} else { } else {
end_dim_node = v0::Constant::create(element::i32, Shape{}, {end_dim}); end_dim_node = v0::Constant::create(element::i32, Shape{}, {-1});
} }
if (start_dim < 0) { start_dim_node = normalize_axis(context, start_dim_node, rank);
start_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, start_dim_node)); end_dim_node = normalize_axis(context, end_dim_node, rank);
} // Slice shape from begin and end, then concat with -1, if slice return empty tensor concat should still be able to
if (end_dim < 0) {
end_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, end_dim_node));
}
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat shuold still be able to
// work with it // work with it
auto zero = v0::Constant::create(element::i32, Shape{1}, {0}); auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
auto one = v0::Constant::create(element::i32, Shape{1}, {1}); auto one = v0::Constant::create(element::i32, Shape{1}, {1});

View File

@ -24,19 +24,12 @@ using namespace ov::op;
OutputVector translate_transpose(const NodeContext& context) { OutputVector translate_transpose(const NodeContext& context) {
num_inputs_check(context, 3, 3); num_inputs_check(context, 3, 3);
auto dim0 = context.const_input<int64_t>(1);
auto dim1 = context.const_input<int64_t>(2);
Output<Node> rank; Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true); std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
// Use opset::If for dim normalization
auto dim0_node = context.get_input(1); auto dim0_node = context.get_input(1);
auto dim1_node = context.get_input(2); auto dim1_node = context.get_input(2);
if (dim0 < 0) { dim0_node = normalize_axis(context, dim0_node, rank);
dim0_node = std::make_shared<v1::Add>(rank, dim0_node); dim1_node = normalize_axis(context, dim1_node, rank);
}
if (dim1 < 0) {
dim1_node = std::make_shared<v1::Add>(rank, dim1_node);
}
auto start = v0::Constant::create(element::i32, {}, {0}); auto start = v0::Constant::create(element::i32, {}, {0});
auto step = v0::Constant::create(element::i32, {}, {1}); auto step = v0::Constant::create(element::i32, {}, {1});
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32); auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);

View File

@ -28,11 +28,13 @@ OutputVector translate_unflatten(const NodeContext& context) {
if (context.get_input_type(2).is<type::List>()) { if (context.get_input_type(2).is<type::List>()) {
sizes = concat_list_construct(sizes); sizes = concat_list_construct(sizes);
} }
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32)); Output<Node> input_shape;
Output<Node> rank;
std::tie(input_shape, rank) = get_shape_rank(context, input);
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0})); auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1})); auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32)); dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
dim = normalize_axis(context, dim, input); dim = normalize_axis(context, dim, rank);
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32)); sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()})); auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d)); auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));

View File

@ -117,11 +117,7 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32)); return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
}; };
std::shared_ptr<Node> normalize_axis(const NodeContext& context, Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& rank) {
const Output<Node>& axis,
const Output<Node>& input_node) {
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, input_node);
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank)); auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank)); auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis)); auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));

View File

@ -39,9 +39,7 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id); std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
std::shared_ptr<Node> normalize_axis(const NodeContext& context, Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& input_node);
const Output<Node>& axis,
const Output<Node>& input_node);
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x); std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);

View File

@ -90,7 +90,7 @@ Output<Node> quantize(const NodeContext& context,
const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32)); const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32));
const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank)); const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank));
const auto normalized_axis = normalize_axis(context, axis_convert, input_convert); const auto normalized_axis = normalize_axis(context, axis_convert, rank);
const auto new_shape = const auto new_shape =
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero)); context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero));

View File

@ -27,7 +27,9 @@ class TestFlatten(PytorchLayerTest):
return aten_flatten(dim0, dim1), ref_net, "aten::flatten" return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
@pytest.mark.parametrize("dim0,dim1", [[0, 1], @pytest.mark.parametrize("dim0,dim1", [[0, -1],
[-2, -1],
[0, 1],
[0, 2], [0, 2],
[0, 3], [0, 3],
[1, 2], [1, 2],

View File

@ -31,13 +31,13 @@ class aten_native_multi_head_attention(torch.nn.Module):
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool, # Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
# which later returns NaNs as MHA's output # which later returns NaNs as MHA's output
if mask == 0: if mask == 0:
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool)) self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
self.mask_type = 0 self.mask_type = 0
elif mask == 1: elif mask == 1:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool)) self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype("bool"))
self.mask_type = 1 self.mask_type = 1
elif mask == 2: elif mask == 2:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool)) self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
self.mask_type = 2 self.mask_type = 2
else: else:
self.mask = None self.mask = None