[TF FE] Optimize DynamicPartition translator (#15750)

* [TF FE] Optimize DynamicPartition translator

It avoid squared complexity to compute split lengths for each partition

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* WA a bug in Unique operation: not use axis

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2023-02-16 15:48:00 +01:00
committed by GitHub
parent 3e32a9f2af
commit 4a5452f29b

View File

@@ -5,11 +5,11 @@
#include <limits>
#include "common_op_table.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/opsets/opset10.hpp"
using namespace std;
using namespace ov;
using namespace ov::opset9;
using namespace ov::opset10;
namespace ov {
namespace frontend {
@@ -29,23 +29,18 @@ OutputVector translate_dynamic_partition_op(const NodeContext& node) {
auto num_partitions = node.get_attribute<int64_t>("num_partitions");
// compute how many slices are collected for each partition
// that will be used as split_legths
auto start = make_shared<Constant>(partitions_type, Shape{}, 0);
auto end = make_shared<Constant>(partitions_type, Shape{}, num_partitions);
auto step = make_shared<Constant>(partitions_type, Shape{}, 1);
auto range_num_parts = make_shared<Range>(start, end, step, partitions_type);
auto unsqueeze_axis1 = make_shared<Constant>(element::i64, Shape{1}, 0);
auto unsqueeze_partitions = make_shared<Unsqueeze>(norm_partitions, unsqueeze_axis1);
auto unsqueeze_axis2 = make_shared<Constant>(element::i64, Shape{1}, 1);
auto unsqueeze_range = make_shared<Unsqueeze>(range_num_parts, unsqueeze_axis2);
auto mask = make_shared<Equal>(unsqueeze_range, unsqueeze_partitions);
auto mask_0_1 = make_shared<Select>(mask,
make_shared<Constant>(partitions_type, Shape{1}, 1),
make_shared<Constant>(partitions_type, Shape{1}, 0));
auto reduction_axis = make_shared<Constant>(element::i64, Shape{1}, 1);
auto split_legths = make_shared<ReduceSum>(mask_0_1, reduction_axis);
// 1. initially assume that we collect zero slices for each partition
auto const_zero = make_shared<Constant>(element::i64, Shape{}, 0);
auto target_shape = make_shared<Constant>(element::i64, Shape{1}, num_partitions);
Output<Node> split_legths = make_shared<Broadcast>(const_zero, target_shape);
// 2. compute unique partition indices and their occurences
auto axis = make_shared<Constant>(element::i32, Shape{1}, 0);
auto unique_partition_inds = make_shared<Unique>(partitions);
// 3. update split_lengths with a number of occurences by each partition index
split_legths = make_shared<ScatterUpdate>(split_legths,
unique_partition_inds->output(0),
unique_partition_inds->output(3),
axis);
// for stable sorting using TopK operation, we have to re-scale partition indices by the formula:
// partition = partition * scale + partition_ind, where delta = max_int / num_partitions