Improve reshapeability of models with eltwise nodes influencing shapes (#2767)
* Fix ElementwiseInputReshape transformation Reshape node always needs to be inserted in order to preserve ShapeOf nodes (reshapability of a model) that can potentially be above elementwise node. Refactor EltwiseInputReshape_test and EltwiseInputNormalization_test since the logic of maintaining reshape for eltwise has been changed. Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Merge EltwiseInputNormalization and EltwiseInputReshape transformations Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Remove Unsqueeze from Fused_op Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix code after code review #1 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix code after review #2 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix code review #4 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Perform full normalization based on shapes of all inputs to eltwise Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Refactor much to avoid old API and edges with unsqueeze_dims attribute Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix code after review Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
2a7f2f5eb6
commit
10b18a00c6
@ -651,11 +651,11 @@ CNNLayer::Ptr NodeConverter<ngraph::op::Squeeze>::createLayer(const std::shared_
|
||||
}
|
||||
|
||||
template <>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::Unsqueeze>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::v0::Unsqueeze>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
LayerParams params = {layer->get_friendly_name(), "Unsqueeze",
|
||||
details::convertPrecision(layer->get_output_element_type(0))};
|
||||
auto res = std::make_shared<InferenceEngine::CNNLayer>(params);
|
||||
auto castedLayer = ngraph::as_type_ptr<ngraph::op::Unsqueeze>(layer);
|
||||
auto castedLayer = ngraph::as_type_ptr<ngraph::op::v0::Unsqueeze>(layer);
|
||||
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
|
||||
|
||||
return res;
|
||||
|
@ -52,7 +52,7 @@ ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() {
|
||||
if (auto new_max_per_class_const = std::dynamic_pointer_cast<opset1::Constant>(new_max_per_class.get_node_shared_ptr())) {
|
||||
new_max_per_class = opset1::Constant::create(element::i64, Shape{1}, new_max_per_class_const->cast_vector<int64_t>());
|
||||
} else {
|
||||
new_max_per_class = std::make_shared<ngraph::op::Unsqueeze>(
|
||||
new_max_per_class = std::make_shared<ngraph::op::v0::Unsqueeze>(
|
||||
nms->input_value(2),
|
||||
opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(new_max_per_class.get_node_shared_ptr());
|
||||
@ -60,14 +60,14 @@ ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() {
|
||||
}
|
||||
auto new_iou_threshold = nms->input_value(3);
|
||||
if (iou_threshold_rank.get_length() == 0) {
|
||||
new_iou_threshold = std::make_shared<ngraph::op::Unsqueeze>(
|
||||
new_iou_threshold = std::make_shared<ngraph::op::v0::Unsqueeze>(
|
||||
nms->input_value(3),
|
||||
opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(new_iou_threshold.get_node_shared_ptr());
|
||||
}
|
||||
auto new_score_threshold = nms->input_value(4);
|
||||
if (score_threshold_rank.get_length() == 0) {
|
||||
new_score_threshold = std::make_shared<ngraph::op::Unsqueeze>(
|
||||
new_score_threshold = std::make_shared<ngraph::op::v0::Unsqueeze>(
|
||||
nms->input_value(4),
|
||||
opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(new_score_threshold.get_node_shared_ptr());
|
||||
|
@ -58,7 +58,7 @@ void MatmulSqueezeAddTest::SetUp() {
|
||||
auto matmul_0 = std::make_shared<ngraph::op::MatMul>(params[0], constant_0, false, true);
|
||||
|
||||
auto constant_1 = std::make_shared<ngraph::op::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 1 }, std::vector<size_t>{0});
|
||||
auto unsqueeze_0 = std::make_shared<ngraph::op::Unsqueeze>(matmul_0, constant_1);
|
||||
auto unsqueeze_0 = std::make_shared<ngraph::op::v0::Unsqueeze>(matmul_0, constant_1);
|
||||
|
||||
auto constant_2 = ngraph::builder::makeConstant<float>(ngPrc, { 1, inputShape[0], outputSize },
|
||||
CommonTestUtils::generate_float_numbers(inputShape[0] * outputSize, 0, 1, seed), false);
|
||||
|
@ -75,7 +75,7 @@ namespace SubgraphTestsDefinitions {
|
||||
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
|
||||
|
||||
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
|
||||
auto permute_in_params = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}});
|
||||
auto permute_in = std::make_shared<ngraph::opset1::Transpose>(unsqueeze_input, permute_in_params);
|
||||
@ -100,7 +100,7 @@ namespace SubgraphTestsDefinitions {
|
||||
auto lstm = std::make_shared<ngraph::opset4::LSTMCell>(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize);
|
||||
|
||||
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
// body - outputs
|
||||
auto H_o = lstm->output(0);
|
||||
auto C_o = lstm->output(1);
|
||||
@ -158,7 +158,7 @@ namespace SubgraphTestsDefinitions {
|
||||
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
|
||||
|
||||
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
|
||||
auto cell_memory_constant = ngraph::builder::makeConstant<float>(ngPrc, cell_memory_dims, cell_memory_init);
|
||||
|
||||
@ -175,7 +175,7 @@ namespace SubgraphTestsDefinitions {
|
||||
reccurrenceWeightsNode, biasNode, hiddenSize);
|
||||
|
||||
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
|
||||
auto final_reshape_pattern = std::make_shared<ngraph::op::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{4}, std::vector<size_t>({1, 1, 1, hiddenSize}));
|
||||
|
@ -73,7 +73,7 @@ void MultipleLSTMCellTest::SetUp() {
|
||||
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
|
||||
|
||||
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
|
||||
auto permute_in_params = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}});
|
||||
auto permute_in = std::make_shared<ngraph::opset1::Transpose>(unsqueeze_input, permute_in_params);
|
||||
@ -100,7 +100,7 @@ void MultipleLSTMCellTest::SetUp() {
|
||||
auto lstm = std::make_shared<ngraph::opset4::LSTMCell>(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize);
|
||||
|
||||
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
// body - outputs
|
||||
auto H_o = lstm->output(0);
|
||||
auto C_o = lstm->output(1);
|
||||
@ -155,7 +155,7 @@ void MultipleLSTMCellTest::SetUp() {
|
||||
auto lstm_2 = std::make_shared<ngraph::opset4::LSTMCell>(squeeze_2, H_t_2, C_t_2, weightsNode_2, reccurrenceWeightsNode_2, biasNode_2, hiddenSize);
|
||||
|
||||
auto unsqueeze_2_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::op::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
|
||||
// body - outputs
|
||||
auto H_o_2 = lstm_2->output(0);
|
||||
auto C_o_2 = lstm_2->output(1);
|
||||
@ -219,7 +219,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() {
|
||||
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
|
||||
|
||||
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
|
||||
|
||||
// Body 1 - layers
|
||||
auto cell_memory_constant = ngraph::builder::makeConstant<float>(ngPrc, cell_memory_dims, cell_memory_init);
|
||||
@ -236,7 +236,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() {
|
||||
reccurrenceWeightsNode, biasNode, hiddenSize);
|
||||
|
||||
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
|
||||
|
||||
auto first_reshape_pattern = std::make_shared<ngraph::op::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{4}, std::vector<size_t>({1, 1, 1, hiddenSize}));
|
||||
@ -261,7 +261,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() {
|
||||
reccurrenceWeightsNode_2, biasNode_2, hiddenSize);
|
||||
|
||||
auto unsqueeze_2_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::op::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
|
||||
|
||||
auto final_reshape_pattern = std::make_shared<ngraph::op::Constant>(ngraph::element::i64,
|
||||
ngraph::Shape{4}, std::vector<size_t>({1, 1, 1, hiddenSize}));
|
||||
|
@ -67,7 +67,7 @@ std::shared_ptr<ngraph::Function> UnsqueezeFunction::getReference(
|
||||
|
||||
const std::shared_ptr<Node> dequantizationOpBefore = makeDequantization(input, dequantizationBefore);
|
||||
const auto unsqueeze = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Unsqueeze>>(
|
||||
op::Unsqueeze(dequantizationOpBefore, std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ axes.size() }, axes)),
|
||||
op::v0::Unsqueeze(dequantizationOpBefore, std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ axes.size() }, axes)),
|
||||
precisionAfterOperation);
|
||||
const std::shared_ptr<Node> dequantizationOpAfter = makeDequantization(unsqueeze, dequantizationAfter);
|
||||
dequantizationOpAfter->set_friendly_name("output");
|
||||
|
@ -526,7 +526,6 @@ extensions/middle/DeleteControlFlowEdges.py
|
||||
extensions/middle/DeleteNotExecutable.py
|
||||
extensions/middle/DilatedConvolution.py
|
||||
extensions/middle/EltwiseChecker.py
|
||||
extensions/middle/EltwiseInputNormalization.py
|
||||
extensions/middle/EltwiseInputReshape.py
|
||||
extensions/middle/FakeSplitOutputs.py
|
||||
extensions/middle/FusedBatchNormNonConstant.py
|
||||
|
@ -1,48 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
|
||||
|
||||
class EltwiseInputNormalize(EltwiseInputReshape, MiddleReplacementPattern):
|
||||
# This pass should be called directly from pipeline before layout change and other permutations
|
||||
enabled = False
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
eltwise_nodes = graph.get_op_nodes(is_eltwise=True)
|
||||
# Iterating over all Eltwise operations and check that every input has similar shape
|
||||
# in case of different shapes, we inserts new_shape attribute and then call EltwiseInputReshape extension
|
||||
# that insert reshapes (in case of not constant nodes) or directly reshapes values in data nodes for specified
|
||||
# shape
|
||||
for node in eltwise_nodes:
|
||||
output_shape = node.out_node().shape
|
||||
for in_node in node.in_nodes().values():
|
||||
if len(in_node.shape) != len(output_shape):
|
||||
# Set edge attribute new_shape for further transformation pass
|
||||
new_shape = in_node.shape
|
||||
for x in range(len(output_shape) - len(in_node.shape)):
|
||||
new_shape = np.insert(new_shape, 0, 1)
|
||||
|
||||
nx.set_edge_attributes(G=node.graph,
|
||||
values={(in_node.id, node.id, 0): new_shape},
|
||||
name='new_shape')
|
||||
|
||||
super().find_and_replace_pattern(graph)
|
@ -1,202 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.middle.passes.eliminate_test import build_graph
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
|
||||
# The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
|
||||
# dictionary with node attributes.
|
||||
nodes_attributes = {
|
||||
# Placeholder layers
|
||||
'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
|
||||
# Reshape layers
|
||||
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
||||
'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'reshape_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||
'reshape_1_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
|
||||
'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
||||
'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'reshape_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||
'reshape_2_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
|
||||
# Eltwise consumes layers
|
||||
'eltwise_1': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'eltwise_2': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_2_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'eltwise_3': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_3_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'eltwise_4': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_4_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
# Concat
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
|
||||
}
|
||||
|
||||
|
||||
class EltwiseInputNormalizationTest(unittest.TestCase):
|
||||
def test1_not_constant(self):
|
||||
#
|
||||
# data1(1,3,64,64)----. data(1,3,64,64)-------.
|
||||
# data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->...
|
||||
# data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-'
|
||||
#
|
||||
graph = build_graph(nodes_attributes, [
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1', 'placeholder_2_data'),
|
||||
('placeholder_1', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'eltwise_1'),
|
||||
('placeholder_3_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([1, 64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1', 'placeholder_2_data'),
|
||||
('placeholder_1', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('placeholder_3_data', 'reshape_2'),
|
||||
('reshape_2_const', 'reshape_2_const_data'),
|
||||
('reshape_2_const_data', 'reshape_2'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_2', 'reshape_2_data'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('reshape_2_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
|
||||
'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]),
|
||||
'shape': int64_array([4])},
|
||||
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
'reshape_2_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
|
||||
'reshape_2_const_data': {'value': int64_array([1, 1, 64, 1]),
|
||||
'shape': int64_array([4])},
|
||||
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputNormalize()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_mega_hardcore(self):
|
||||
# ORIGINAL GRAPH
|
||||
#
|
||||
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
|
||||
# /\ /\ /\
|
||||
# data2(64,1)-----,-'--------------------------------'------------------------------'
|
||||
# \/ /
|
||||
# data3(64,1)----`-->Eltwise3->data(64,1)----------'
|
||||
#
|
||||
# REFERENCE GRAPH AFTER TRANSFORMATION
|
||||
#
|
||||
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
|
||||
# /\ /\ /\
|
||||
# data2(1,1,64,1)---'--------------------------------'-------------------------------'
|
||||
# /
|
||||
# data4(64,1)-------, Reshape(1,1,64,1)
|
||||
# \/ |
|
||||
# data3(64,1)------`---->Eltwise3->data(64,1)---'
|
||||
#
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_1_data', 'eltwise_2'),
|
||||
('placeholder_2_data', 'eltwise_3'),
|
||||
('placeholder_3_data', 'eltwise_3'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_3_data', 'eltwise_2'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_2_data', 'eltwise_4'),
|
||||
('placeholder_2_data', 'eltwise_4'),
|
||||
('eltwise_4', 'eltwise_4_data'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_1_data', 'eltwise_2'),
|
||||
('placeholder_4_data', 'eltwise_3'),
|
||||
('placeholder_3_data', 'eltwise_3'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_3_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_1_data', 'eltwise_2'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_2_data', 'eltwise_4'),
|
||||
('placeholder_2_data', 'eltwise_4'),
|
||||
('eltwise_4', 'eltwise_4_data'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([1, 1, 64, 1]),
|
||||
'value': np.ones([1, 1, 64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([64, 1])},
|
||||
'placeholder_4_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
|
||||
'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
|
||||
'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]),
|
||||
'shape': int64_array([4])},
|
||||
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputNormalize()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -17,11 +17,13 @@
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.layout import get_features_dim, shape_for_layout
|
||||
from mo.graph.graph import Graph
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.op import Op
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
class Eltwise1DInputReshape(MiddleReplacementPattern):
|
||||
@ -42,9 +44,6 @@ class Eltwise1DInputReshape(MiddleReplacementPattern):
|
||||
"""
|
||||
enabled = False
|
||||
|
||||
def run_after(self):
|
||||
return [EltwiseInputReshape]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
layout = graph.graph['layout']
|
||||
for eltwise_op_node in graph.get_op_nodes(is_eltwise=True):
|
||||
@ -64,59 +63,85 @@ class Eltwise1DInputReshape(MiddleReplacementPattern):
|
||||
reshape_op.out_port(0).connect(eltwise_op_node.in_port(port))
|
||||
|
||||
|
||||
class EltwiseInputReshape(MiddleReplacementPattern):
|
||||
enabled = True
|
||||
force_clean_up = True
|
||||
def compute_unsqueeze_map_for_eltwise(eltwise_node: Node):
|
||||
'''
|
||||
The function computes a map of unsqueeze_dims for each producer of eltwise node.
|
||||
These unsqueeze_dims are needed to normalize input shapes of eltwise node.
|
||||
'''
|
||||
eltwise_shape = eltwise_node.out_port(0).data.get_shape()
|
||||
max_dims = max(
|
||||
[len(port.data.get_shape()) for port in eltwise_node.in_ports().values() if port.data.get_shape() is not None])
|
||||
axis = eltwise_node.soft_get('axis', None)
|
||||
unsqueeze_dims_map = {}
|
||||
for consumer_port in eltwise_node.in_ports().values():
|
||||
producer_port = consumer_port.get_source()
|
||||
producer_shape = producer_port.data.get_shape()
|
||||
unsqueeze_dims = int64_array([])
|
||||
|
||||
def run_after(self):
|
||||
from extensions.middle.pass_separator import MiddleStart
|
||||
return [MiddleStart]
|
||||
# 1. Compute unsqueeze dimensions in the tail
|
||||
if len(producer_shape) != max_dims and len(producer_shape) > 0 and axis is not None:
|
||||
num_unsqueeze_dims = max_dims - axis - len(producer_shape)
|
||||
if num_unsqueeze_dims > 0:
|
||||
unsqueeze_dims = np.arange(len(producer_shape), len(producer_shape) + num_unsqueeze_dims,
|
||||
dtype=np.int64)
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_data_nodes():
|
||||
# Get all requested shapes for current node
|
||||
# This mapping will contain pairs like {shape:[list of consumers nodes]}
|
||||
mapping = {}
|
||||
for consumer in node.out_nodes():
|
||||
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
|
||||
if 'new_shape' in edge_attrs:
|
||||
if np.array_equal(edge_attrs['new_shape'], node.shape):
|
||||
continue
|
||||
new_shape = tuple([x for x in edge_attrs['new_shape']])
|
||||
if not new_shape in mapping:
|
||||
mapping.update({new_shape: [consumer]})
|
||||
else:
|
||||
mapping[new_shape].append(consumer)
|
||||
# 2. Compute unsqueeze dimensions in the head
|
||||
unsqueeze_dims_head = np.arange(len(eltwise_shape) - len(producer_shape) - len(unsqueeze_dims), dtype=np.int64)
|
||||
|
||||
if node.has_valid('value'):
|
||||
# Check that requested shape are the same
|
||||
# In case if they are different, we duplicate them
|
||||
for shape_key in mapping.keys():
|
||||
shape = list(shape_key)
|
||||
new_value = np.reshape(node.value, shape)
|
||||
node_copy = Op.create_input_data_node(graph, node.id + '/copy', value=np.array(new_value))
|
||||
for consumer in mapping[shape_key]:
|
||||
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
|
||||
del edge_attrs['new_shape']
|
||||
# Pay attention that unsqueeze dims order makes sense
|
||||
# since shape is normalized in the tail first and after in the head
|
||||
unsqueeze_dims = np.concatenate((unsqueeze_dims, unsqueeze_dims_head))
|
||||
unsqueeze_dims_map[producer_port] = unsqueeze_dims
|
||||
|
||||
# Remove edge from previous data node and connect new data node with its consumer
|
||||
graph.remove_edge(node.id, consumer.id)
|
||||
graph.add_edge(node_copy.id, consumer.id, **edge_attrs)
|
||||
return unsqueeze_dims_map
|
||||
|
||||
|
||||
def normalize_eltwise_inputs(graph: Graph):
|
||||
'''
|
||||
The function normalizes input shapes for eltwise nodes.
|
||||
In the first step the function gets to know which shapes/unsqueeze dims for inputs are required for normalization.
|
||||
In the second step the function inserts Unsqueeze nodes between non-normalized inputs and eltwise nodes.
|
||||
'''
|
||||
# Generate a map for producers of eltwise nodes with non-normalized shapes
|
||||
# and in this map every producer has another map that reflects normalized shape
|
||||
# to a list of eltwise consumers
|
||||
mapping = {}
|
||||
for eltwise_node in graph.get_op_nodes(is_eltwise=True):
|
||||
unsqueeze_dims_map = compute_unsqueeze_map_for_eltwise(eltwise_node)
|
||||
for consumer_port in eltwise_node.in_ports().values():
|
||||
producer_port = consumer_port.get_source()
|
||||
unsqueeze_dims = unsqueeze_dims_map[producer_port]
|
||||
if unsqueeze_dims is not None and len(unsqueeze_dims) > 0:
|
||||
unsqueeze_dims = tuple([x for x in unsqueeze_dims])
|
||||
if producer_port not in mapping:
|
||||
mapping.update({producer_port: {unsqueeze_dims: [consumer_port]}})
|
||||
elif unsqueeze_dims not in mapping[producer_port]:
|
||||
mapping[producer_port].update({unsqueeze_dims: [consumer_port]})
|
||||
else:
|
||||
mapping[producer_port][unsqueeze_dims].append(consumer_port)
|
||||
|
||||
# Walk through each produced in the map and insert Unsqueeze nodes between a producer and eltwise nodes
|
||||
for producer_port in mapping.keys():
|
||||
producer_node = producer_port.node
|
||||
for unsqueeze_dims in mapping[producer_port].keys():
|
||||
unsqueeze_name = producer_node.soft_get('name', producer_node.id) + '/EltwiseUnsqueeze'
|
||||
unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(list(unsqueeze_dims))},
|
||||
{'name': unsqueeze_name})
|
||||
unsqueeze_node.in_port(0).connect(producer_port)
|
||||
|
||||
# Insert Unsqueeze with determined unsqueeze dimensions between the current producer and eltwise node
|
||||
for consumer_port in mapping[producer_port][unsqueeze_dims]:
|
||||
consumer_port.connect(unsqueeze_node.out_port(0))
|
||||
|
||||
# The shape and value adjustments must be explicitly done within the transformation
|
||||
# since the transformation is called from Fusing transformation that excludes
|
||||
# automatic call of shape inference pass
|
||||
producer_port_value = producer_port.data.get_value()
|
||||
producer_port_shape = producer_port.data.get_shape()
|
||||
new_shape = producer_port_shape.copy()
|
||||
for unsqueeze_dim in unsqueeze_dims:
|
||||
new_shape = np.insert(new_shape, unsqueeze_dim, 1)
|
||||
if producer_port_value is not None:
|
||||
unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape))
|
||||
else:
|
||||
# Insert Reshape layer between data node and consumer
|
||||
for shape_key in mapping.keys():
|
||||
shape = list(shape_key)
|
||||
reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
|
||||
reshape = Reshape(graph, attrs={'name': reshape_name})
|
||||
reshape_dim = Const(graph,
|
||||
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
|
||||
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
|
||||
|
||||
# Iterate over consumers and reconnect them to Reshape layer output
|
||||
for consumer in mapping[shape_key]:
|
||||
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
|
||||
del edge_attrs['new_shape']
|
||||
|
||||
# Reconnect edge from original data node to Reshape output datanode
|
||||
graph.remove_edge(node.id, consumer.id)
|
||||
graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)
|
||||
unsqueeze_node.out_port(0).data.set_shape(new_shape)
|
||||
|
@ -18,7 +18,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
|
||||
from extensions.middle.EltwiseInputReshape import normalize_eltwise_inputs
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.middle.passes.eliminate_test import build_graph
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
@ -28,47 +28,216 @@ from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
nodes_attributes = {
|
||||
# Placeholder layers
|
||||
'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_2': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_3': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
|
||||
|
||||
# Reshape layers
|
||||
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
||||
'reshape_1': {'type': 'Unsqueeze', 'value': None, 'kind': 'op', 'op': 'Unsqueeze'},
|
||||
'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'reshape_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||
'reshape_1_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
|
||||
'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
|
||||
'reshape_2': {'type': 'Unsqueeze', 'value': None, 'kind': 'op', 'op': 'Unsqueeze'},
|
||||
'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
'reshape_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
|
||||
'reshape_2_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
|
||||
# Fake consumes layers
|
||||
'consumer_1': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
|
||||
'consumer_2': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
|
||||
'consumer_3': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
|
||||
# Eltwise consumes layers
|
||||
'eltwise_1': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_1_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'eltwise_2': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_2_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'eltwise_3': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_3_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
'eltwise_4': {'kind': 'op', 'is_eltwise': True},
|
||||
'eltwise_4_data': {'value': None, 'shape': None, 'kind': 'data'},
|
||||
|
||||
# Concat
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
|
||||
}
|
||||
|
||||
|
||||
class EltwiseInputReshapeTest(unittest.TestCase):
|
||||
class EltwiseInputNormalizationTest(unittest.TestCase):
|
||||
def test1_not_constant(self):
|
||||
#
|
||||
# data1(1,3,64,64)----. data(1,3,64,64)-------.
|
||||
# data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->...
|
||||
# data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-'
|
||||
#
|
||||
graph = build_graph(nodes_attributes, [
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1', 'placeholder_2_data'),
|
||||
('placeholder_1', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'eltwise_1'),
|
||||
('placeholder_3_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([1, 64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1', 'placeholder_2_data'),
|
||||
('placeholder_1', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('placeholder_3_data', 'reshape_2'),
|
||||
('reshape_2_const', 'reshape_2_const_data'),
|
||||
('reshape_2_const_data', 'reshape_2'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_2', 'reshape_2_data'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('reshape_2_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])},
|
||||
'reshape_1_const_data': {'value': int64_array([0]),
|
||||
'shape': int64_array([1])},
|
||||
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
'reshape_2_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
|
||||
'reshape_2_const_data': {'value': int64_array([0, 1]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_mega_hardcore(self):
|
||||
# ORIGINAL GRAPH
|
||||
#
|
||||
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
|
||||
# /\ /\ /\
|
||||
# data2(64,1)-----,-'--------------------------------'------------------------------'
|
||||
# \/ /
|
||||
# data3(64,1)----`-->Eltwise3->data(64,1)----------'
|
||||
#
|
||||
# REFERENCE GRAPH AFTER TRANSFORMATION
|
||||
#
|
||||
# data1(1,3,64,64)---------------------,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
|
||||
# /\ /\ /\
|
||||
# data2(64,1)-,- Reshape1(1,1,64,64)--'--------------------------------o-------------------------------'
|
||||
# | |
|
||||
# | Reshape(1,1,64,1)
|
||||
# \/ |
|
||||
# data3(64,1)----------->Eltwise3->data(64,1)--------------------------'
|
||||
#
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_2', 'placeholder_2_data'),
|
||||
('placeholder_3', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_1_data', 'eltwise_2'),
|
||||
('placeholder_2_data', 'eltwise_3'),
|
||||
('placeholder_3_data', 'eltwise_3'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_3_data', 'eltwise_2'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_2_data', 'eltwise_4'),
|
||||
('placeholder_2_data', 'eltwise_4'),
|
||||
('eltwise_4', 'eltwise_4_data'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_2', 'placeholder_2_data'),
|
||||
('placeholder_3', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_1_data', 'eltwise_2'),
|
||||
('placeholder_2_data', 'eltwise_3'),
|
||||
('placeholder_3_data', 'eltwise_3'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_3_data', 'reshape_2'),
|
||||
('reshape_2_const', 'reshape_2_const_data'),
|
||||
('reshape_2_const_data', 'reshape_2'),
|
||||
('reshape_2', 'reshape_2_data'),
|
||||
('reshape_2_data', 'eltwise_2'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_2_data', 'eltwise_4'),
|
||||
('reshape_1_data', 'eltwise_4'),
|
||||
('eltwise_4', 'eltwise_4_data'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([64, 1]),
|
||||
'value': np.ones([64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([64, 1])},
|
||||
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
|
||||
'reshape_1_const_data': {'value': int64_array([0, 1]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
|
||||
'reshape_2_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
|
||||
'reshape_2_const_data': {'value': int64_array([0, 1]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_3_data': {'shape': np.array([64, 1])},
|
||||
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test2_not_constant(self):
|
||||
# ,-------------->consumer3 ,------------>consumer3
|
||||
# data---(new_shape1)-->consumer1 => data---->Reshape-->consumer1
|
||||
# `-(new_shape2)-->consumer2 `-->Reshape-->consumer2
|
||||
#
|
||||
graph = build_graph(nodes_attributes, [
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
|
||||
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])},
|
||||
'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
'eltwise_2_data': {'shape': int64_array([1, 1, 3])},
|
||||
'eltwise_3_data': {'shape': int64_array([1, 3])},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[
|
||||
@ -79,32 +248,34 @@ class EltwiseInputReshapeTest(unittest.TestCase):
|
||||
('placeholder_1_data', 'reshape_2'),
|
||||
('reshape_2_const', 'reshape_2_const_data'),
|
||||
('reshape_2_const_data', 'reshape_2'),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_2', 'reshape_2_data'),
|
||||
('reshape_1_data', 'consumer_1'),
|
||||
('reshape_2_data', 'consumer_2'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('reshape_2_data', 'eltwise_2'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])},
|
||||
'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])},
|
||||
'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]),
|
||||
'shape': int64_array([4])},
|
||||
'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])},
|
||||
'reshape_2_const': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])},
|
||||
'reshape_2_const_data': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])},
|
||||
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
|
||||
'reshape_1_const_data': {'value': int64_array([0, 1]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
'reshape_2_const': {'value': int64_array([0]), 'shape': int64_array([1])},
|
||||
'reshape_2_const_data': {'value': int64_array([0]), 'shape': int64_array([1])},
|
||||
'reshape_2_data': {'shape': int64_array([1, 1, 3])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputReshape()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test2_not_constant(self):
|
||||
def test3_not_constant(self):
|
||||
# ,--------------->consumer3 ,----------->consumer3
|
||||
# data---(new_shape1)-->consumer1 => data-->Reshape-->consumer1
|
||||
# `-(new_shape1)-->consumer2 `-->consumer2
|
||||
@ -112,14 +283,22 @@ class EltwiseInputReshapeTest(unittest.TestCase):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
|
||||
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3, 1, 1])}),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])},
|
||||
'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
'eltwise_2_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
'eltwise_3_data': {'shape': int64_array([1, 3])},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[
|
||||
@ -127,123 +306,239 @@ class EltwiseInputReshapeTest(unittest.TestCase):
|
||||
('placeholder_1_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_1_data', 'consumer_1'),
|
||||
('reshape_1_data', 'consumer_2'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('reshape_1_data', 'eltwise_2'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])},
|
||||
'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])},
|
||||
'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]),
|
||||
'shape': int64_array([4])},
|
||||
'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])},
|
||||
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
|
||||
'reshape_1_const_data': {'value': int64_array([0, 1]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputReshape()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test3_constant(self):
|
||||
# ,--------------->consumer3 data-->consumer3
|
||||
# data---(new_shape1)-->consumer1 => data-->consumer1
|
||||
# `-(new_shape2)-->consumer2 data-->consumer2
|
||||
#
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
|
||||
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}},
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'consumer_1'),
|
||||
('placeholder_2_data', 'consumer_2'),
|
||||
('placeholder_3_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3, 1, 1]),
|
||||
'value': np.ones([1, 3, 1, 1])},
|
||||
'placeholder_2_data': {'shape': int64_array([1, 1, 3]), 'value': np.ones([1, 1, 3])},
|
||||
'placeholder_3_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputReshape()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test4_constant(self):
|
||||
# ,--------------->consumer3 ,-->consumer3
|
||||
# data---(new_shape1)-->consumer1 => data-->consumer1
|
||||
# `-(new_shape2)-->consumer2 `->consumer2
|
||||
# ,--------------->consumer3 ,------------>consumer3
|
||||
# data---(new_shape1)-->consumer1 => data--->reshape1-->consumer1
|
||||
# `-(new_shape2)-->consumer2 `->reshape2-->consumer2
|
||||
#
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([3, 1, 1])}),
|
||||
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([3, 1, 1])}),
|
||||
('placeholder_1_data', 'consumer_3', {'new_shape': int64_array([3, 1, 1])}),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}},
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
|
||||
'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
'eltwise_2_data': {'shape': int64_array([1, 1, 3])},
|
||||
'eltwise_3_data': {'shape': int64_array([1, 3])},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'consumer_1'),
|
||||
('placeholder_1_data', 'consumer_2'),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'reshape_2'),
|
||||
('reshape_2_const', 'reshape_2_const_data'),
|
||||
('reshape_2_const_data', 'reshape_2'),
|
||||
('reshape_2', 'reshape_2_data'),
|
||||
('reshape_2_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([3, 1, 1]), 'value': np.ones([3, 1, 1])}
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
|
||||
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
|
||||
'reshape_1_const_data': {'value': int64_array([0, 1]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])},
|
||||
|
||||
'reshape_2_const': {'value': int64_array([0]), 'shape': int64_array([1])},
|
||||
'reshape_2_const_data': {'value': int64_array([0]),
|
||||
'shape': int64_array([1])},
|
||||
'reshape_2_data': {'shape': int64_array([1, 1, 3])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputReshape()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test5_not_constant(self):
|
||||
def test5_constant(self):
|
||||
# ,-(new_shape)-->consumer3 ,-->consumer3
|
||||
# data---(new_shape)-->consumer1 => data-->reshape---->consumer1
|
||||
# `-(new_shape)-->consumer2 `-->consumer2
|
||||
#
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
|
||||
'eltwise_1_data': {'shape': int64_array([1, 1, 3])},
|
||||
'eltwise_2_data': {'shape': int64_array([1, 1, 3])},
|
||||
'eltwise_3_data': {'shape': int64_array([1, 1, 3])},
|
||||
},
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('reshape_1_data', 'eltwise_2'),
|
||||
('reshape_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
|
||||
'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])},
|
||||
'reshape_1_const_data': {'value': int64_array([0]),
|
||||
'shape': int64_array([1])},
|
||||
'reshape_1_data': {'shape': int64_array([1, 1, 3])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test6_not_constant(self):
|
||||
# ,--------------->consumer3 ,->consumer3
|
||||
# data---(new_shape1)-->consumer1 => data----->consumer1
|
||||
# `-(new_shape1)-->consumer2 `-->consumer2
|
||||
#
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}),
|
||||
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])},
|
||||
'eltwise_1_data': {'shape': int64_array([1, 3])},
|
||||
'eltwise_2_data': {'shape': int64_array([1, 3])},
|
||||
'eltwise_3_data': {'shape': int64_array([1, 3])},
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}),
|
||||
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}),
|
||||
('placeholder_1_data', 'consumer_3'),
|
||||
('consumer_1', 'concat'),
|
||||
('consumer_2', 'concat'),
|
||||
('consumer_3', 'concat'),
|
||||
[('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_1_data', 'eltwise_2'),
|
||||
('placeholder_1_data', 'eltwise_3'),
|
||||
('eltwise_1', 'eltwise_1_data'),
|
||||
('eltwise_2', 'eltwise_2_data'),
|
||||
('eltwise_3', 'eltwise_3_data'),
|
||||
('eltwise_1_data', 'concat'),
|
||||
('eltwise_2_data', 'concat'),
|
||||
('eltwise_3_data', 'concat'),
|
||||
],
|
||||
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
|
||||
|
||||
pattern = EltwiseInputReshape()
|
||||
pattern.find_and_replace_pattern(graph)
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test7_axis1_not_constant(self):
|
||||
#
|
||||
# data1(1,3,64,64)----. data(1,3,64,64)-------.
|
||||
# data2(3,64,1)-------->Eltwise-->data(1,3,64,64)=> data(3,64,1)->Unsqueeze(0)->data(1,3,64,1)-->Eltwise->...
|
||||
# data3(3,1)------' data(3,1)->Unsqueeze(2, 0)->data(1,3,1,1)-'
|
||||
#
|
||||
graph = build_graph(nodes_attributes, [
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_2', 'placeholder_2_data'),
|
||||
('placeholder_3', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'eltwise_1'),
|
||||
('placeholder_3_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([3, 64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([3, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'eltwise_1' : {'axis': 1}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[
|
||||
('placeholder_1', 'placeholder_1_data'),
|
||||
('placeholder_2', 'placeholder_2_data'),
|
||||
('placeholder_3', 'placeholder_3_data'),
|
||||
('placeholder_1_data', 'eltwise_1'),
|
||||
('placeholder_2_data', 'reshape_1'),
|
||||
('reshape_1_const', 'reshape_1_const_data'),
|
||||
('reshape_1_const_data', 'reshape_1'),
|
||||
('placeholder_3_data', 'reshape_2'),
|
||||
('reshape_2_const', 'reshape_2_const_data'),
|
||||
('reshape_2_const_data', 'reshape_2'),
|
||||
('reshape_1', 'reshape_1_data'),
|
||||
('reshape_2', 'reshape_2_data'),
|
||||
('reshape_1_data', 'eltwise_1'),
|
||||
('reshape_2_data', 'eltwise_1'),
|
||||
('eltwise_1', 'eltwise_1_data')
|
||||
],
|
||||
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
|
||||
'placeholder_2_data': {'shape': np.array([3, 64, 1])},
|
||||
'placeholder_3_data': {'shape': np.array([3, 1])},
|
||||
'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])},
|
||||
'reshape_1_const_data': {'value': int64_array([0]),
|
||||
'shape': int64_array([1])},
|
||||
'reshape_1_data': {'shape': np.array([1, 3, 64, 1])},
|
||||
'reshape_2_const': {'value': int64_array([2, 0]), 'shape': int64_array([2])},
|
||||
'reshape_2_const_data': {'value': int64_array([2, 0]),
|
||||
'shape': int64_array([2])},
|
||||
'reshape_2_data': {'shape': np.array([1, 3, 1, 1])},
|
||||
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
|
||||
}, nodes_with_edges_only=True)
|
||||
|
||||
normalize_eltwise_inputs(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -16,7 +16,7 @@
|
||||
from extensions.front.div import Div
|
||||
from extensions.front.sub import Sub
|
||||
from extensions.middle.AddFakeQuantizeFuse import AddFakeQuantizeFuse
|
||||
from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
|
||||
from extensions.middle.EltwiseInputReshape import normalize_eltwise_inputs
|
||||
from extensions.middle.MulFakeQuantizeFuse import MulFakeQuantizeFuse
|
||||
from extensions.middle.RemoveRedundantReshapes import RemoveRedundantReshapes
|
||||
|
||||
@ -82,7 +82,7 @@ class Fusing(MiddleReplacementPattern):
|
||||
for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence)
|
||||
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
|
||||
|
||||
EltwiseInputNormalize().find_and_replace_pattern(graph)
|
||||
normalize_eltwise_inputs(graph)
|
||||
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
|
||||
|
||||
# Fusing linear operation to Convolution
|
||||
@ -96,7 +96,7 @@ class Fusing(MiddleReplacementPattern):
|
||||
for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
|
||||
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
|
||||
|
||||
EltwiseInputNormalize().find_and_replace_pattern(graph)
|
||||
normalize_eltwise_inputs(graph)
|
||||
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
|
||||
|
||||
MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)
|
||||
|
@ -14,17 +14,14 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node
|
||||
|
||||
|
||||
def eltwise_infer(node, op=None, **kwargs):
|
||||
raw_inputs = [(inp, attr) for inp, attr in node.get_sorted_inputs()
|
||||
if 'control_flow_edge' not in attr or not attr['control_flow_edge']]
|
||||
inputs = [Node(node.graph, inp) for inp, attr in raw_inputs]
|
||||
shapes = [node.graph.node[inp]['shape'] for inp, attr in raw_inputs]
|
||||
values = [node.graph.node[inp]['value'] for inp, attr in raw_inputs]
|
||||
|
||||
@ -53,14 +50,6 @@ def eltwise_infer(node, op=None, **kwargs):
|
||||
|
||||
shapes[id] = new_shape
|
||||
|
||||
# Save shape for further transformation that applies this shapes for input nodes
|
||||
# We set new_shape attribute on edge for given input node
|
||||
edge_attrs = node.graph.get_edge_data(inputs[id].id, node.id)[0]
|
||||
|
||||
nx.set_edge_attributes(G=node.graph,
|
||||
values={(inputs[id].id, node.id, 0): new_shape},
|
||||
name='new_shape')
|
||||
|
||||
# Reshape value to correctly calculate output shape
|
||||
if values[id] is not None:
|
||||
values[id] = np.reshape(values[id], new_shape)
|
||||
|
@ -38,13 +38,13 @@ def get_canonical_axis_index_node(rank: Node, axis: int) -> Node:
|
||||
graph = rank.graph
|
||||
name = rank.soft_get('name', rank.id)
|
||||
if axis < 0:
|
||||
axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array([axis])}).create_node()
|
||||
axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array(axis)}).create_node()
|
||||
add = Add(graph, {'name': name + '/positive_axis'}).create_node()
|
||||
rank.out_port(0).connect(add.in_port(0))
|
||||
axis.out_port(0).connect(add.in_port(1))
|
||||
return add
|
||||
else:
|
||||
return Const(graph, {'name': name + '/positive_axis', 'value': int64_array([axis])}).create_node()
|
||||
return Const(graph, {'name': name + '/positive_axis', 'value': int64_array(axis)}).create_node()
|
||||
|
||||
|
||||
def get_range_node_of_idxs(rank: Node, begin: int, end: int,
|
||||
@ -66,20 +66,20 @@ def get_range_node_of_idxs(rank: Node, begin: int, end: int,
|
||||
end_idx = get_canonical_axis_index_node(rank, end)
|
||||
|
||||
if not include_begin:
|
||||
const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
|
||||
const = Const(graph, {'value': int64_array(1), 'name': name + '/exclude_begin/value'}).create_node()
|
||||
add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
|
||||
start_idx.out_port(0).connect(add.in_port(0))
|
||||
const.out_port(0).connect(add.in_port(1))
|
||||
start_idx = add
|
||||
|
||||
if include_end:
|
||||
const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
|
||||
const = Const(graph, {'value': int64_array(1), 'name': name + '/including_end/value'}).create_node()
|
||||
add = Add(graph, {'name': name + '/including_end'}).create_node()
|
||||
end_idx.out_port(0).connect(add.in_port(0))
|
||||
const.out_port(0).connect(add.in_port(1))
|
||||
end_idx = add
|
||||
|
||||
delta = Const(graph, {'name': name + '/delta', 'value': int64_array([1])}).create_node()
|
||||
delta = Const(graph, {'name': name + '/delta', 'value': int64_array(1)}).create_node()
|
||||
range_node = Range(graph, {'name': name + '/range_idxs'}).create_node()
|
||||
|
||||
start_idx.out_port(0).connect(range_node.in_port(0))
|
||||
|
@ -21,9 +21,6 @@
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -31,7 +28,7 @@ namespace ngraph
|
||||
{
|
||||
namespace v0
|
||||
{
|
||||
class NGRAPH_API Unsqueeze : public ngraph::op::util::FusedOp
|
||||
class NGRAPH_API Unsqueeze : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -39,9 +36,7 @@ namespace ngraph
|
||||
Unsqueeze() = default;
|
||||
Unsqueeze(const Output<Node>& data, const Output<Node>& axes);
|
||||
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
@ -55,5 +50,3 @@ namespace ngraph
|
||||
using v0::Unsqueeze;
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -29,17 +29,15 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::Unsqueeze, "Unsqueeze", 0);
|
||||
|
||||
op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
|
||||
: FusedOp({data, axes})
|
||||
op::v0::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
|
||||
: Op({data, axes})
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::Unsqueeze::pre_validate_and_infer_types()
|
||||
void op::v0::Unsqueeze::validate_and_infer_types()
|
||||
{
|
||||
const auto data = input_value(0);
|
||||
auto data_partial_shape = data.get_partial_shape();
|
||||
@ -79,24 +77,12 @@ void op::Unsqueeze::pre_validate_and_infer_types()
|
||||
set_output_type(0, get_input_element_type(0), PartialShape{output_shape});
|
||||
}
|
||||
|
||||
OutputVector op::Unsqueeze::decompose_op() const
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
(get_output_partial_shape(0).is_static()),
|
||||
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
|
||||
auto data = input_value(0);
|
||||
auto data_shape = data.get_shape();
|
||||
auto output_shape = get_output_shape(0);
|
||||
return {builder::opset1::reshape(data, output_shape)};
|
||||
}
|
||||
|
||||
bool ngraph::op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor)
|
||||
bool op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::Unsqueeze::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
shared_ptr<Node> op::v0::Unsqueeze::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
if (new_args.size() != 2)
|
||||
{
|
||||
|
@ -1487,7 +1487,7 @@ NGRAPH_TEST(${BACKEND_NAME}, unsqueeze)
|
||||
auto data_node = make_shared<op::Parameter>(element::f32, Shape{4, 2});
|
||||
auto axes_node =
|
||||
make_shared<ngraph::op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
|
||||
auto squeeze = make_shared<op::Unsqueeze>(data_node, axes_node);
|
||||
auto squeeze = make_shared<op::v0::Unsqueeze>(data_node, axes_node);
|
||||
|
||||
auto function = make_shared<Function>(NodeVector{squeeze}, ParameterVector{data_node});
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
@ -189,7 +189,7 @@ TEST(constant_folding, constant_unsqueeze)
|
||||
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
|
||||
vector<int64_t> values_axes{2, 3};
|
||||
auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
|
||||
auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
|
||||
auto unsqueeze = make_shared<op::v0::Unsqueeze>(constant, constant_axes);
|
||||
unsqueeze->set_friendly_name("test");
|
||||
auto f = make_shared<Function>(unsqueeze, ParameterVector{});
|
||||
|
||||
@ -197,7 +197,7 @@ TEST(constant_folding, constant_unsqueeze)
|
||||
pass_manager.register_pass<pass::ConstantFolding>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 0);
|
||||
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
|
||||
|
||||
auto new_const =
|
||||
|
@ -877,7 +877,7 @@ namespace
|
||||
|
||||
void op_is_Unsqueeze()
|
||||
{
|
||||
op::Unsqueeze node;
|
||||
op::v0::Unsqueeze node;
|
||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||
|
@ -126,5 +126,5 @@ NGRAPH_OP(Tan, ngraph::op)
|
||||
NGRAPH_OP(Tanh, ngraph::op)
|
||||
NGRAPH_OP(TensorIterator, ngraph::op)
|
||||
NGRAPH_OP(Tile, ngraph::op::v0)
|
||||
NGRAPH_OP(Unsqueeze, ngraph::op)
|
||||
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
|
||||
NGRAPH_OP(Xor, ngraph::op)
|
||||
|
@ -26,7 +26,7 @@ TEST(type_prop, unsqueeze)
|
||||
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8});
|
||||
auto axes_node =
|
||||
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
|
||||
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
|
||||
auto unsqueeze = make_shared<op::v0::Unsqueeze>(param, axes_node);
|
||||
|
||||
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
|
||||
ASSERT_EQ(unsqueeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
|
||||
@ -37,7 +37,7 @@ TEST(type_prop, unsqueeze_dynamic)
|
||||
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(5));
|
||||
auto axes_node =
|
||||
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
|
||||
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
|
||||
auto unsqueeze = make_shared<op::v0::Unsqueeze>(param, axes_node);
|
||||
|
||||
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
|
||||
EXPECT_TRUE(
|
||||
|
Loading…
Reference in New Issue
Block a user