[TF FE] Support MobileNetV2 with FakeQuantize by TF FE (#12851)

Avoid extra transposes in conversion of FakeQuantWithMinMaxVars and
add translator FakeQuantWithMinMaxVarsPerChannels.
It allows to convert MobileNetV2 with FakeQuantWithMinMaxVars and
FakeQuantWithMinMaxVarsPerChannels operations.

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2022-09-01 09:36:24 +03:00 committed by GitHub
parent 3856d69ae1
commit c11151ae9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 19 deletions

View File

@ -12,21 +12,21 @@ namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
OutputVector translate_fake_quant_op(const NodeContext& node) {
auto ng_input = node.get_input(0);
auto ng_min = node.get_input(1);
auto ng_max = node.get_input(2);
default_op_checks(node, 2, {"FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel"});
auto inputs = node.get_input(0);
auto min = node.get_input(1);
auto max = node.get_input(2);
auto narrow_range = node.get_attribute<bool>("narrow_range");
auto num_bits = node.get_attribute<int64_t>("num_bits");
auto levels = std::pow(2, num_bits) - int(narrow_range);
auto min_less_max = make_shared<Less>(ng_min, ng_max);
auto minimum = make_shared<Select>(min_less_max, ng_min, ng_max);
auto maximum = make_shared<Select>(min_less_max, ng_max, ng_min);
auto min_less_max = make_shared<Less>(min, max);
auto minimum = make_shared<Select>(min_less_max, min, max);
auto maximum = make_shared<Select>(min_less_max, max, min);
auto zero = make_shared<Constant>(ng_min.get_element_type(), Shape{}, std::vector<int>({0}));
auto zero = make_shared<Constant>(min.get_element_type(), Shape{}, std::vector<int>({0}));
auto min_greater_zero = make_shared<Greater>(minimum, zero);
auto max_minus_min = make_shared<Subtract>(maximum, minimum);
@ -50,19 +50,11 @@ OutputVector translate_fake_quant_op(const NodeContext& node) {
auto adjustment = make_shared<Subtract>(min_adj, minimum);
auto max_adj = make_shared<Add>(maximum, adjustment);
auto ng_input_shape = ng_input.get_shape();
if (ng_input_shape.size() == 4) {
ng_input = make_transpose(ng_input, {0, 3, 1, 2});
}
auto res = make_shared<FakeQuantize>(ng_input, min_adj, max_adj, min_adj, max_adj, levels)->output(0);
if (ng_input_shape.size() == 4) {
res = make_transpose(res, {0, 2, 3, 1});
}
set_node_name(node.get_name(), res.get_node_shared_ptr());
auto res = make_shared<FakeQuantize>(inputs, min_adj, max_adj, min_adj, max_adj, levels);
set_node_name(node.get_name(), res);
return {res};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov
} // namespace ov

View File

@ -186,6 +186,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"ExpandDims", translate_expand_dims_op},
{"ExtractImagePatches", translate_extract_image_patches_op},
{"FakeQuantWithMinMaxVars", translate_fake_quant_op},
{"FakeQuantWithMinMaxVarsPerChannel", translate_fake_quant_op},
{"Fill", translate_fill_op},
{"FloorDiv", translate_floor_div_op},
{"FusedBatchNorm", translate_fused_batch_norm_op},