[TF FE] Fix Transpose Sinking for dynamic shapes (#15302)
* TransposeSinking transformation: add support for dynamic shapes * fix e2e tests Co-authored-by: Ivan <ivan.tikhonov@intel.com>
This commit is contained in:
parent
bf98d31393
commit
7ee781dfbf
@ -14,9 +14,7 @@ namespace pass {
|
||||
class TransposeSinking : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::tensorflow::pass::TransposeSinking");
|
||||
TransposeSinking() {
|
||||
set_property(ov::pass::PassProperty::REQUIRE_STATIC_SHAPE, true);
|
||||
}
|
||||
TransposeSinking() = default;
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& function) override;
|
||||
};
|
||||
|
||||
|
@ -42,6 +42,12 @@ static AxisVector get_default_order(size_t rank) {
|
||||
return default_order;
|
||||
}
|
||||
|
||||
int64_t get_static_rank(const Output<Node>& output) {
|
||||
auto rank = output.get_partial_shape().rank();
|
||||
OPENVINO_ASSERT(rank.is_static(), "Dynamic rank is not supported in TransposeSinking transformation.");
|
||||
return rank.get_length();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static string describe(shared_ptr<Node> node) {
|
||||
// ensure that it's either a reshape or a transpose
|
||||
@ -93,7 +99,7 @@ static shared_ptr<Transpose> read_transposemap(TransposeMap& reorders, const Out
|
||||
}
|
||||
|
||||
static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1, const shared_ptr<Transpose>& t2) {
|
||||
auto default_order = get_default_order(t1->get_shape().size());
|
||||
auto default_order = get_default_order(get_static_rank(t1));
|
||||
auto t1_const = as_type_ptr<Constant>(t1->input_value(1).get_node_shared_ptr());
|
||||
auto t2_const = as_type_ptr<Constant>(t2->input_value(1).get_node_shared_ptr());
|
||||
|
||||
@ -112,7 +118,9 @@ static shared_ptr<Transpose> combine_transposes(const shared_ptr<Transpose>& t1,
|
||||
static void insert_transpose(const shared_ptr<Node>& target, const shared_ptr<Node>& transpose, size_t input_index) {
|
||||
OPENVINO_DEBUG << "Inserting transpose at input " << target->get_name() << " input index " << input_index;
|
||||
auto arg = target->input(input_index).get_source_output();
|
||||
OPENVINO_DEBUG << "Arg shape: " << arg.get_shape();
|
||||
if (arg.get_partial_shape().is_static()) {
|
||||
OPENVINO_DEBUG << "Arg shape: " << arg.get_shape();
|
||||
}
|
||||
auto new_order = as_type_ptr<Constant>(transpose->input_value(1).get_node_shared_ptr());
|
||||
auto new_transpose = make_transpose(arg.get_node_shared_ptr(), new_order->get_axis_vector_val());
|
||||
OPENVINO_DEBUG << "Inserting transpose " << describe<Transpose>(new_transpose) << " at input " << target->get_name()
|
||||
@ -142,7 +150,7 @@ static void mark_transpose_for_deletion(const shared_ptr<Node>& transpose,
|
||||
}
|
||||
|
||||
static shared_ptr<Transpose> create_default_transpose(const Output<Node>& n) {
|
||||
auto default_order = get_default_order(n.get_shape().size());
|
||||
auto default_order = get_default_order(get_static_rank(n));
|
||||
auto order = std::make_shared<Constant>(element::i64, Shape{default_order.size()}, default_order);
|
||||
return make_shared<Transpose>(n, order);
|
||||
}
|
||||
@ -160,27 +168,32 @@ static void convert_binary_to_default_order(const shared_ptr<Node>& binary,
|
||||
auto left = input.get_source_output();
|
||||
auto right_t = read_transposemap(reorders, right);
|
||||
auto right_const = as_type_ptr<Constant>(right_t->input_value(1).get_node_shared_ptr());
|
||||
|
||||
auto perm_to_def = permutation_to_default_order(right_const->get_axis_vector_val());
|
||||
|
||||
// if right input is being implicitly broadcasted, insert a reshape
|
||||
// instead of a transpose
|
||||
shared_ptr<Node> new_node;
|
||||
auto left_shape = left.get_shape();
|
||||
if (left_shape.size() < perm_to_def.size()) {
|
||||
left_shape.insert(left_shape.begin(), perm_to_def.size() - left_shape.size(), 1);
|
||||
auto left_rank = get_static_rank(left);
|
||||
if (left_rank < perm_to_def.size() && left.get_partial_shape().is_static()) {
|
||||
auto left_shape = left.get_shape();
|
||||
left_shape.insert(left_shape.begin(), perm_to_def.size() - left_rank, 1);
|
||||
|
||||
auto new_shape = apply_permutation(left_shape, perm_to_def);
|
||||
new_node = make_reshape(left, new_shape);
|
||||
} else if (left_shape.size() == perm_to_def.size()) {
|
||||
} else if (left_rank == perm_to_def.size()) {
|
||||
new_node = make_transpose(left, perm_to_def);
|
||||
} else {
|
||||
throw runtime_error("case not supported when converting binary to default order");
|
||||
}
|
||||
input.replace_source_output(new_node->output(0));
|
||||
|
||||
OPENVINO_DEBUG << "right = " << ov::util::vector_to_string(right.get_shape()) << ", "
|
||||
<< right.get_node_shared_ptr()->get_name();
|
||||
if (right.get_partial_shape().is_static()) {
|
||||
OPENVINO_DEBUG << "right = " << ov::util::vector_to_string(right.get_shape()) << ", "
|
||||
<< right.get_node_shared_ptr()->get_name();
|
||||
} else {
|
||||
OPENVINO_DEBUG << "right = "
|
||||
<< "dynamic shape, " << right.get_node_shared_ptr()->get_name();
|
||||
}
|
||||
// this should now insert transpose on right
|
||||
mark_transpose_for_deletion(right_t, transposes_to_delete);
|
||||
write_transposemap(reorders, binary, right_t);
|
||||
@ -203,14 +216,15 @@ static void materialize_shapes(const shared_ptr<Node>& n,
|
||||
<< arg.get_node_shared_ptr()->get_name();
|
||||
mark_transpose_for_deletion(arg_transpose, transposes_to_delete);
|
||||
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
|
||||
if (arg_transpose_order->get_axis_vector_val() != get_default_order(arg.get_shape().size())) {
|
||||
if (arg_transpose_order &&
|
||||
arg_transpose_order->get_axis_vector_val() != get_default_order(get_static_rank(arg))) {
|
||||
// Insert if arg needs to be transposed.
|
||||
insert_transpose(n, arg_transpose, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void sink_transpose(const shared_ptr<Transpose>& transpose,
|
||||
static bool sink_transpose(const shared_ptr<Transpose>& transpose,
|
||||
TransposeMap& reorders,
|
||||
set<shared_ptr<Node>>& transposes_to_delete) {
|
||||
OPENVINO_DEBUG << "Sinking Transpose :" << describe<Transpose>(transpose);
|
||||
@ -226,18 +240,25 @@ static void sink_transpose(const shared_ptr<Transpose>& transpose,
|
||||
replace_node(transpose, new_transpose);
|
||||
mark_transpose_for_deletion(new_transpose, transposes_to_delete);
|
||||
write_transposemap(reorders, new_transpose, new_transpose);
|
||||
} else {
|
||||
// combine_transposes failed
|
||||
// transpose remains in the graph
|
||||
OPENVINO_DEBUG << "CombineTranspose has failed. Writing original transpose to the transpose map.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void sink_unary(const shared_ptr<Node>& n,
|
||||
static bool sink_unary(const shared_ptr<Node>& n,
|
||||
TransposeMap& reorders,
|
||||
set<shared_ptr<Node>>& /* transposes_to_delete */) {
|
||||
auto arg_transpose = read_transposemap(reorders, n->input_value(0));
|
||||
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << n->get_name();
|
||||
write_transposemap(reorders, n, arg_transpose);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void sink_binary(const shared_ptr<Node>& binary,
|
||||
static bool sink_binary(const shared_ptr<Node>& binary,
|
||||
TransposeMap& reorders,
|
||||
set<shared_ptr<Node>>& transposes_to_delete) {
|
||||
auto left = binary->input_value(0);
|
||||
@ -246,18 +267,25 @@ static void sink_binary(const shared_ptr<Node>& binary,
|
||||
auto right_t = read_transposemap(reorders, right);
|
||||
auto left_const = as_type_ptr<Constant>(left_t->input_value(1).get_node_shared_ptr());
|
||||
auto right_const = as_type_ptr<Constant>(right_t->input_value(1).get_node_shared_ptr());
|
||||
if (!(left_const && right_const)) {
|
||||
OPENVINO_DEBUG << "TransposeSinking failed for binary op " << binary->get_name()
|
||||
<< "2nd inputs to Transposes must be constants.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto left_order = left_const->get_axis_vector_val();
|
||||
auto right_order = right_const->get_axis_vector_val();
|
||||
|
||||
auto left_mismatch = left_order != get_default_order(left.get_shape().size());
|
||||
auto right_mismatch = right_order != get_default_order(right.get_shape().size());
|
||||
auto left_rank = get_static_rank(left);
|
||||
auto right_rank = get_static_rank(right);
|
||||
auto left_mismatch = left_order != get_default_order(left_rank);
|
||||
auto right_mismatch = right_order != get_default_order(right_rank);
|
||||
|
||||
OPENVINO_DEBUG << "Sink binary " << binary->get_name()
|
||||
<< " left transpose: " << ov::util::vector_to_string(left_order)
|
||||
<< " left default: " << ov::util::vector_to_string(get_default_order(left.get_shape().size()))
|
||||
<< " left default: " << ov::util::vector_to_string(get_default_order(left_rank))
|
||||
<< " right transpose: " << ov::util::vector_to_string(right_order)
|
||||
<< " right default: " << ov::util::vector_to_string(get_default_order(right.get_shape().size()));
|
||||
<< " right default: " << ov::util::vector_to_string(get_default_order(right_rank));
|
||||
|
||||
if ((left_order.size() == right_order.size() && left_order == right_order) || (!left_mismatch && !right_mismatch)) {
|
||||
// Propagate the reshape which matches the shape of the binary node
|
||||
@ -277,95 +305,115 @@ static void sink_binary(const shared_ptr<Node>& binary,
|
||||
}
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
throw std::runtime_error("");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void sink_pad(shared_ptr<Pad> n, TransposeMap& reorders, set<shared_ptr<Node>>& /* transposes_to_delete */) {
|
||||
static bool sink_pad(shared_ptr<Pad> n, TransposeMap& reorders, set<shared_ptr<Node>>& /* transposes_to_delete */) {
|
||||
auto n_in = n->input_value(0);
|
||||
auto arg_transpose = read_transposemap(reorders, n_in);
|
||||
describe<Transpose>(arg_transpose);
|
||||
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
|
||||
auto order = arg_transpose_order->get_axis_vector_val();
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new pad will have the right shape
|
||||
auto def_order = permutation_to_default_order(order);
|
||||
if (arg_transpose->get_output_partial_shape(0).is_static()) {
|
||||
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
|
||||
auto order = arg_transpose_order->get_axis_vector_val();
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new pad will have the right shape
|
||||
auto def_order = permutation_to_default_order(order);
|
||||
|
||||
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
|
||||
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
|
||||
|
||||
auto dummy_correct_shape =
|
||||
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
|
||||
auto dummy_correct_shape =
|
||||
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
|
||||
|
||||
auto pad_begin = apply_permutation(n->get_pads_begin(), def_order);
|
||||
auto pad_end = apply_permutation(n->get_pads_end(), def_order);
|
||||
auto pad_begin = apply_permutation(n->get_pads_begin(), def_order);
|
||||
auto pad_end = apply_permutation(n->get_pads_end(), def_order);
|
||||
|
||||
auto new_begin = make_shared<Constant>(element::i64, Shape{pad_begin.size()}, pad_begin);
|
||||
auto new_end = make_shared<Constant>(element::i64, Shape{pad_end.size()}, pad_end);
|
||||
auto new_pad = make_shared<Pad>(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode());
|
||||
replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
|
||||
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
|
||||
replace_node(n, new_pad);
|
||||
auto new_transpose = make_transpose(new_pad, order);
|
||||
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
|
||||
write_transposemap(reorders, new_pad, new_transpose);
|
||||
auto new_begin = make_shared<Constant>(element::i64, Shape{pad_begin.size()}, pad_begin);
|
||||
auto new_end = make_shared<Constant>(element::i64, Shape{pad_end.size()}, pad_end);
|
||||
auto new_pad = make_shared<Pad>(dummy_correct_shape, new_begin, new_end, n->input_value(3), n->get_pad_mode());
|
||||
replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
|
||||
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
|
||||
replace_node(n, new_pad);
|
||||
auto new_transpose = make_transpose(new_pad, order);
|
||||
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
|
||||
write_transposemap(reorders, new_pad, new_transpose);
|
||||
} else {
|
||||
OPENVINO_DEBUG << "TransposeSinking failed for Pad op " << n->get_name()
|
||||
<< " . Output shape of Transpose op must be static.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static void sink_concat(const shared_ptr<Concat>& n,
|
||||
static bool sink_concat(const shared_ptr<Concat>& n,
|
||||
TransposeMap& reorders,
|
||||
set<shared_ptr<Node>>& transposes_to_delete) {
|
||||
auto n_in = n->input_value(0);
|
||||
auto arg_transpose = read_transposemap(reorders, n_in);
|
||||
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
|
||||
auto order = arg_transpose_order->get_axis_vector_val();
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new concat will have the right shape
|
||||
auto def_order = permutation_to_default_order(order);
|
||||
if (arg_transpose->get_output_partial_shape(0).is_static()) {
|
||||
auto arg_transpose_order = as_type_ptr<Constant>(arg_transpose->input_value(1).get_node_shared_ptr());
|
||||
auto order = arg_transpose_order->get_axis_vector_val();
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new concat will have the right shape
|
||||
auto def_order = permutation_to_default_order(order);
|
||||
|
||||
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
|
||||
auto input_shape = apply_permutation(arg_transpose->get_shape(), def_order);
|
||||
|
||||
auto dummy_correct_shape =
|
||||
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
|
||||
auto dummy_correct_shape =
|
||||
make_shared<ov::pass::pattern::op::Label>(arg_transpose->get_element_type(), input_shape);
|
||||
|
||||
NodeVector new_args;
|
||||
new_args.push_back(dummy_correct_shape);
|
||||
NodeVector new_args;
|
||||
new_args.push_back(dummy_correct_shape);
|
||||
|
||||
for (size_t i = 1; i < n->get_input_size(); i++) {
|
||||
auto iarg = n->input_value(i);
|
||||
auto iarg_transpose = read_transposemap(reorders, iarg);
|
||||
auto iarg_transpose_order = as_type_ptr<Constant>(iarg_transpose->input_value(1).get_node_shared_ptr());
|
||||
auto iorder = iarg_transpose_order->get_axis_vector_val();
|
||||
if (iorder != order) {
|
||||
OPENVINO_DEBUG << " input order at " << i << "-th arg is different from first arg";
|
||||
materialize_shapes(n, reorders, transposes_to_delete);
|
||||
return;
|
||||
for (size_t i = 1; i < n->get_input_size(); i++) {
|
||||
auto iarg = n->input_value(i);
|
||||
auto iarg_transpose = read_transposemap(reorders, iarg);
|
||||
auto iarg_transpose_order = as_type_ptr<Constant>(iarg_transpose->input_value(1).get_node_shared_ptr());
|
||||
auto iorder = iarg_transpose_order->get_axis_vector_val();
|
||||
if (iorder != order) {
|
||||
OPENVINO_DEBUG << " input order at " << i << "-th arg is different from first arg";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (iarg_transpose->get_output_partial_shape(0).is_dynamic()) {
|
||||
OPENVINO_DEBUG << "TransposeSinking failed for Concat op " << n->get_name()
|
||||
<< " . Input Transpose ops"
|
||||
" must have static shapes. ";
|
||||
return false;
|
||||
}
|
||||
auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order);
|
||||
|
||||
auto idummy_correct_shape =
|
||||
make_shared<ov::pass::pattern::op::Label>(iarg_transpose->get_element_type(), iinput_shape);
|
||||
new_args.push_back(idummy_correct_shape);
|
||||
}
|
||||
|
||||
auto iinput_shape = apply_permutation(iarg_transpose->get_shape(), def_order);
|
||||
|
||||
auto idummy_correct_shape =
|
||||
make_shared<ov::pass::pattern::op::Label>(iarg_transpose->get_element_type(), iinput_shape);
|
||||
new_args.push_back(idummy_correct_shape);
|
||||
auto new_axis = order.at(n->get_concatenation_axis());
|
||||
auto new_concat = make_shared<Concat>(new_args, new_axis);
|
||||
// put back the original arguments
|
||||
for (size_t i = 0; i < new_concat->get_input_size(); i++) {
|
||||
OPENVINO_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name()
|
||||
<< " input " << i;
|
||||
new_concat->input(i).replace_source_output(n->input_value(i));
|
||||
}
|
||||
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
|
||||
replace_node(n, new_concat);
|
||||
auto new_transpose = make_transpose(new_concat, order);
|
||||
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
|
||||
write_transposemap(reorders, new_concat, new_transpose);
|
||||
} else {
|
||||
OPENVINO_DEBUG << "TransposeSinking failed for Concat op " << n->get_name()
|
||||
<< " . Output shape of Transpose op must be static.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto new_axis = order.at(n->get_concatenation_axis());
|
||||
auto new_concat = make_shared<Concat>(new_args, new_axis);
|
||||
// put back the original arguments
|
||||
for (size_t i = 0; i < new_concat->get_input_size(); i++) {
|
||||
OPENVINO_DEBUG << "Replacing " << new_concat->get_name() << " input " << i << " with " << n->get_name()
|
||||
<< " input " << i;
|
||||
new_concat->input(i).replace_source_output(n->input_value(i));
|
||||
}
|
||||
OPENVINO_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
|
||||
replace_node(n, new_concat);
|
||||
auto new_transpose = make_transpose(new_concat, order);
|
||||
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(new_transpose) << " for " << n->get_name();
|
||||
write_transposemap(reorders, new_concat, new_transpose);
|
||||
return true;
|
||||
}
|
||||
|
||||
static void sink_prelu(const shared_ptr<PRelu>& prelu,
|
||||
static bool sink_prelu(const shared_ptr<PRelu>& prelu,
|
||||
TransposeMap& reorders,
|
||||
set<shared_ptr<Node>>& transposes_to_delete) {
|
||||
FRONT_END_GENERAL_CHECK(prelu, "Null pointer is given to PRelu node.");
|
||||
@ -377,9 +425,9 @@ static void sink_prelu(const shared_ptr<PRelu>& prelu,
|
||||
OPENVINO_DEBUG << "Propagating " << describe<Transpose>(arg_transpose) << " for " << prelu->get_name();
|
||||
write_transposemap(reorders, prelu, arg_transpose);
|
||||
} else {
|
||||
// TODO: handle other cases with non-scalar slope
|
||||
materialize_shapes(prelu, reorders, transposes_to_delete);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void purge_transposes(const set<shared_ptr<Node>>& transposes_to_delete) {
|
||||
@ -400,7 +448,7 @@ void purge_transposes(const set<shared_ptr<Node>>& transposes_to_delete) {
|
||||
bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared_ptr<Model>& f) {
|
||||
TransposeMap reorders;
|
||||
set<shared_ptr<Node>> transposes_to_delete;
|
||||
unordered_map<std::string, Shape> orig_result_out_shape;
|
||||
unordered_map<std::string, PartialShape> orig_result_out_shape;
|
||||
|
||||
// STEP 1 : Sink or Swim transposes away for op clusters
|
||||
try {
|
||||
@ -408,24 +456,28 @@ bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared
|
||||
OPENVINO_DEBUG << "Processing " << n->get_name();
|
||||
// collect output shape of all Result nodes for a sanity check
|
||||
if (ov::op::util::is_output(n)) {
|
||||
orig_result_out_shape[n->get_name()] = n->get_output_shape(0);
|
||||
orig_result_out_shape[n->get_name()] = n->get_output_partial_shape(0);
|
||||
}
|
||||
|
||||
bool sink_res = false;
|
||||
if (auto transpose = as_type_ptr<opset8::Transpose>(n)) {
|
||||
sink_transpose(transpose, reorders, transposes_to_delete);
|
||||
sink_res = sink_transpose(transpose, reorders, transposes_to_delete);
|
||||
} else if (ov::op::util::is_unary_elementwise_arithmetic(n) || as_type_ptr<Clamp>(n) ||
|
||||
as_type_ptr<Elu>(n) || as_type_ptr<SoftPlus>(n) || as_type_ptr<LogicalNot>(n)) {
|
||||
// Some unary operations are inherrited from Op class
|
||||
// so we need explicitly to check them
|
||||
sink_unary(n, reorders, transposes_to_delete);
|
||||
sink_res = sink_unary(n, reorders, transposes_to_delete);
|
||||
} else if (ov::op::util::is_binary_elementwise_arithmetic(n)) {
|
||||
sink_binary(n, reorders, transposes_to_delete);
|
||||
sink_res = sink_binary(n, reorders, transposes_to_delete);
|
||||
} else if (auto pad = as_type_ptr<Pad>(n)) {
|
||||
sink_pad(pad, reorders, transposes_to_delete);
|
||||
sink_res = sink_pad(pad, reorders, transposes_to_delete);
|
||||
} else if (auto concat = as_type_ptr<Concat>(n)) {
|
||||
sink_concat(concat, reorders, transposes_to_delete);
|
||||
sink_res = sink_concat(concat, reorders, transposes_to_delete);
|
||||
} else if (auto prelu = as_type_ptr<PRelu>(n)) {
|
||||
sink_prelu(prelu, reorders, transposes_to_delete);
|
||||
} else {
|
||||
sink_res = sink_prelu(prelu, reorders, transposes_to_delete);
|
||||
}
|
||||
|
||||
if (!sink_res) {
|
||||
materialize_shapes(n, reorders, transposes_to_delete);
|
||||
}
|
||||
}
|
||||
@ -448,16 +500,16 @@ bool ov::frontend::tensorflow::pass::TransposeSinking::run_on_model(const shared
|
||||
const ResultVector& results = f->get_results();
|
||||
for (const auto& r : results) {
|
||||
// make sure shapes are always materialized before results
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
r->get_shape() == r->get_input_shape(0) && r->get_element_type() == r->input_value(0).get_element_type(),
|
||||
" op::Result = ",
|
||||
*r,
|
||||
", Arg = ",
|
||||
r->input_value(0).get_node());
|
||||
FRONT_END_GENERAL_CHECK(r->get_output_partial_shape(0) == r->get_input_partial_shape(0) &&
|
||||
r->get_element_type() == r->input_value(0).get_element_type(),
|
||||
" op::Result = ",
|
||||
*r,
|
||||
", Arg = ",
|
||||
r->input_value(0).get_node());
|
||||
|
||||
// make sure that after TransposeSinking pass the output_shape for Result
|
||||
// does not change from the expected output_shape before the pass
|
||||
FRONT_END_GENERAL_CHECK(r->get_output_shape(0) == orig_result_out_shape[r->get_name()],
|
||||
FRONT_END_GENERAL_CHECK(r->get_output_partial_shape(0) == orig_result_out_shape[r->get_name()],
|
||||
" op::Result = ",
|
||||
*r,
|
||||
" expected output shape = ",
|
||||
|
@ -26,10 +26,25 @@ int64_t count_ops_of_type(const shared_ptr<Model>& f) {
|
||||
return cnt;
|
||||
}
|
||||
|
||||
TEST(TransposeSinkingTest, PassProperty) {
|
||||
auto pass = std::make_shared<TransposeSinking>();
|
||||
ASSERT_TRUE(pass->get_property(ov::pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(ov::pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
TEST(TransposeSinkingTest, DynamicShape) {
|
||||
ov::PartialShape shape_nhwc(vector<Dimension>(4, Dimension::dynamic()));
|
||||
auto a = make_shared<Parameter>(ngraph::element::i32, shape_nhwc);
|
||||
auto ng_order = std::make_shared<Constant>(ngraph::element::u64, ngraph::Shape{4}, ngraph::Shape{0, 3, 1, 2});
|
||||
auto transpose = make_shared<Transpose>(a, ng_order);
|
||||
auto absn = make_shared<Abs>(transpose);
|
||||
auto absn2 = make_shared<Abs>(absn);
|
||||
absn2->output(0).set_names({"out_name"});
|
||||
auto res = make_shared<Result>(absn2);
|
||||
auto func = make_shared<ngraph::Function>(ngraph::OutputVector{res}, ngraph::ParameterVector{a});
|
||||
|
||||
ov::pass::Manager pass_manager;
|
||||
pass_manager.register_pass<TransposeSinking>();
|
||||
pass_manager.run_passes(func);
|
||||
|
||||
auto new_transpose =
|
||||
ngraph::as_type_ptr<Transpose>(func->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_transpose);
|
||||
EXPECT_EQ(new_transpose->output(0).get_names(), std::unordered_set<std::string>({"out_name"}));
|
||||
}
|
||||
|
||||
TEST(TransposeSinkingTest, TensorNames) {
|
||||
|
Loading…
Reference in New Issue
Block a user