Files
openvino/model-optimizer/mo/ops/power.py
Alexey Suhov 6478f1742a Align copyright notice in python scripts (CVS-51320) (#4974)
* Align copyright notice in python scripts (CVS-51320)
2021-03-26 17:54:28 +03:00

55 lines
1.9 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from mo.front.common.partial_infer.eltwise import eltwise_infer
from mo.graph.graph import Graph, Node
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.ops.op import Op
class AttributedPower(Op):
op = 'AttributedPower'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'type': 'Power',
'power': 1,
'scale': 1,
'shift': 0,
'infer': self.infer,
'type_infer': self.type_infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
def supported_attrs(self):
return ['power', 'scale', 'shift']
@staticmethod
def type_infer(node: Node):
node.out_port(0).set_data_type(data_type_str_to_np(node.graph.graph['cmd_params'].data_type))
@staticmethod
def infer(node: Node):
name = node.soft_get('name', node.id)
connected_inputs = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_inputs) == 1 and 0 in connected_inputs, \
"AttributedPower should have 1 connected input port, but it doesn't for node: `{}`. Ports: {}" \
"".format(name, connected_inputs)
assert node.has_valid('scale'), \
'AttributedPower operation should have `scale` parameter set, but it doesn`t for node {}'.format(name)
assert node.has_valid('shift'), \
'AttributedPower operation should have `shift` parameter set, but it doesn`t for node {}'.format(name)
assert node.has_valid('power'), \
'AttributedPower operation should have `power` parameter set, but it doesn`t for node {}'.format(name)
eltwise_infer(node, lambda a: np.power(a * node.scale + node.shift, node.power))