From 9149d899b349c5c96a326232d66a4f753b17ad98 Mon Sep 17 00:00:00 2001 From: Pavel Esir Date: Mon, 5 Oct 2020 12:48:57 +0300 Subject: [PATCH] [MO] Fix: add PermuteInputs to shape broadcasting (#2419) * fix: add PermuteInputs to shape broadcasting * fix type declaration typo --- model-optimizer/mo/ops/broadcast.py | 1 + model-optimizer/mo/utils/broadcasting.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/model-optimizer/mo/ops/broadcast.py b/model-optimizer/mo/ops/broadcast.py index 247bca668d8..8320cbbf1be 100644 --- a/model-optimizer/mo/ops/broadcast.py +++ b/model-optimizer/mo/ops/broadcast.py @@ -87,6 +87,7 @@ class Broadcast(Op): axes_mapping = node.in_port(2).data.get_value() assert axes_mapping is not None, 'Broadcast(mode="explicit") with dynamic axes_mapping input ' \ 'is not supported. Node: `{}`'.format(node_name) + PermuteInputs().set_input_permutation(node.in_node(2), node, 'output:0', 'axis') axes_mapping = node.in_port(2).data.get_value() new_shape,_ = explicit_shape_broadcasting(input_shape, target_shape, axes_mapping) node.out_port(0).data.set_shape(new_shape) diff --git a/model-optimizer/mo/utils/broadcasting.py b/model-optimizer/mo/utils/broadcasting.py index ec757f3835d..c1d251a98c7 100644 --- a/model-optimizer/mo/utils/broadcasting.py +++ b/model-optimizer/mo/utils/broadcasting.py @@ -79,7 +79,7 @@ def bi_directional_shape_broadcasting(input_shape_1: np.array, input_shape_2: np return np.maximum(shape_1, shape_2) -def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> np.array: +def explicit_shape_broadcasting(input_shape: np.array, target_shape: np.array, axes_mapping: np.array) -> [np.array, np.array]: """ Explicit shape broadcasting of input tensor. Function only asserts that values are correct and normalizes axes. Resulting shape is equal to target_shape.