[PT FE] Refactor aten::flatten and aten::transpose conversion (#19098)

* [PT FE] Refactor aten::flatten and aten::transpose conversion

* Fix code style

* Fix codestyle
This commit is contained in:
Maxim Vafin 2023-08-10 10:28:36 +02:00 committed by GitHub
parent 9deef1480a
commit f5f221a3a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 21 additions and 42 deletions

View File

@ -67,7 +67,7 @@ OutputVector translate_fake_quantize_per_channel_affine(const NodeContext& conte
auto rank = std::get<1>(get_shape_rank(context, input_node));
auto ones = std::make_shared<v3::Broadcast>(const_1, rank);
auto normalized_axis = normalize_axis(context, axis, input_node);
auto normalized_axis = normalize_axis(context, axis, rank);
// Create vector of length of rank filled with ones, except single -1 value at place selected by axis element.
auto new_shape = std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, const_neg_1, const_0);
// Reshape scale and zero point to tensor of the same rank as input, having shape 1 everywhere except dimension

View File

@ -21,14 +21,6 @@ using namespace ov::op;
OutputVector translate_flatten(const NodeContext& context) {
num_inputs_check(context, 1, 3);
auto x = context.get_input(0);
int64_t start_dim = 0;
int64_t end_dim = -1;
if (!context.input_is_none(1)) {
start_dim = context.const_input<int64_t>(1);
}
if (!context.input_is_none(2)) {
end_dim = context.const_input<int64_t>(2);
}
Output<Node> shape;
Output<Node> rank;
std::tie(shape, rank) = get_shape_rank(context, x, true);
@ -38,20 +30,16 @@ OutputVector translate_flatten(const NodeContext& context) {
if (!context.input_is_none(1)) {
start_dim_node = context.get_input(1);
} else {
start_dim_node = v0::Constant::create(element::i32, Shape{}, {start_dim});
start_dim_node = v0::Constant::create(element::i32, Shape{}, {0});
}
if (!context.input_is_none(2)) {
end_dim_node = context.get_input(2);
} else {
end_dim_node = v0::Constant::create(element::i32, Shape{}, {end_dim});
end_dim_node = v0::Constant::create(element::i32, Shape{}, {-1});
}
if (start_dim < 0) {
start_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, start_dim_node));
}
if (end_dim < 0) {
end_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, end_dim_node));
}
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat shuold still be able to
start_dim_node = normalize_axis(context, start_dim_node, rank);
end_dim_node = normalize_axis(context, end_dim_node, rank);
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat should still be able to
// work with it
auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
auto one = v0::Constant::create(element::i32, Shape{1}, {1});

View File

@ -24,19 +24,12 @@ using namespace ov::op;
OutputVector translate_transpose(const NodeContext& context) {
num_inputs_check(context, 3, 3);
auto dim0 = context.const_input<int64_t>(1);
auto dim1 = context.const_input<int64_t>(2);
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
// Use opset::If for dim normalization
auto dim0_node = context.get_input(1);
auto dim1_node = context.get_input(2);
if (dim0 < 0) {
dim0_node = std::make_shared<v1::Add>(rank, dim0_node);
}
if (dim1 < 0) {
dim1_node = std::make_shared<v1::Add>(rank, dim1_node);
}
dim0_node = normalize_axis(context, dim0_node, rank);
dim1_node = normalize_axis(context, dim1_node, rank);
auto start = v0::Constant::create(element::i32, {}, {0});
auto step = v0::Constant::create(element::i32, {}, {1});
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);

View File

@ -28,11 +28,13 @@ OutputVector translate_unflatten(const NodeContext& context) {
if (context.get_input_type(2).is<type::List>()) {
sizes = concat_list_construct(sizes);
}
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
Output<Node> input_shape;
Output<Node> rank;
std::tie(input_shape, rank) = get_shape_rank(context, input);
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
dim = normalize_axis(context, dim, input);
dim = normalize_axis(context, dim, rank);
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));

View File

@ -117,11 +117,7 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
};
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
const Output<Node>& axis,
const Output<Node>& input_node) {
Output<Node> rank;
std::tie(std::ignore, rank) = get_shape_rank(context, input_node);
Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& rank) {
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));

View File

@ -39,9 +39,7 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
const Output<Node>& axis,
const Output<Node>& input_node);
Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& input_node);
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);

View File

@ -90,7 +90,7 @@ Output<Node> quantize(const NodeContext& context,
const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32));
const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank));
const auto normalized_axis = normalize_axis(context, axis_convert, input_convert);
const auto normalized_axis = normalize_axis(context, axis_convert, rank);
const auto new_shape =
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero));

View File

@ -27,7 +27,9 @@ class TestFlatten(PytorchLayerTest):
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
@pytest.mark.parametrize("dim0,dim1", [[0, 1],
@pytest.mark.parametrize("dim0,dim1", [[0, -1],
[-2, -1],
[0, 1],
[0, 2],
[0, 3],
[1, 2],

View File

@ -31,13 +31,13 @@ class aten_native_multi_head_attention(torch.nn.Module):
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
# which later returns NaNs as MHA's output
if mask == 0:
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
self.mask_type = 0
elif mask == 1:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype("bool"))
self.mask_type = 1
elif mask == 2:
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
self.mask_type = 2
else:
self.mask = None