[TF FE] Optimize DynamicPartition translator (#15750)
* [TF FE] Optimize DynamicPartition translator It avoid squared complexity to compute split lengths for each partition Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * WA a bug in Unique operation: not use axis --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
@@ -5,11 +5,11 @@
|
||||
#include <limits>
|
||||
|
||||
#include "common_op_table.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::opset10;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
@@ -29,23 +29,18 @@ OutputVector translate_dynamic_partition_op(const NodeContext& node) {
|
||||
auto num_partitions = node.get_attribute<int64_t>("num_partitions");
|
||||
|
||||
// compute how many slices are collected for each partition
|
||||
// that will be used as split_legths
|
||||
auto start = make_shared<Constant>(partitions_type, Shape{}, 0);
|
||||
auto end = make_shared<Constant>(partitions_type, Shape{}, num_partitions);
|
||||
auto step = make_shared<Constant>(partitions_type, Shape{}, 1);
|
||||
auto range_num_parts = make_shared<Range>(start, end, step, partitions_type);
|
||||
|
||||
auto unsqueeze_axis1 = make_shared<Constant>(element::i64, Shape{1}, 0);
|
||||
auto unsqueeze_partitions = make_shared<Unsqueeze>(norm_partitions, unsqueeze_axis1);
|
||||
auto unsqueeze_axis2 = make_shared<Constant>(element::i64, Shape{1}, 1);
|
||||
auto unsqueeze_range = make_shared<Unsqueeze>(range_num_parts, unsqueeze_axis2);
|
||||
|
||||
auto mask = make_shared<Equal>(unsqueeze_range, unsqueeze_partitions);
|
||||
auto mask_0_1 = make_shared<Select>(mask,
|
||||
make_shared<Constant>(partitions_type, Shape{1}, 1),
|
||||
make_shared<Constant>(partitions_type, Shape{1}, 0));
|
||||
auto reduction_axis = make_shared<Constant>(element::i64, Shape{1}, 1);
|
||||
auto split_legths = make_shared<ReduceSum>(mask_0_1, reduction_axis);
|
||||
// 1. initially assume that we collect zero slices for each partition
|
||||
auto const_zero = make_shared<Constant>(element::i64, Shape{}, 0);
|
||||
auto target_shape = make_shared<Constant>(element::i64, Shape{1}, num_partitions);
|
||||
Output<Node> split_legths = make_shared<Broadcast>(const_zero, target_shape);
|
||||
// 2. compute unique partition indices and their occurences
|
||||
auto axis = make_shared<Constant>(element::i32, Shape{1}, 0);
|
||||
auto unique_partition_inds = make_shared<Unique>(partitions);
|
||||
// 3. update split_lengths with a number of occurences by each partition index
|
||||
split_legths = make_shared<ScatterUpdate>(split_legths,
|
||||
unique_partition_inds->output(0),
|
||||
unique_partition_inds->output(3),
|
||||
axis);
|
||||
|
||||
// for stable sorting using TopK operation, we have to re-scale partition indices by the formula:
|
||||
// partition = partition * scale + partition_ind, where delta = max_int / num_partitions
|
||||
|
||||
Reference in New Issue
Block a user