[PT FE] Refactor aten::flatten and aten::transpose conversion (#19098)
* [PT FE] Refactor aten::flatten and aten::transpose conversion * Fix code style * Fix codestyle
This commit is contained in:
parent
9deef1480a
commit
f5f221a3a9
@ -67,7 +67,7 @@ OutputVector translate_fake_quantize_per_channel_affine(const NodeContext& conte
|
||||
|
||||
auto rank = std::get<1>(get_shape_rank(context, input_node));
|
||||
auto ones = std::make_shared<v3::Broadcast>(const_1, rank);
|
||||
auto normalized_axis = normalize_axis(context, axis, input_node);
|
||||
auto normalized_axis = normalize_axis(context, axis, rank);
|
||||
// Create vector of length of rank filled with ones, except single -1 value at place selected by axis element.
|
||||
auto new_shape = std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, const_neg_1, const_0);
|
||||
// Reshape scale and zero point to tensor of the same rank as input, having shape 1 everywhere except dimension
|
||||
|
@ -21,14 +21,6 @@ using namespace ov::op;
|
||||
OutputVector translate_flatten(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 3);
|
||||
auto x = context.get_input(0);
|
||||
int64_t start_dim = 0;
|
||||
int64_t end_dim = -1;
|
||||
if (!context.input_is_none(1)) {
|
||||
start_dim = context.const_input<int64_t>(1);
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
end_dim = context.const_input<int64_t>(2);
|
||||
}
|
||||
Output<Node> shape;
|
||||
Output<Node> rank;
|
||||
std::tie(shape, rank) = get_shape_rank(context, x, true);
|
||||
@ -38,20 +30,16 @@ OutputVector translate_flatten(const NodeContext& context) {
|
||||
if (!context.input_is_none(1)) {
|
||||
start_dim_node = context.get_input(1);
|
||||
} else {
|
||||
start_dim_node = v0::Constant::create(element::i32, Shape{}, {start_dim});
|
||||
start_dim_node = v0::Constant::create(element::i32, Shape{}, {0});
|
||||
}
|
||||
if (!context.input_is_none(2)) {
|
||||
end_dim_node = context.get_input(2);
|
||||
} else {
|
||||
end_dim_node = v0::Constant::create(element::i32, Shape{}, {end_dim});
|
||||
end_dim_node = v0::Constant::create(element::i32, Shape{}, {-1});
|
||||
}
|
||||
if (start_dim < 0) {
|
||||
start_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, start_dim_node));
|
||||
}
|
||||
if (end_dim < 0) {
|
||||
end_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, end_dim_node));
|
||||
}
|
||||
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat shuold still be able to
|
||||
start_dim_node = normalize_axis(context, start_dim_node, rank);
|
||||
end_dim_node = normalize_axis(context, end_dim_node, rank);
|
||||
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat should still be able to
|
||||
// work with it
|
||||
auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto one = v0::Constant::create(element::i32, Shape{1}, {1});
|
||||
|
@ -24,19 +24,12 @@ using namespace ov::op;
|
||||
|
||||
OutputVector translate_transpose(const NodeContext& context) {
|
||||
num_inputs_check(context, 3, 3);
|
||||
auto dim0 = context.const_input<int64_t>(1);
|
||||
auto dim1 = context.const_input<int64_t>(2);
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, context.get_input(0), true);
|
||||
// Use opset::If for dim normalization
|
||||
auto dim0_node = context.get_input(1);
|
||||
auto dim1_node = context.get_input(2);
|
||||
if (dim0 < 0) {
|
||||
dim0_node = std::make_shared<v1::Add>(rank, dim0_node);
|
||||
}
|
||||
if (dim1 < 0) {
|
||||
dim1_node = std::make_shared<v1::Add>(rank, dim1_node);
|
||||
}
|
||||
dim0_node = normalize_axis(context, dim0_node, rank);
|
||||
dim1_node = normalize_axis(context, dim1_node, rank);
|
||||
auto start = v0::Constant::create(element::i32, {}, {0});
|
||||
auto step = v0::Constant::create(element::i32, {}, {1});
|
||||
auto range = std::make_shared<v4::Range>(start, rank, step, element::i32);
|
||||
|
@ -28,11 +28,13 @@ OutputVector translate_unflatten(const NodeContext& context) {
|
||||
if (context.get_input_type(2).is<type::List>()) {
|
||||
sizes = concat_list_construct(sizes);
|
||||
}
|
||||
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i32));
|
||||
Output<Node> input_shape;
|
||||
Output<Node> rank;
|
||||
std::tie(input_shape, rank) = get_shape_rank(context, input);
|
||||
auto zero_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto one_1d = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
dim = context.mark_node(std::make_shared<v0::Convert>(dim, element::i32));
|
||||
dim = normalize_axis(context, dim, input);
|
||||
dim = normalize_axis(context, dim, rank);
|
||||
sizes = context.mark_node(std::make_shared<v0::Convert>(sizes, element::i32));
|
||||
auto max_int = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int>::max()}));
|
||||
auto dim_plus_one = context.mark_node(std::make_shared<v1::Add>(dim, one_1d));
|
||||
|
@ -117,11 +117,7 @@ std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id) {
|
||||
return context.mark_node(std::make_shared<opset10::Range>(start, reduced_rank, step, element::i32));
|
||||
};
|
||||
|
||||
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
|
||||
const Output<Node>& axis,
|
||||
const Output<Node>& input_node) {
|
||||
Output<Node> rank;
|
||||
std::tie(std::ignore, rank) = get_shape_rank(context, input_node);
|
||||
Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& rank) {
|
||||
auto axis_rank = context.mark_node(std::make_shared<opset10::Add>(axis, rank));
|
||||
auto is_less = context.mark_node(std::make_shared<opset10::Less>(axis_rank, rank));
|
||||
auto new_axis = context.mark_node(std::make_shared<opset10::Select>(is_less, axis_rank, axis));
|
||||
|
@ -39,9 +39,7 @@ Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<N
|
||||
|
||||
std::shared_ptr<Node> get_axes_range(const NodeContext& context, int input_id);
|
||||
|
||||
std::shared_ptr<Node> normalize_axis(const NodeContext& context,
|
||||
const Output<Node>& axis,
|
||||
const Output<Node>& input_node);
|
||||
Output<Node> normalize_axis(const NodeContext& context, const Output<Node>& axis, const Output<Node>& input_node);
|
||||
|
||||
std::shared_ptr<Node> numel(const NodeContext& context, const Output<Node>& x);
|
||||
|
||||
|
@ -90,7 +90,7 @@ Output<Node> quantize(const NodeContext& context,
|
||||
const auto rank = std::get<1>(get_shape_rank(context, input_convert, false, element::i32));
|
||||
const auto ones = context.mark_node(std::make_shared<v3::Broadcast>(one, rank));
|
||||
|
||||
const auto normalized_axis = normalize_axis(context, axis_convert, input_convert);
|
||||
const auto normalized_axis = normalize_axis(context, axis_convert, rank);
|
||||
const auto new_shape =
|
||||
context.mark_node(std::make_shared<v3::ScatterElementsUpdate>(ones, normalized_axis, neg_one, zero));
|
||||
|
||||
|
@ -27,7 +27,9 @@ class TestFlatten(PytorchLayerTest):
|
||||
|
||||
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
|
||||
|
||||
@pytest.mark.parametrize("dim0,dim1", [[0, 1],
|
||||
@pytest.mark.parametrize("dim0,dim1", [[0, -1],
|
||||
[-2, -1],
|
||||
[0, 1],
|
||||
[0, 2],
|
||||
[0, 3],
|
||||
[1, 2],
|
||||
|
@ -31,13 +31,13 @@ class aten_native_multi_head_attention(torch.nn.Module):
|
||||
# Float masks raise a warning in PyTorch and are (incorrectly) converted to bool,
|
||||
# which later returns NaNs as MHA's output
|
||||
if mask == 0:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
|
||||
self.mask_type = 0
|
||||
elif mask == 1:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, SEQ_LENGTH)).astype("bool"))
|
||||
self.mask_type = 1
|
||||
elif mask == 2:
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype(np.bool))
|
||||
self.mask = torch.from_numpy(np.random.randint(0, 2, (BATCH_SIZE, NUM_HEADS, SEQ_LENGTH, SEQ_LENGTH)).astype("bool"))
|
||||
self.mask_type = 2
|
||||
else:
|
||||
self.mask = None
|
||||
|
Loading…
Reference in New Issue
Block a user