Fixed issue with cat in fx backend (#20744)

* Added fix for cat in torchfx

* Added batch_norm_legit_no_training op

* Fixed coding style

* Fixed clang format

* Addressed PR comments
This commit is contained in:
Surya Siddharth Pemmaraju 2023-11-01 00:36:19 -07:00 committed by GitHub
parent 26c9c41b8e
commit bb0e4f8ecf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 42 additions and 14 deletions

View File

@ -82,6 +82,8 @@ class OperatorSupport(OperatorSupport):
"torch.ops.aten.mul.Scalar": None,
"torch.ops.aten.mul.Tensor": None,
"torch.ops.aten.native_batch_norm.default": None,
"torch.ops.aten._native_batch_norm_legit.default": None,
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
"torch.ops.aten.native_group_norm.default": None,
"torch.ops.aten.native_layer_norm.default": None,
"torch.ops.aten.neg.default": None,

View File

@ -39,10 +39,13 @@ Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
}
} // namespace
OutputVector translate_batch_norm(const NodeContext& context) {
OutputVector translate_batch_norm_common(const NodeContext& context, bool training) {
// Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var,
// bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
num_inputs_check(context, 8, 9);
// batch_norm_legit_no_training Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor?
// running_mean, Tensor? running_var, float momentum, float eps) -> Tensor
auto input = context.get_input(0);
Output<Node> weight;
Output<Node> bias;
@ -63,7 +66,6 @@ OutputVector translate_batch_norm(const NodeContext& context) {
bias = broadcast_const_to_channel_dim(context, input, zero_f);
}
// index 3 running_mean and index 4 running_var can be none for training case only, check that not training before
auto training = context.const_input<bool>(5);
// if training for batch norm activated, but model in eval mode, it uses current statistics instead of running
if (training) {
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
@ -92,14 +94,34 @@ OutputVector translate_batch_norm(const NodeContext& context) {
running_var = current_var;
}
// Input with index 6 is momentum, it is used only for updating running_mean accumulation during training
auto epsilon = context.const_input<float>(7);
// In batch_norm_legit_no_training, momentum is index 5 and epsilon is 6
float epsilon;
if (context.get_input_size() == 7) {
epsilon = context.const_input<float>(6);
} else {
epsilon = context.const_input<float>(7);
}
// Input with index 8 is flag "cudnn_enabled" we can ignore it
return {context.mark_node(
std::make_shared<v5::BatchNormInference>(input, weight, bias, running_mean, running_var, epsilon))};
};
OutputVector translate_batch_norm_fx(const NodeContext& context) {
auto output = translate_batch_norm(context);
OutputVector translate_batch_norm(const NodeContext& context) {
num_inputs_check(context, 7, 9);
auto training = context.const_input<bool>(5);
return translate_batch_norm_common(context, training);
}
OutputVector translate_batch_norm_legit_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
auto training = context.const_input<bool>(5);
auto output = translate_batch_norm_common(context, training);
return {context.mark_node(make_list_construct(output))};
}
OutputVector translate_batch_norm_legit_no_training_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
auto output = translate_batch_norm_common(context, false);
return {context.mark_node(make_list_construct(output))};
}

View File

@ -22,7 +22,8 @@ using namespace ov::op;
OutputVector translate_cat_common(const NodeContext& context,
const std::deque<ov::Output<ov::Node>>& list_elems,
int64_t axis) {
int64_t axis,
bool is_fx) {
if (list_elems.empty()) {
// couldn't get list elements
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
@ -39,8 +40,8 @@ OutputVector translate_cat_common(const NodeContext& context,
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
"This case is not supported.");
if (list_elems.size() == 1 &&
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr())) {
// Case when list was merged into tensor
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr()) && !is_fx) {
// Case when list was merged into tensor. // This case doesn't work with torchfx
auto tensor = list_elems[0];
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(tensor, element::i32));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
@ -63,7 +64,7 @@ OutputVector translate_cat(const NodeContext& context) {
num_inputs_check(context, 2, 3);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
auto axis = context.const_input<int64_t>(1);
auto out = translate_cat_common(context, list_elems, axis);
auto out = translate_cat_common(context, list_elems, axis, false);
if (!context.input_is_none(2)) {
context.mutate_input(2, out[0]);
}
@ -78,7 +79,7 @@ OutputVector translate_cat_fx(const NodeContext& context) {
list_elems.push_back(context.get_input(static_cast<int>(i)));
}
auto axis = context.const_input<int64_t>(context.get_input_size() - 1);
return translate_cat_common(context, list_elems, axis);
return translate_cat_common(context, list_elems, axis, true);
};
OutputVector translate_quantized_cat(const NodeContext& context) {
@ -87,7 +88,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) {
auto axis = context.const_input<int64_t>(1);
FRONT_END_OP_CONVERSION_CHECK(!list_elems.empty(), "Couldn't find quantized input for quantized::cat operation.");
return {quantize(context,
translate_cat_common(context, list_elems, axis)[0],
translate_cat_common(context, list_elems, axis, false)[0],
context.get_input(2),
context.get_input(3),
list_elems.front())};

View File

@ -213,7 +213,8 @@ OP_CONVERTER(translate_quantized_linear);
OP_CONVERTER(translate_xor);
// Torch FX Translations
OP_CONVERTER(translate_arange_fx);
OP_CONVERTER(translate_batch_norm_fx);
OP_CONVERTER(translate_batch_norm_legit_fx);
OP_CONVERTER(translate_batch_norm_legit_no_training_fx);
OP_CONVERTER(translate_cat_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_expand_fx);
@ -612,7 +613,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.mm.default", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten.mul.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten.mul.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten.native_batch_norm.default", op::translate_batch_norm_fx},
{"aten.native_batch_norm.default", op::translate_batch_norm_legit_fx},
{"aten._native_batch_norm_legit.default", op::translate_batch_norm_legit_fx},
{"aten._native_batch_norm_legit_no_training.default", op::translate_batch_norm_legit_no_training_fx},
{"aten.native_group_norm.default", op::translate_group_norm_fx},
{"aten.native_layer_norm.default", op::translate_layer_norm_fx},
{"aten.neg.default", op::translate_neg},