[PT FE] Refactor aten::flatten and aten::transpose conversion (#19098)
* [PT FE] Refactor aten::flatten and aten::transpose conversion * Fix code style * Fix codestyle
This commit is contained in:
parent
9deef1480a
commit
f5f221a3a9
@ -67,7 +67,7 @@ OutputVector translate_fake_quantize_per_channel_affine(const NodeContext& conte
|
|||||||
|
|
||||||
auto rank = std::get<1>(get_shape_rank(context, input_node));
|
auto rank = std::get<1>(get_shape_rank(context, input_node));
|
||||||
auto ones = std::make_shared<v3::Broadcast>(const_1, rank);
|
auto ones = std::make_shared<v3::Broadcast>(const_1, rank);
|
||||||
auto normalized_axis = normalize_axis(context, axis, input_node);
|
auto normalized_axis = normalize_axis(context, axis, rank);
|
||||||
// Create vector of length of rank filled with ones, except single -1 value at place selected by axis element.
|
// Create vector of length of rank filled with ones, except single -1 value at place selected by axis element.
|
||||||
auto new_shape = std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, const_neg_1, const_0);
|
auto new_shape = std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, const_neg_1, const_0);
|
||||||
// Reshape scale and zero point to tensor of the same rank as input, having shape 1 everywhere except dimension
|
// Reshape scale and zero point to tensor of the same rank as input, having shape 1 everywhere except dimension
|
||||||
|
@ -21,14 +21,6 @@ using namespace ov::op;
|
|||||||
OutputVector translate_flatten(const NodeContext& context) {
|
OutputVector translate_flatten(const NodeContext& context) {
|
||||||
num_inputs_check(context, 1, 3);
|
num_inputs_check(context, 1, 3);
|
||||||
auto x = context.get_input(0);
|
auto x = context.get_input(0);
|
||||||
int64_t start_dim = 0;
|
|
||||||
int64_t end_dim = -1;
|
|
||||||
if (!context.input_is_none(1)) {
|
|
||||||
start_dim = context.const_input<int64_t>(1);
|
|
||||||
}
|
|
||||||
if (!context.input_is_none(2)) {
|
|
||||||
end_dim = context.const_input<int64_t>(2);
|
|
||||||
}
|
|
||||||
Output<Node> shape;
|
Output<Node> shape;
|
||||||
Output<Node> rank;
|
Output<Node> rank;
|
||||||
std::tie(shape, rank) = get_shape_rank(context, x, true);
|
std::tie(shape, rank) = get_shape_rank(context, x, true);
|
||||||
@ -38,20 +30,16 @@ OutputVector translate_flatten(const NodeContext& context) {
|
|||||||
if (!context.input_is_none(1)) {
|
if (!context.input_is_none(1)) {
|
||||||
start_dim_node = context.get_input(1);
|
start_dim_node = context.get_input(1);
|
||||||
} else {
|
} else {
|
||||||
start_dim_node = v0::Constant::create(element::i32, Shape{}, {start_dim});
|
start_dim_node = v0::Constant::create(element::i32, Shape{}, {0});
|
||||||
}
|
}
|
||||||
if (!context.input_is_none(2)) {
|
if (!context.input_is_none(2)) {
|
||||||
end_dim_node = context.get_input(2);
|
end_dim_node = context.get_input(2);
|
||||||
} else {
|
} else {
|
||||||
end_dim_node = v0::Constant::create(element::i32, Shape{}, {end_dim});
|
end_dim_node = v0::Constant::create(element::i32, Shape{}, {-1});
|
||||||
}
|
}
|
||||||
if (start_dim < 0) {
|
start_dim_node = normalize_axis(context, start_dim_node, rank);
|
||||||
start_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, start_dim_node));
|
end_dim_node = normalize_axis(context, end_dim_node, rank);
|
||||||
}
|
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat should still be able to
|
||||||
if (end_dim < 0) {
|
|
||||||
end_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, end_dim_node));
|
|
||||||
}
|
|
||||||
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat shuold still be able to
|
|
||||||
// work with it
|
// work with it
|
||||||
auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
|
auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||||
auto one = v0::Constant::create(element::i32, Shape{1}, {1});
|
auto one = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||||
|
@ -24,19 +24,12 @@ using namespace ov::op;
|
|||||||
|
|
||||||
OutputVector translate_transpose(const NodeContext& context) {
|
OutputVector translate_transpose(const NodeContext& context) {
|
||||||
num_inputs_check(context, 3, 3);
|
num_inputs_check(context, 3, 3);
|
||||||
auto dim0 = context.const_input<int64_t>(1);
|
|
||||||
auto dim1 = context.const_input<int64_t>(2);
|
|
||||||
Output<Node> rank;
|
Output<Node> rank;
|
||||||
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
|
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
|
||||||
// Use opset::If for dim normalization
|
|
||||||
auto dim0_node = context.get_input(1);
|
auto dim0_node = context.get_input(1);
|
||||||
auto dim1_node = context.get_input(2);
|
auto dim1_node = context.get_input(2);
|
||||||
if (dim0 < 0) {
|
dim0_node = normalize_axis(context, dim0_node, rank);
|
||||||
dim0_node = std::make_shared<v1::Add>(rank, dim0_node);
|
dim1_node = normalize_axis(context, dim1_node, rank);
|
||||||
}
|
|
||||||
if (dim1 < 0) {
|
|
||||||
dim1_node = std::make_shared<v1::Add>(rank, dim1_node);
|
|
||||||
}
|
|
||||||
auto start = v0::Constant::create(element::i32, {}, {0});
|
auto start = v0::Constant::create(element::i32, {}, {0});
|
||||||
auto step = v0::Constant::create(element::i32, {}, {1});
|
auto step = v0::Constant::create(element::i32, {}, {1});
|
||||||
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);
|
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);
|
||||||
|
@ -28,11 +28,13 @@ OutputVector translate_unflatten(const NodeContext& context) {
|
|||||||
if (context.get_input_type(2).is<type::List>()) {
|
if (context.get_input_type(2).is<type::List>()) {
|
||||||
sizes = concat_list_construct(sizes);
|
sizes = concat_list_construct(sizes);
|
||||||
}
|
}
|
||||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
Output<Node> input_shape;
|
||||||
|
Output<Node> rank;
|
||||||
|
std::tie(input_shape, rank) = get_shape_rank(context, input);
|
||||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||||
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||||
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
|
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
|
||||||
dim = normalize_axis(context, dim, input);
|
dim = normalize_axis(context, dim, rank);
|
||||||
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
||||||
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
|
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
|
||||||
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));
|
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));
|
||||||
|
@ -117,11 +117,7 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
|
|||||||
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
|
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
|
||||||
};
|
};
|
||||||
|
|
||||||
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
|
Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& rank) {
|
||||||
const Output<Node>& axis,
|
|
||||||
const Output<Node>& input_node) {
|
|
||||||
Output<Node> rank;
|
|
||||||
std::tie(std::ignore, rank) = get_shape_rank(context, input_node);
|
|
||||||
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
|
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
|
||||||
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
|
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
|
||||||
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));
|
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));
|
||||||
|
@ -39,9 +39,7 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
|
|||||||
|
|
||||||
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
|
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
|
||||||
|
|
||||||
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
|
Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& input_node);
|
||||||
const Output<Node>& axis,
|
|
||||||
const Output<Node>& input_node);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);
|
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ Output<Node> quantize(const NodeContext& context,
|
|||||||
const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32));
|
const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32));
|
||||||
const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank));
|
const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank));
|
||||||
|
|
||||||
const auto normalized_axis = normalize_axis(context, axis_convert, input_convert);
|
const auto normalized_axis = normalize_axis(context, axis_convert, rank);
|
||||||
const auto new_shape =
|
const auto new_shape =
|
||||||
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero));
|
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero));
|
||||||
|
|
||||||
|
@ -27,7 +27,9 @@ class TestFlatten(PytorchLayerTest):
|
|||||||
|
|
||||||
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
|
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim0,dim1", [[0, 1],
|
@pytest.mark.parametrize("dim0,dim1", [[0, -1],
|
||||||
|
[-2, -1],
|
||||||
|
[0, 1],
|
||||||
[0, 2],
|
[0, 2],
|
||||||
[0, 3],
|
[0, 3],
|
||||||
[1, 2],
|
[1, 2],
|
||||||
|
@ -31,13 +31,13 @@ class aten_native_multi_head_attention(torch.nn.Module):
|
|||||||
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
|
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
|
||||||
# which later returns NaNs as MHA's output
|
# which later returns NaNs as MHA's output
|
||||||
if mask == 0:
|
if mask == 0:
|
||||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
|
||||||
self.mask_type = 0
|
self.mask_type = 0
|
||||||
elif mask == 1:
|
elif mask == 1:
|
||||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
|
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype("bool"))
|
||||||
self.mask_type = 1
|
self.mask_type = 1
|
||||||
elif mask == 2:
|
elif mask == 2:
|
||||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
|
||||||
self.mask_type = 2
|
self.mask_type = 2
|
||||||
else:
|
else:
|
||||||
self.mask = None
|
self.mask = None
|
||||||
|
Loading…
Reference in New Issue
Block a user