From 7cc7e204f84bdea926ea3fabc5178616fb1c846e Mon Sep 17 00:00:00 2001 From: Surya Siddharth Pemmaraju Date: Wed, 1 Nov 2023 00:36:19 -0700 Subject: [PATCH] Fixed issue with cat in fx backend (#20744) * Added fix for cat in torchfx * Added batch_norm_legit_no_training op * Fixed coding style * Fixed clang format * Addressed PR comments --- .../pytorch/torchdynamo/op_support.py | 2 ++ src/frontends/pytorch/src/op/batch_norm.cpp | 34 +++++++++++++++---- src/frontends/pytorch/src/op/cat.cpp | 13 +++---- src/frontends/pytorch/src/op_table.cpp | 7 ++-- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 4a76d90b160..a6fb4de094d 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -82,6 +82,8 @@ class OperatorSupport(OperatorSupport): "torch.ops.aten.mul.Scalar": None, "torch.ops.aten.mul.Tensor": None, "torch.ops.aten.native_batch_norm.default": None, + "torch.ops.aten._native_batch_norm_legit.default": None, + "torch.ops.aten._native_batch_norm_legit_no_training.default": None, "torch.ops.aten.native_group_norm.default": None, "torch.ops.aten.native_layer_norm.default": None, "torch.ops.aten.neg.default": None, diff --git a/src/frontends/pytorch/src/op/batch_norm.cpp b/src/frontends/pytorch/src/op/batch_norm.cpp index 1c7528e8ed3..126588eb952 100644 --- a/src/frontends/pytorch/src/op/batch_norm.cpp +++ b/src/frontends/pytorch/src/op/batch_norm.cpp @@ -39,10 +39,13 @@ Output broadcast_const_to_channel_dim(const NodeContext& context, } } // namespace -OutputVector translate_batch_norm(const NodeContext& context) { +OutputVector translate_batch_norm_common(const NodeContext& context, bool training) { // Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, // bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor - num_inputs_check(context, 8, 9); + + // batch_norm_legit_no_training Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? + // running_mean, Tensor? running_var, float momentum, float eps) -> Tensor + auto input = context.get_input(0); Output weight; Output bias; @@ -63,7 +66,6 @@ OutputVector translate_batch_norm(const NodeContext& context) { bias = broadcast_const_to_channel_dim(context, input, zero_f); } // index 3 running_mean and index 4 running_var can be none for training case only, check that not training before - auto training = context.const_input(5); // if training for batch norm activated, but model in eval mode, it uses current statistics instead of running if (training) { auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); @@ -92,14 +94,34 @@ OutputVector translate_batch_norm(const NodeContext& context) { running_var = current_var; } // Input with index 6 is momentum, it is used only for updating running_mean accumulation during training - auto epsilon = context.const_input(7); + // In batch_norm_legit_no_training, momentum is index 5 and epsilon is 6 + float epsilon; + if (context.get_input_size() == 7) { + epsilon = context.const_input(6); + } else { + epsilon = context.const_input(7); + } // Input with index 8 is flag "cudnn_enabled" we can ignore it return {context.mark_node( std::make_shared(input, weight, bias, running_mean, running_var, epsilon))}; }; -OutputVector translate_batch_norm_fx(const NodeContext& context) { - auto output = translate_batch_norm(context); +OutputVector translate_batch_norm(const NodeContext& context) { + num_inputs_check(context, 7, 9); + auto training = context.const_input(5); + return translate_batch_norm_common(context, training); +} + +OutputVector translate_batch_norm_legit_fx(const NodeContext& context) { + num_inputs_check(context, 7, 9); + auto training = context.const_input(5); + auto output = translate_batch_norm_common(context, training); + return {context.mark_node(make_list_construct(output))}; +} + +OutputVector translate_batch_norm_legit_no_training_fx(const NodeContext& context) { + num_inputs_check(context, 7, 9); + auto output = translate_batch_norm_common(context, false); return {context.mark_node(make_list_construct(output))}; } diff --git a/src/frontends/pytorch/src/op/cat.cpp b/src/frontends/pytorch/src/op/cat.cpp index 63e61734544..9476979a118 100644 --- a/src/frontends/pytorch/src/op/cat.cpp +++ b/src/frontends/pytorch/src/op/cat.cpp @@ -22,7 +22,8 @@ using namespace ov::op; OutputVector translate_cat_common(const NodeContext& context, const std::deque>& list_elems, - int64_t axis) { + int64_t axis, + bool is_fx) { if (list_elems.empty()) { // couldn't get list elements auto fw_node = std::make_shared(context.get_decoder(), OutputVector{context.get_input(0)}, 1); @@ -39,8 +40,8 @@ OutputVector translate_cat_common(const NodeContext& context, "::cat is located inside body while inputs are located outside of the body. " "This case is not supported."); if (list_elems.size() == 1 && - !std::dynamic_pointer_cast(context.get_input(0).get_node_shared_ptr())) { - // Case when list was merged into tensor + !std::dynamic_pointer_cast(context.get_input(0).get_node_shared_ptr()) && !is_fx) { + // Case when list was merged into tensor. // This case doesn't work with torchfx auto tensor = list_elems[0]; auto shape = context.mark_node(std::make_shared(tensor, element::i32)); auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); @@ -63,7 +64,7 @@ OutputVector translate_cat(const NodeContext& context) { num_inputs_check(context, 2, 3); const auto&& list_elems = get_list_as_outputs(context.get_input(0)); auto axis = context.const_input(1); - auto out = translate_cat_common(context, list_elems, axis); + auto out = translate_cat_common(context, list_elems, axis, false); if (!context.input_is_none(2)) { context.mutate_input(2, out[0]); } @@ -78,7 +79,7 @@ OutputVector translate_cat_fx(const NodeContext& context) { list_elems.push_back(context.get_input(static_cast(i))); } auto axis = context.const_input(context.get_input_size() - 1); - return translate_cat_common(context, list_elems, axis); + return translate_cat_common(context, list_elems, axis, true); }; OutputVector translate_quantized_cat(const NodeContext& context) { @@ -87,7 +88,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) { auto axis = context.const_input(1); FRONT_END_OP_CONVERSION_CHECK(!list_elems.empty(), "Couldn't find quantized input for quantized::cat operation."); return {quantize(context, - translate_cat_common(context, list_elems, axis)[0], + translate_cat_common(context, list_elems, axis, false)[0], context.get_input(2), context.get_input(3), list_elems.front())}; diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 3f71d22e428..933f9a48eeb 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -213,7 +213,8 @@ OP_CONVERTER(translate_quantized_linear); OP_CONVERTER(translate_xor); // Torch FX Translations OP_CONVERTER(translate_arange_fx); -OP_CONVERTER(translate_batch_norm_fx); +OP_CONVERTER(translate_batch_norm_legit_fx); +OP_CONVERTER(translate_batch_norm_legit_no_training_fx); OP_CONVERTER(translate_cat_fx); OP_CONVERTER(translate_chunk_fx); OP_CONVERTER(translate_expand_fx); @@ -612,7 +613,9 @@ const std::map get_supported_ops_fx() { {"aten.mm.default", op::translate_1to1_match_2_inputs}, {"aten.mul.Tensor", op::translate_1to1_match_2_inputs_align_types}, {"aten.mul.Scalar", op::translate_1to1_match_2_inputs_align_types}, - {"aten.native_batch_norm.default", op::translate_batch_norm_fx}, + {"aten.native_batch_norm.default", op::translate_batch_norm_legit_fx}, + {"aten._native_batch_norm_legit.default", op::translate_batch_norm_legit_fx}, + {"aten._native_batch_norm_legit_no_training.default", op::translate_batch_norm_legit_no_training_fx}, {"aten.native_group_norm.default", op::translate_group_norm_fx}, {"aten.native_layer_norm.default", op::translate_layer_norm_fx}, {"aten.neg.default", op::translate_neg},