[TF FE] Support MobileNetV2 with FakeQuantize by TF FE (#12851)
Avoid extra transposes in conversion of FakeQuantWithMinMaxVars and add translator FakeQuantWithMinMaxVarsPerChannels. It allows to convert MobileNetV2 with FakeQuantWithMinMaxVars and FakeQuantWithMinMaxVarsPerChannels operations. Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
3856d69ae1
commit
c11151ae9e
@ -12,21 +12,21 @@ namespace ov {
|
||||
namespace frontend {
|
||||
namespace tensorflow {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_fake_quant_op(const NodeContext& node) {
|
||||
auto ng_input = node.get_input(0);
|
||||
auto ng_min = node.get_input(1);
|
||||
auto ng_max = node.get_input(2);
|
||||
default_op_checks(node, 2, {"FakeQuantWithMinMaxVars", "FakeQuantWithMinMaxVarsPerChannel"});
|
||||
auto inputs = node.get_input(0);
|
||||
auto min = node.get_input(1);
|
||||
auto max = node.get_input(2);
|
||||
|
||||
auto narrow_range = node.get_attribute<bool>("narrow_range");
|
||||
auto num_bits = node.get_attribute<int64_t>("num_bits");
|
||||
|
||||
auto levels = std::pow(2, num_bits) - int(narrow_range);
|
||||
auto min_less_max = make_shared<Less>(ng_min, ng_max);
|
||||
auto minimum = make_shared<Select>(min_less_max, ng_min, ng_max);
|
||||
auto maximum = make_shared<Select>(min_less_max, ng_max, ng_min);
|
||||
auto min_less_max = make_shared<Less>(min, max);
|
||||
auto minimum = make_shared<Select>(min_less_max, min, max);
|
||||
auto maximum = make_shared<Select>(min_less_max, max, min);
|
||||
|
||||
auto zero = make_shared<Constant>(ng_min.get_element_type(), Shape{}, std::vector<int>({0}));
|
||||
auto zero = make_shared<Constant>(min.get_element_type(), Shape{}, std::vector<int>({0}));
|
||||
|
||||
auto min_greater_zero = make_shared<Greater>(minimum, zero);
|
||||
auto max_minus_min = make_shared<Subtract>(maximum, minimum);
|
||||
@ -50,19 +50,11 @@ OutputVector translate_fake_quant_op(const NodeContext& node) {
|
||||
auto adjustment = make_shared<Subtract>(min_adj, minimum);
|
||||
auto max_adj = make_shared<Add>(maximum, adjustment);
|
||||
|
||||
auto ng_input_shape = ng_input.get_shape();
|
||||
if (ng_input_shape.size() == 4) {
|
||||
ng_input = make_transpose(ng_input, {0, 3, 1, 2});
|
||||
}
|
||||
auto res = make_shared<FakeQuantize>(ng_input, min_adj, max_adj, min_adj, max_adj, levels)->output(0);
|
||||
if (ng_input_shape.size() == 4) {
|
||||
res = make_transpose(res, {0, 2, 3, 1});
|
||||
}
|
||||
|
||||
set_node_name(node.get_name(), res.get_node_shared_ptr());
|
||||
auto res = make_shared<FakeQuantize>(inputs, min_adj, max_adj, min_adj, max_adj, levels);
|
||||
set_node_name(node.get_name(), res);
|
||||
return {res};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -186,6 +186,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"ExpandDims", translate_expand_dims_op},
|
||||
{"ExtractImagePatches", translate_extract_image_patches_op},
|
||||
{"FakeQuantWithMinMaxVars", translate_fake_quant_op},
|
||||
{"FakeQuantWithMinMaxVarsPerChannel", translate_fake_quant_op},
|
||||
{"Fill", translate_fill_op},
|
||||
{"FloorDiv", translate_floor_div_op},
|
||||
{"FusedBatchNorm", translate_fused_batch_norm_op},
|
||||
|
Loading…
Reference in New Issue
Block a user