Fixed issue with cat in fx backend (#20744)
* Added fix for cat in torchfx * Added batch_norm_legit_no_training op * Fixed coding style * Fixed clang format * Addressed PR comments
This commit is contained in:
parent
26c9c41b8e
commit
bb0e4f8ecf
@ -82,6 +82,8 @@ class OperatorSupport(OperatorSupport):
|
||||
"torch.ops.aten.mul.Scalar": None,
|
||||
"torch.ops.aten.mul.Tensor": None,
|
||||
"torch.ops.aten.native_batch_norm.default": None,
|
||||
"torch.ops.aten._native_batch_norm_legit.default": None,
|
||||
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
|
||||
"torch.ops.aten.native_group_norm.default": None,
|
||||
"torch.ops.aten.native_layer_norm.default": None,
|
||||
"torch.ops.aten.neg.default": None,
|
||||
|
@ -39,10 +39,13 @@ Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
OutputVector translate_batch_norm_common(const NodeContext& context, bool training) {
|
||||
// Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var,
|
||||
// bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
|
||||
num_inputs_check(context, 8, 9);
|
||||
|
||||
// batch_norm_legit_no_training Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor?
|
||||
// running_mean, Tensor? running_var, float momentum, float eps) -> Tensor
|
||||
|
||||
auto input = context.get_input(0);
|
||||
Output<Node> weight;
|
||||
Output<Node> bias;
|
||||
@ -63,7 +66,6 @@ OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
bias = broadcast_const_to_channel_dim(context, input, zero_f);
|
||||
}
|
||||
// index 3 running_mean and index 4 running_var can be none for training case only, check that not training before
|
||||
auto training = context.const_input<bool>(5);
|
||||
// if training for batch norm activated, but model in eval mode, it uses current statistics instead of running
|
||||
if (training) {
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
@ -92,14 +94,34 @@ OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
running_var = current_var;
|
||||
}
|
||||
// Input with index 6 is momentum, it is used only for updating running_mean accumulation during training
|
||||
auto epsilon = context.const_input<float>(7);
|
||||
// In batch_norm_legit_no_training, momentum is index 5 and epsilon is 6
|
||||
float epsilon;
|
||||
if (context.get_input_size() == 7) {
|
||||
epsilon = context.const_input<float>(6);
|
||||
} else {
|
||||
epsilon = context.const_input<float>(7);
|
||||
}
|
||||
// Input with index 8 is flag "cudnn_enabled" we can ignore it
|
||||
return {context.mark_node(
|
||||
std::make_shared<v5::BatchNormInference>(input, weight, bias, running_mean, running_var, epsilon))};
|
||||
};
|
||||
|
||||
OutputVector translate_batch_norm_fx(const NodeContext& context) {
|
||||
auto output = translate_batch_norm(context);
|
||||
OutputVector translate_batch_norm(const NodeContext& context) {
|
||||
num_inputs_check(context, 7, 9);
|
||||
auto training = context.const_input<bool>(5);
|
||||
return translate_batch_norm_common(context, training);
|
||||
}
|
||||
|
||||
OutputVector translate_batch_norm_legit_fx(const NodeContext& context) {
|
||||
num_inputs_check(context, 7, 9);
|
||||
auto training = context.const_input<bool>(5);
|
||||
auto output = translate_batch_norm_common(context, training);
|
||||
return {context.mark_node(make_list_construct(output))};
|
||||
}
|
||||
|
||||
OutputVector translate_batch_norm_legit_no_training_fx(const NodeContext& context) {
|
||||
num_inputs_check(context, 7, 9);
|
||||
auto output = translate_batch_norm_common(context, false);
|
||||
return {context.mark_node(make_list_construct(output))};
|
||||
}
|
||||
|
||||
|
@ -22,7 +22,8 @@ using namespace ov::op;
|
||||
|
||||
OutputVector translate_cat_common(const NodeContext& context,
|
||||
const std::deque<ov::Output<ov::Node>>& list_elems,
|
||||
int64_t axis) {
|
||||
int64_t axis,
|
||||
bool is_fx) {
|
||||
if (list_elems.empty()) {
|
||||
// couldn't get list elements
|
||||
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
|
||||
@ -39,8 +40,8 @@ OutputVector translate_cat_common(const NodeContext& context,
|
||||
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
|
||||
"This case is not supported.");
|
||||
if (list_elems.size() == 1 &&
|
||||
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr())) {
|
||||
// Case when list was merged into tensor
|
||||
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr()) && !is_fx) {
|
||||
// Case when list was merged into tensor. // This case doesn't work with torchfx
|
||||
auto tensor = list_elems[0];
|
||||
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(tensor, element::i32));
|
||||
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
@ -63,7 +64,7 @@ OutputVector translate_cat(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
|
||||
auto axis = context.const_input<int64_t>(1);
|
||||
auto out = translate_cat_common(context, list_elems, axis);
|
||||
auto out = translate_cat_common(context, list_elems, axis, false);
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, out[0]);
|
||||
}
|
||||
@ -78,7 +79,7 @@ OutputVector translate_cat_fx(const NodeContext& context) {
|
||||
list_elems.push_back(context.get_input(static_cast<int>(i)));
|
||||
}
|
||||
auto axis = context.const_input<int64_t>(context.get_input_size() - 1);
|
||||
return translate_cat_common(context, list_elems, axis);
|
||||
return translate_cat_common(context, list_elems, axis, true);
|
||||
};
|
||||
|
||||
OutputVector translate_quantized_cat(const NodeContext& context) {
|
||||
@ -87,7 +88,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) {
|
||||
auto axis = context.const_input<int64_t>(1);
|
||||
FRONT_END_OP_CONVERSION_CHECK(!list_elems.empty(), "Couldn't find quantized input for quantized::cat operation.");
|
||||
return {quantize(context,
|
||||
translate_cat_common(context, list_elems, axis)[0],
|
||||
translate_cat_common(context, list_elems, axis, false)[0],
|
||||
context.get_input(2),
|
||||
context.get_input(3),
|
||||
list_elems.front())};
|
||||
|
@ -213,7 +213,8 @@ OP_CONVERTER(translate_quantized_linear);
|
||||
OP_CONVERTER(translate_xor);
|
||||
// Torch FX Translations
|
||||
OP_CONVERTER(translate_arange_fx);
|
||||
OP_CONVERTER(translate_batch_norm_fx);
|
||||
OP_CONVERTER(translate_batch_norm_legit_fx);
|
||||
OP_CONVERTER(translate_batch_norm_legit_no_training_fx);
|
||||
OP_CONVERTER(translate_cat_fx);
|
||||
OP_CONVERTER(translate_chunk_fx);
|
||||
OP_CONVERTER(translate_expand_fx);
|
||||
@ -612,7 +613,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
|
||||
{"aten.mm.default", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten.mul.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
|
||||
{"aten.mul.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
|
||||
{"aten.native_batch_norm.default", op::translate_batch_norm_fx},
|
||||
{"aten.native_batch_norm.default", op::translate_batch_norm_legit_fx},
|
||||
{"aten._native_batch_norm_legit.default", op::translate_batch_norm_legit_fx},
|
||||
{"aten._native_batch_norm_legit_no_training.default", op::translate_batch_norm_legit_no_training_fx},
|
||||
{"aten.native_group_norm.default", op::translate_group_norm_fx},
|
||||
{"aten.native_layer_norm.default", op::translate_layer_norm_fx},
|
||||
{"aten.neg.default", op::translate_neg},
|
||||
|
Loading…
Reference in New Issue
Block a user