Files
openvino/model-optimizer/extensions/front/tf/BatchMatMul_ext.py

36 lines
900 B
Python
Raw Normal View History

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
2019-04-12 18:25:53 +03:00
2020-02-11 22:48:49 +03:00
from extensions.ops.MatMul import MatMul
2019-04-12 18:25:53 +03:00
from mo.front.extractor import FrontExtractorOp
2019-08-09 19:02:42 +03:00
class BatchMatMulExtractor(FrontExtractorOp):
op = 'BatchMatMul'
2019-04-12 18:25:53 +03:00
enabled = True
2020-02-11 22:48:49 +03:00
@classmethod
def extract(cls, node):
2019-08-09 19:02:42 +03:00
attr = node.pb.attr
attrs = {
'transpose_a': int(attr['adj_x'].b),
'transpose_b': int(attr['adj_y'].b),
}
2020-02-11 22:48:49 +03:00
MatMul.update_node_stat(node, attrs)
return cls.enabled
class BatchMatMulV2Extractor(FrontExtractorOp):
op = 'BatchMatMulV2'
enabled = True
@classmethod
def extract(cls, node):
attr = node.pb.attr
attrs = {
'transpose_a': int(attr['adj_x'].b),
'transpose_b': int(attr['adj_y'].b),
}
MatMul.update_node_stat(node, attrs)
return cls.enabled