Improve reshapeability of models with eltwise nodes influencing shapes (#2767)

* Fix ElementwiseInputReshape transformation

Reshape node always needs to be inserted
in order to preserve ShapeOf nodes (reshapability of a model) that can potentially be above
elementwise node.

Refactor EltwiseInputReshape_test and EltwiseInputNormalization_test since the logic of maintaining reshape for eltwise has been changed.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Merge EltwiseInputNormalization and EltwiseInputReshape transformations

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Remove Unsqueeze from Fused_op

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix code after code review #1

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix code after review #2

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix code review #4

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Perform full normalization based on shapes of all inputs to eltwise

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Refactor much to avoid old API and edges with unsqueeze_dims attribute

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix code after review

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2020-11-16 09:50:41 +03:00 committed by GitHub
parent 2a7f2f5eb6
commit 10b18a00c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 535 additions and 498 deletions

View File

@ -651,11 +651,11 @@ CNNLayer::Ptr NodeConverter<ngraph::op::Squeeze>::createLayer(const std::shared_
}
template <>
CNNLayer::Ptr NodeConverter<ngraph::op::Unsqueeze>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
CNNLayer::Ptr NodeConverter<ngraph::op::v0::Unsqueeze>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
LayerParams params = {layer->get_friendly_name(), "Unsqueeze",
details::convertPrecision(layer->get_output_element_type(0))};
auto res = std::make_shared<InferenceEngine::CNNLayer>(params);
auto castedLayer = ngraph::as_type_ptr<ngraph::op::Unsqueeze>(layer);
auto castedLayer = ngraph::as_type_ptr<ngraph::op::v0::Unsqueeze>(layer);
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
return res;

View File

@ -52,7 +52,7 @@ ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() {
if (auto new_max_per_class_const = std::dynamic_pointer_cast<opset1::Constant>(new_max_per_class.get_node_shared_ptr())) {
new_max_per_class = opset1::Constant::create(element::i64, Shape{1}, new_max_per_class_const->cast_vector<int64_t>());
} else {
new_max_per_class = std::make_shared<ngraph::op::Unsqueeze>(
new_max_per_class = std::make_shared<ngraph::op::v0::Unsqueeze>(
nms->input_value(2),
opset1::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(new_max_per_class.get_node_shared_ptr());
@ -60,14 +60,14 @@ ngraph::pass::ConvertNMSToNMSIEMatcher::ConvertNMSToNMSIEMatcher() {
}
auto new_iou_threshold = nms->input_value(3);
if (iou_threshold_rank.get_length() == 0) {
new_iou_threshold = std::make_shared<ngraph::op::Unsqueeze>(
new_iou_threshold = std::make_shared<ngraph::op::v0::Unsqueeze>(
nms->input_value(3),
opset1::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(new_iou_threshold.get_node_shared_ptr());
}
auto new_score_threshold = nms->input_value(4);
if (score_threshold_rank.get_length() == 0) {
new_score_threshold = std::make_shared<ngraph::op::Unsqueeze>(
new_score_threshold = std::make_shared<ngraph::op::v0::Unsqueeze>(
nms->input_value(4),
opset1::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(new_score_threshold.get_node_shared_ptr());

View File

@ -58,7 +58,7 @@ void MatmulSqueezeAddTest::SetUp() {
auto matmul_0 = std::make_shared<ngraph::op::MatMul>(params[0], constant_0, false, true);
auto constant_1 = std::make_shared<ngraph::op::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{ 1 }, std::vector<size_t>{0});
auto unsqueeze_0 = std::make_shared<ngraph::op::Unsqueeze>(matmul_0, constant_1);
auto unsqueeze_0 = std::make_shared<ngraph::op::v0::Unsqueeze>(matmul_0, constant_1);
auto constant_2 = ngraph::builder::makeConstant<float>(ngPrc, { 1, inputShape[0], outputSize },
CommonTestUtils::generate_float_numbers(inputShape[0] * outputSize, 0, 1, seed), false);

View File

@ -75,7 +75,7 @@ namespace SubgraphTestsDefinitions {
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
auto permute_in_params = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}});
auto permute_in = std::make_shared<ngraph::opset1::Transpose>(unsqueeze_input, permute_in_params);
@ -100,7 +100,7 @@ namespace SubgraphTestsDefinitions {
auto lstm = std::make_shared<ngraph::opset4::LSTMCell>(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize);
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
// body - outputs
auto H_o = lstm->output(0);
auto C_o = lstm->output(1);
@ -158,7 +158,7 @@ namespace SubgraphTestsDefinitions {
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
auto cell_memory_constant = ngraph::builder::makeConstant<float>(ngPrc, cell_memory_dims, cell_memory_init);
@ -175,7 +175,7 @@ namespace SubgraphTestsDefinitions {
reccurrenceWeightsNode, biasNode, hiddenSize);
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
auto final_reshape_pattern = std::make_shared<ngraph::op::Constant>(ngraph::element::i64,
ngraph::Shape{4}, std::vector<size_t>({1, 1, 1, hiddenSize}));

View File

@ -73,7 +73,7 @@ void MultipleLSTMCellTest::SetUp() {
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
auto permute_in_params = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, ngraph::Shape{{1, 0, 2}});
auto permute_in = std::make_shared<ngraph::opset1::Transpose>(unsqueeze_input, permute_in_params);
@ -100,7 +100,7 @@ void MultipleLSTMCellTest::SetUp() {
auto lstm = std::make_shared<ngraph::opset4::LSTMCell>(squeeze, H_t, C_t, weightsNode, reccurrenceWeightsNode, biasNode, hiddenSize);
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
// body - outputs
auto H_o = lstm->output(0);
auto C_o = lstm->output(1);
@ -155,7 +155,7 @@ void MultipleLSTMCellTest::SetUp() {
auto lstm_2 = std::make_shared<ngraph::opset4::LSTMCell>(squeeze_2, H_t_2, C_t_2, weightsNode_2, reccurrenceWeightsNode_2, biasNode_2, hiddenSize);
auto unsqueeze_2_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze_2 = std::make_shared<ngraph::op::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
auto unsqueeze_2 = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
// body - outputs
auto H_o_2 = lstm_2->output(0);
auto C_o_2 = lstm_2->output(1);
@ -219,7 +219,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() {
auto mul = ngraph::builder::makeEltwise(add, input_mul_const, ngraph::helpers::EltwiseTypes::MULTIPLY);
auto unsqueeze_input_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze_input = std::make_shared<ngraph::op::Unsqueeze>(mul, unsqueeze_input_const);
auto unsqueeze_input = std::make_shared<ngraph::op::v0::Unsqueeze>(mul, unsqueeze_input_const);
// Body 1 - layers
auto cell_memory_constant = ngraph::builder::makeConstant<float>(ngPrc, cell_memory_dims, cell_memory_init);
@ -236,7 +236,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() {
reccurrenceWeightsNode, biasNode, hiddenSize);
auto unsqueeze_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze = std::make_shared<ngraph::op::Unsqueeze>(lstm->output(0), unsqueeze_const);
auto unsqueeze = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm->output(0), unsqueeze_const);
auto first_reshape_pattern = std::make_shared<ngraph::op::Constant>(ngraph::element::i64,
ngraph::Shape{4}, std::vector<size_t>({1, 1, 1, hiddenSize}));
@ -261,7 +261,7 @@ void MultipleLSTMCellTest::switchToNgraphFriendlyModel() {
reccurrenceWeightsNode_2, biasNode_2, hiddenSize);
auto unsqueeze_2_const = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{1}, squeeze_axes);
auto unsqueeze_2 = std::make_shared<ngraph::op::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
auto unsqueeze_2 = std::make_shared<ngraph::op::v0::Unsqueeze>(lstm_2->output(0), unsqueeze_2_const);
auto final_reshape_pattern = std::make_shared<ngraph::op::Constant>(ngraph::element::i64,
ngraph::Shape{4}, std::vector<size_t>({1, 1, 1, hiddenSize}));

View File

@ -67,7 +67,7 @@ std::shared_ptr<ngraph::Function> UnsqueezeFunction::getReference(
const std::shared_ptr<Node> dequantizationOpBefore = makeDequantization(input, dequantizationBefore);
const auto unsqueeze = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Unsqueeze>>(
op::Unsqueeze(dequantizationOpBefore, std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ axes.size() }, axes)),
op::v0::Unsqueeze(dequantizationOpBefore, std::make_shared<ngraph::opset1::Constant>(element::i64, Shape{ axes.size() }, axes)),
precisionAfterOperation);
const std::shared_ptr<Node> dequantizationOpAfter = makeDequantization(unsqueeze, dequantizationAfter);
dequantizationOpAfter->set_friendly_name("output");

View File

@ -526,7 +526,6 @@ extensions/middle/DeleteControlFlowEdges.py
extensions/middle/DeleteNotExecutable.py
extensions/middle/DilatedConvolution.py
extensions/middle/EltwiseChecker.py
extensions/middle/EltwiseInputNormalization.py
extensions/middle/EltwiseInputReshape.py
extensions/middle/FakeSplitOutputs.py
extensions/middle/FusedBatchNormNonConstant.py

View File

@ -1,48 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import networkx as nx
import numpy as np
from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
class EltwiseInputNormalize(EltwiseInputReshape, MiddleReplacementPattern):
# This pass should be called directly from pipeline before layout change and other permutations
enabled = False
def find_and_replace_pattern(self, graph: Graph):
eltwise_nodes = graph.get_op_nodes(is_eltwise=True)
# Iterating over all Eltwise operations and check that every input has similar shape
# in case of different shapes, we inserts new_shape attribute and then call EltwiseInputReshape extension
# that insert reshapes (in case of not constant nodes) or directly reshapes values in data nodes for specified
# shape
for node in eltwise_nodes:
output_shape = node.out_node().shape
for in_node in node.in_nodes().values():
if len(in_node.shape) != len(output_shape):
# Set edge attribute new_shape for further transformation pass
new_shape = in_node.shape
for x in range(len(output_shape) - len(in_node.shape)):
new_shape = np.insert(new_shape, 0, 1)
nx.set_edge_attributes(G=node.graph,
values={(in_node.id, node.id, 0): new_shape},
name='new_shape')
super().find_and_replace_pattern(graph)

View File

@ -1,202 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import numpy as np
from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
from mo.front.common.partial_infer.utils import int64_array
from mo.middle.passes.eliminate_test import build_graph
from mo.utils.ir_engine.compare_graphs import compare_graphs
# The dictionary with nodes attributes used to build various graphs. A key is the name of the node and the value is the
# dictionary with node attributes.
nodes_attributes = {
# Placeholder layers
'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# Reshape layers
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
'reshape_1_const_data': {'kind': 'data', 'value': None, 'shape': None},
'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
'reshape_2_const_data': {'kind': 'data', 'value': None, 'shape': None},
# Eltwise consumes layers
'eltwise_1': {'kind': 'op', 'is_eltwise': True},
'eltwise_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'eltwise_2': {'kind': 'op', 'is_eltwise': True},
'eltwise_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'eltwise_3': {'kind': 'op', 'is_eltwise': True},
'eltwise_3_data': {'value': None, 'shape': None, 'kind': 'data'},
'eltwise_4': {'kind': 'op', 'is_eltwise': True},
'eltwise_4_data': {'value': None, 'shape': None, 'kind': 'data'},
# Concat
'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
}
class EltwiseInputNormalizationTest(unittest.TestCase):
def test1_not_constant(self):
#
# data1(1,3,64,64)----. data(1,3,64,64)-------.
# data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->...
# data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-'
#
graph = build_graph(nodes_attributes, [
('placeholder_1', 'placeholder_1_data'),
('placeholder_1', 'placeholder_2_data'),
('placeholder_1', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('placeholder_3_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([1, 64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[
('placeholder_1', 'placeholder_1_data'),
('placeholder_1', 'placeholder_2_data'),
('placeholder_1', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('placeholder_3_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_1', 'reshape_1_data'),
('reshape_2', 'reshape_2_data'),
('reshape_1_data', 'eltwise_1'),
('reshape_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]),
'shape': int64_array([4])},
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
'reshape_2_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
'reshape_2_const_data': {'value': int64_array([1, 1, 64, 1]),
'shape': int64_array([4])},
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
pattern = EltwiseInputNormalize()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_mega_hardcore(self):
# ORIGINAL GRAPH
#
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
# /\ /\ /\
# data2(64,1)-----,-'--------------------------------'------------------------------'
# \/ /
# data3(64,1)----`-->Eltwise3->data(64,1)----------'
#
# REFERENCE GRAPH AFTER TRANSFORMATION
#
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
# /\ /\ /\
# data2(1,1,64,1)---'--------------------------------'-------------------------------'
# /
# data4(64,1)-------, Reshape(1,1,64,1)
# \/ |
# data3(64,1)------`---->Eltwise3->data(64,1)---'
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_1_data', 'eltwise_2'),
('placeholder_2_data', 'eltwise_3'),
('placeholder_3_data', 'eltwise_3'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_3_data', 'eltwise_2'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_2_data', 'eltwise_4'),
('placeholder_2_data', 'eltwise_4'),
('eltwise_4', 'eltwise_4_data'),
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_3_data': {'shape': np.array([64, 1])},
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_1_data', 'eltwise_2'),
('placeholder_4_data', 'eltwise_3'),
('placeholder_3_data', 'eltwise_3'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_3_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'eltwise_2'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_2_data', 'eltwise_4'),
('placeholder_2_data', 'eltwise_4'),
('eltwise_4', 'eltwise_4_data'),
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([1, 1, 64, 1]),
'value': np.ones([1, 1, 64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'placeholder_4_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
'reshape_1_const': {'value': int64_array([1, 1, 64, 1]), 'shape': int64_array([4])},
'reshape_1_const_data': {'value': int64_array([1, 1, 64, 1]),
'shape': int64_array([4])},
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_3_data': {'shape': np.array([64, 1])},
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
pattern = EltwiseInputNormalize()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -17,11 +17,13 @@
import numpy as np
from mo.front.common.layout import get_features_dim, shape_for_layout
from mo.graph.graph import Graph
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.op import Op
from mo.ops.reshape import Reshape
from mo.ops.unsqueeze import Unsqueeze
class Eltwise1DInputReshape(MiddleReplacementPattern):
@ -42,9 +44,6 @@ class Eltwise1DInputReshape(MiddleReplacementPattern):
"""
enabled = False
def run_after(self):
return [EltwiseInputReshape]
def find_and_replace_pattern(self, graph: Graph):
layout = graph.graph['layout']
for eltwise_op_node in graph.get_op_nodes(is_eltwise=True):
@ -64,59 +63,85 @@ class Eltwise1DInputReshape(MiddleReplacementPattern):
reshape_op.out_port(0).connect(eltwise_op_node.in_port(port))
class EltwiseInputReshape(MiddleReplacementPattern):
enabled = True
force_clean_up = True
def compute_unsqueeze_map_for_eltwise(eltwise_node: Node):
'''
The function computes a map of unsqueeze_dims for each producer of eltwise node.
These unsqueeze_dims are needed to normalize input shapes of eltwise node.
'''
eltwise_shape = eltwise_node.out_port(0).data.get_shape()
max_dims = max(
[len(port.data.get_shape()) for port in eltwise_node.in_ports().values() if port.data.get_shape() is not None])
axis = eltwise_node.soft_get('axis', None)
unsqueeze_dims_map = {}
for consumer_port in eltwise_node.in_ports().values():
producer_port = consumer_port.get_source()
producer_shape = producer_port.data.get_shape()
unsqueeze_dims = int64_array([])
def run_after(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]
# 1. Compute unsqueeze dimensions in the tail
if len(producer_shape) != max_dims and len(producer_shape) > 0 and axis is not None:
num_unsqueeze_dims = max_dims - axis - len(producer_shape)
if num_unsqueeze_dims > 0:
unsqueeze_dims = np.arange(len(producer_shape), len(producer_shape) + num_unsqueeze_dims,
dtype=np.int64)
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_data_nodes():
# Get all requested shapes for current node
# This mapping will contain pairs like {shape:[list of consumers nodes]}
mapping = {}
for consumer in node.out_nodes():
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
if 'new_shape' in edge_attrs:
if np.array_equal(edge_attrs['new_shape'], node.shape):
continue
new_shape = tuple([x for x in edge_attrs['new_shape']])
if not new_shape in mapping:
mapping.update({new_shape: [consumer]})
else:
mapping[new_shape].append(consumer)
# 2. Compute unsqueeze dimensions in the head
unsqueeze_dims_head = np.arange(len(eltwise_shape) - len(producer_shape) - len(unsqueeze_dims), dtype=np.int64)
if node.has_valid('value'):
# Check that requested shape are the same
# In case if they are different, we duplicate them
for shape_key in mapping.keys():
shape = list(shape_key)
new_value = np.reshape(node.value, shape)
node_copy = Op.create_input_data_node(graph, node.id + '/copy', value=np.array(new_value))
for consumer in mapping[shape_key]:
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
del edge_attrs['new_shape']
# Pay attention that unsqueeze dims order makes sense
# since shape is normalized in the tail first and after in the head
unsqueeze_dims = np.concatenate((unsqueeze_dims, unsqueeze_dims_head))
unsqueeze_dims_map[producer_port] = unsqueeze_dims
# Remove edge from previous data node and connect new data node with its consumer
graph.remove_edge(node.id, consumer.id)
graph.add_edge(node_copy.id, consumer.id, **edge_attrs)
return unsqueeze_dims_map
def normalize_eltwise_inputs(graph: Graph):
'''
The function normalizes input shapes for eltwise nodes.
In the first step the function gets to know which shapes/unsqueeze dims for inputs are required for normalization.
In the second step the function inserts Unsqueeze nodes between non-normalized inputs and eltwise nodes.
'''
# Generate a map for producers of eltwise nodes with non-normalized shapes
# and in this map every producer has another map that reflects normalized shape
# to a list of eltwise consumers
mapping = {}
for eltwise_node in graph.get_op_nodes(is_eltwise=True):
unsqueeze_dims_map = compute_unsqueeze_map_for_eltwise(eltwise_node)
for consumer_port in eltwise_node.in_ports().values():
producer_port = consumer_port.get_source()
unsqueeze_dims = unsqueeze_dims_map[producer_port]
if unsqueeze_dims is not None and len(unsqueeze_dims) > 0:
unsqueeze_dims = tuple([x for x in unsqueeze_dims])
if producer_port not in mapping:
mapping.update({producer_port: {unsqueeze_dims: [consumer_port]}})
elif unsqueeze_dims not in mapping[producer_port]:
mapping[producer_port].update({unsqueeze_dims: [consumer_port]})
else:
mapping[producer_port][unsqueeze_dims].append(consumer_port)
# Walk through each produced in the map and insert Unsqueeze nodes between a producer and eltwise nodes
for producer_port in mapping.keys():
producer_node = producer_port.node
for unsqueeze_dims in mapping[producer_port].keys():
unsqueeze_name = producer_node.soft_get('name', producer_node.id) + '/EltwiseUnsqueeze'
unsqueeze_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(list(unsqueeze_dims))},
{'name': unsqueeze_name})
unsqueeze_node.in_port(0).connect(producer_port)
# Insert Unsqueeze with determined unsqueeze dimensions between the current producer and eltwise node
for consumer_port in mapping[producer_port][unsqueeze_dims]:
consumer_port.connect(unsqueeze_node.out_port(0))
# The shape and value adjustments must be explicitly done within the transformation
# since the transformation is called from Fusing transformation that excludes
# automatic call of shape inference pass
producer_port_value = producer_port.data.get_value()
producer_port_shape = producer_port.data.get_shape()
new_shape = producer_port_shape.copy()
for unsqueeze_dim in unsqueeze_dims:
new_shape = np.insert(new_shape, unsqueeze_dim, 1)
if producer_port_value is not None:
unsqueeze_node.out_port(0).data.set_value(np.reshape(producer_port_value, new_shape))
else:
# Insert Reshape layer between data node and consumer
for shape_key in mapping.keys():
shape = list(shape_key)
reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
reshape = Reshape(graph, attrs={'name': reshape_name})
reshape_dim = Const(graph,
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
# Iterate over consumers and reconnect them to Reshape layer output
for consumer in mapping[shape_key]:
edge_attrs = graph.get_edge_data(node.id, consumer.id)[0]
del edge_attrs['new_shape']
# Reconnect edge from original data node to Reshape output datanode
graph.remove_edge(node.id, consumer.id)
graph.add_edge(reshape_data.id, consumer.id, **edge_attrs)
unsqueeze_node.out_port(0).data.set_shape(new_shape)

View File

@ -18,7 +18,7 @@ import unittest
import numpy as np
from extensions.middle.EltwiseInputReshape import EltwiseInputReshape
from extensions.middle.EltwiseInputReshape import normalize_eltwise_inputs
from mo.front.common.partial_infer.utils import int64_array
from mo.middle.passes.eliminate_test import build_graph
from mo.utils.ir_engine.compare_graphs import compare_graphs
@ -28,47 +28,216 @@ from mo.utils.ir_engine.compare_graphs import compare_graphs
nodes_attributes = {
# Placeholder layers
'placeholder_1': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_2': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_3': {'value': None, 'shape': None, 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_1_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_2_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_3_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
'placeholder_4_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': None},
# Reshape layers
'reshape_1': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
'reshape_1': {'type': 'Unsqueeze', 'value': None, 'kind': 'op', 'op': 'Unsqueeze'},
'reshape_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_1_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
'reshape_1_const_data': {'kind': 'data', 'value': None, 'shape': None},
'reshape_2': {'type': 'Reshape', 'value': None, 'kind': 'op', 'op': 'Reshape'},
'reshape_2': {'type': 'Unsqueeze', 'value': None, 'kind': 'op', 'op': 'Unsqueeze'},
'reshape_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'reshape_2_const': {'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': None},
'reshape_2_const_data': {'kind': 'data', 'value': None, 'shape': None},
# Fake consumes layers
'consumer_1': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
'consumer_2': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
'consumer_3': {'type': 'Consumer', 'value': None, 'kind': 'op', 'op': 'Consumer'},
# Eltwise consumes layers
'eltwise_1': {'kind': 'op', 'is_eltwise': True},
'eltwise_1_data': {'value': None, 'shape': None, 'kind': 'data'},
'eltwise_2': {'kind': 'op', 'is_eltwise': True},
'eltwise_2_data': {'value': None, 'shape': None, 'kind': 'data'},
'eltwise_3': {'kind': 'op', 'is_eltwise': True},
'eltwise_3_data': {'value': None, 'shape': None, 'kind': 'data'},
'eltwise_4': {'kind': 'op', 'is_eltwise': True},
'eltwise_4_data': {'value': None, 'shape': None, 'kind': 'data'},
# Concat
'concat': {'type': 'Concat', 'kind': 'op', 'op': 'Concat'},
}
class EltwiseInputReshapeTest(unittest.TestCase):
class EltwiseInputNormalizationTest(unittest.TestCase):
def test1_not_constant(self):
#
# data1(1,3,64,64)----. data(1,3,64,64)-------.
# data2(1,64,1)-------->Eltwise-->data(1,3,64,64) => data(1,64,1)->Reshape->data(1,1,64,1)-->Eltwise->...
# data3(64,1)------' data(64,1)->Reshape->data(1,1,64,1)-'
#
graph = build_graph(nodes_attributes, [
('placeholder_1', 'placeholder_1_data'),
('placeholder_1', 'placeholder_2_data'),
('placeholder_1', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('placeholder_3_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([1, 64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[
('placeholder_1', 'placeholder_1_data'),
('placeholder_1', 'placeholder_2_data'),
('placeholder_1', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('placeholder_3_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_1', 'reshape_1_data'),
('reshape_2', 'reshape_2_data'),
('reshape_1_data', 'eltwise_1'),
('reshape_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])},
'reshape_1_const_data': {'value': int64_array([0]),
'shape': int64_array([1])},
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
'reshape_2_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
'reshape_2_const_data': {'value': int64_array([0, 1]),
'shape': int64_array([2])},
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_mega_hardcore(self):
# ORIGINAL GRAPH
#
# data1(1,3,64,64)---,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
# /\ /\ /\
# data2(64,1)-----,-'--------------------------------'------------------------------'
# \/ /
# data3(64,1)----`-->Eltwise3->data(64,1)----------'
#
# REFERENCE GRAPH AFTER TRANSFORMATION
#
# data1(1,3,64,64)---------------------,->Eltwise1->data(1,3,64,64)-----,->Eltwise2->data(1,3,64,64)---,->Eltwise4->data(1,3,64,64)
# /\ /\ /\
# data2(64,1)-,- Reshape1(1,1,64,64)--'--------------------------------o-------------------------------'
# | |
# | Reshape(1,1,64,1)
# \/ |
# data3(64,1)----------->Eltwise3->data(64,1)--------------------------'
#
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_3', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_1_data', 'eltwise_2'),
('placeholder_2_data', 'eltwise_3'),
('placeholder_3_data', 'eltwise_3'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_3_data', 'eltwise_2'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_2_data', 'eltwise_4'),
('placeholder_2_data', 'eltwise_4'),
('eltwise_4', 'eltwise_4_data'),
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([64, 1]), 'value': np.ones([64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_3_data': {'shape': np.array([64, 1])},
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_3', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_1_data', 'eltwise_2'),
('placeholder_2_data', 'eltwise_3'),
('placeholder_3_data', 'eltwise_3'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_3_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'eltwise_2'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_2_data', 'eltwise_4'),
('reshape_1_data', 'eltwise_4'),
('eltwise_4', 'eltwise_4_data'),
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([64, 1]),
'value': np.ones([64, 1])},
'placeholder_3_data': {'shape': np.array([64, 1])},
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
'reshape_1_const_data': {'value': int64_array([0, 1]),
'shape': int64_array([2])},
'reshape_1_data': {'shape': np.array([1, 1, 64, 1])},
'reshape_2_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
'reshape_2_const_data': {'value': int64_array([0, 1]),
'shape': int64_array([2])},
'reshape_2_data': {'shape': np.array([1, 1, 64, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_2_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_3_data': {'shape': np.array([64, 1])},
'eltwise_4_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_4', check_op_attrs=True)
self.assertTrue(flag, resp)
def test2_not_constant(self):
# ,-------------->consumer3 ,------------>consumer3
# data---(new_shape1)-->consumer1 => data---->Reshape-->consumer1
# `-(new_shape2)-->consumer2 `-->Reshape-->consumer2
#
graph = build_graph(nodes_attributes, [
('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_1_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
{'placeholder_1_data': {'shape': int64_array([1, 3])},
'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])},
'eltwise_2_data': {'shape': int64_array([1, 1, 3])},
'eltwise_3_data': {'shape': int64_array([1, 3])},
},
nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[
@ -79,32 +248,34 @@ class EltwiseInputReshapeTest(unittest.TestCase):
('placeholder_1_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('placeholder_1_data', 'consumer_3'),
('placeholder_1_data', 'eltwise_3'),
('reshape_1', 'reshape_1_data'),
('reshape_2', 'reshape_2_data'),
('reshape_1_data', 'consumer_1'),
('reshape_2_data', 'consumer_2'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
('reshape_1_data', 'eltwise_1'),
('reshape_2_data', 'eltwise_2'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3])},
'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])},
'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]),
'shape': int64_array([4])},
'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])},
'reshape_2_const': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])},
'reshape_2_const_data': {'value': int64_array([1, 1, 3]), 'shape': int64_array([3])},
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
'reshape_1_const_data': {'value': int64_array([0, 1]),
'shape': int64_array([2])},
'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])},
'reshape_2_const': {'value': int64_array([0]), 'shape': int64_array([1])},
'reshape_2_const_data': {'value': int64_array([0]), 'shape': int64_array([1])},
'reshape_2_data': {'shape': int64_array([1, 1, 3])},
}, nodes_with_edges_only=True)
pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
def test2_not_constant(self):
def test3_not_constant(self):
# ,--------------->consumer3 ,----------->consumer3
# data---(new_shape1)-->consumer1 => data-->Reshape-->consumer1
# `-(new_shape1)-->consumer2 `-->consumer2
@ -112,14 +283,22 @@ class EltwiseInputReshapeTest(unittest.TestCase):
graph = build_graph(nodes_attributes,
[
('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3, 1, 1])}),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_1_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
{'placeholder_1_data': {'shape': int64_array([1, 3])},
'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])},
'eltwise_2_data': {'shape': int64_array([1, 1, 1, 3])},
'eltwise_3_data': {'shape': int64_array([1, 3])},
},
nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[
@ -127,123 +306,239 @@ class EltwiseInputReshapeTest(unittest.TestCase):
('placeholder_1_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('placeholder_1_data', 'consumer_3'),
('placeholder_1_data', 'eltwise_3'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'consumer_1'),
('reshape_1_data', 'consumer_2'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
('reshape_1_data', 'eltwise_1'),
('reshape_1_data', 'eltwise_2'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3])},
'reshape_1_const': {'value': int64_array([1, 3, 1, 1]), 'shape': int64_array([4])},
'reshape_1_const_data': {'value': int64_array([1, 3, 1, 1]),
'shape': int64_array([4])},
'reshape_1_data': {'shape': int64_array([1, 3, 1, 1])},
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
'reshape_1_const_data': {'value': int64_array([0, 1]),
'shape': int64_array([2])},
'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])},
}, nodes_with_edges_only=True)
pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
def test3_constant(self):
# ,--------------->consumer3 data-->consumer3
# data---(new_shape1)-->consumer1 => data-->consumer1
# `-(new_shape2)-->consumer2 data-->consumer2
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3, 1, 1])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 1, 3])}),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}},
nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1'),
('placeholder_2_data', 'consumer_2'),
('placeholder_3_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3, 1, 1]),
'value': np.ones([1, 3, 1, 1])},
'placeholder_2_data': {'shape': int64_array([1, 1, 3]), 'value': np.ones([1, 1, 3])},
'placeholder_3_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
}, nodes_with_edges_only=True)
pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
def test4_constant(self):
# ,--------------->consumer3 ,-->consumer3
# data---(new_shape1)-->consumer1 => data-->consumer1
# `-(new_shape2)-->consumer2 `->consumer2
# ,--------------->consumer3 ,------------>consumer3
# data---(new_shape1)-->consumer1 => data--->reshape1-->consumer1
# `-(new_shape2)-->consumer2 `->reshape2-->consumer2
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([3, 1, 1])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([3, 1, 1])}),
('placeholder_1_data', 'consumer_3', {'new_shape': int64_array([3, 1, 1])}),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_1_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])}},
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
'eltwise_1_data': {'shape': int64_array([1, 1, 1, 3])},
'eltwise_2_data': {'shape': int64_array([1, 1, 3])},
'eltwise_3_data': {'shape': int64_array([1, 3])},
},
nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1'),
('placeholder_1_data', 'consumer_2'),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'eltwise_1'),
('placeholder_1_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_2', 'reshape_2_data'),
('reshape_2_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([3, 1, 1]), 'value': np.ones([3, 1, 1])}
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
'reshape_1_const': {'value': int64_array([0, 1]), 'shape': int64_array([2])},
'reshape_1_const_data': {'value': int64_array([0, 1]),
'shape': int64_array([2])},
'reshape_1_data': {'shape': int64_array([1, 1, 1, 3])},
'reshape_2_const': {'value': int64_array([0]), 'shape': int64_array([1])},
'reshape_2_const_data': {'value': int64_array([0]),
'shape': int64_array([1])},
'reshape_2_data': {'shape': int64_array([1, 1, 3])},
}, nodes_with_edges_only=True)
pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
def test5_not_constant(self):
def test5_constant(self):
# ,-(new_shape)-->consumer3 ,-->consumer3
# data---(new_shape)-->consumer1 => data-->reshape---->consumer1
# `-(new_shape)-->consumer2 `-->consumer2
#
graph = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_1_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
'eltwise_1_data': {'shape': int64_array([1, 1, 3])},
'eltwise_2_data': {'shape': int64_array([1, 1, 3])},
'eltwise_3_data': {'shape': int64_array([1, 1, 3])},
},
nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('reshape_1', 'reshape_1_data'),
('reshape_1_data', 'eltwise_1'),
('reshape_1_data', 'eltwise_2'),
('reshape_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3]), 'value': np.ones([1, 3])},
'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])},
'reshape_1_const_data': {'value': int64_array([0]),
'shape': int64_array([1])},
'reshape_1_data': {'shape': int64_array([1, 1, 3])},
}, nodes_with_edges_only=True)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
def test6_not_constant(self):
# ,--------------->consumer3 ,->consumer3
# data---(new_shape1)-->consumer1 => data----->consumer1
# `-(new_shape1)-->consumer2 `-->consumer2
#
graph = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_1_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
{'placeholder_1_data': {'shape': int64_array([1, 3])},
'eltwise_1_data': {'shape': int64_array([1, 3])},
'eltwise_2_data': {'shape': int64_array([1, 3])},
'eltwise_3_data': {'shape': int64_array([1, 3])},
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[('placeholder_1_data', 'consumer_1', {'new_shape': int64_array([1, 3])}),
('placeholder_1_data', 'consumer_2', {'new_shape': int64_array([1, 3])}),
('placeholder_1_data', 'consumer_3'),
('consumer_1', 'concat'),
('consumer_2', 'concat'),
('consumer_3', 'concat'),
[('placeholder_1', 'placeholder_1_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_1_data', 'eltwise_2'),
('placeholder_1_data', 'eltwise_3'),
('eltwise_1', 'eltwise_1_data'),
('eltwise_2', 'eltwise_2_data'),
('eltwise_3', 'eltwise_3_data'),
('eltwise_1_data', 'concat'),
('eltwise_2_data', 'concat'),
('eltwise_3_data', 'concat'),
],
{'placeholder_1_data': {'shape': int64_array([1, 3])}}, nodes_with_edges_only=True)
pattern = EltwiseInputReshape()
pattern.find_and_replace_pattern(graph)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
self.assertTrue(flag, resp)
def test7_axis1_not_constant(self):
#
# data1(1,3,64,64)----. data(1,3,64,64)-------.
# data2(3,64,1)-------->Eltwise-->data(1,3,64,64)=> data(3,64,1)->Unsqueeze(0)->data(1,3,64,1)-->Eltwise->...
# data3(3,1)------' data(3,1)->Unsqueeze(2, 0)->data(1,3,1,1)-'
#
graph = build_graph(nodes_attributes, [
('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_3', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'eltwise_1'),
('placeholder_3_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([3, 64, 1])},
'placeholder_3_data': {'shape': np.array([3, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])},
'eltwise_1' : {'axis': 1}
}, nodes_with_edges_only=True)
graph_ref = build_graph(nodes_attributes,
[
('placeholder_1', 'placeholder_1_data'),
('placeholder_2', 'placeholder_2_data'),
('placeholder_3', 'placeholder_3_data'),
('placeholder_1_data', 'eltwise_1'),
('placeholder_2_data', 'reshape_1'),
('reshape_1_const', 'reshape_1_const_data'),
('reshape_1_const_data', 'reshape_1'),
('placeholder_3_data', 'reshape_2'),
('reshape_2_const', 'reshape_2_const_data'),
('reshape_2_const_data', 'reshape_2'),
('reshape_1', 'reshape_1_data'),
('reshape_2', 'reshape_2_data'),
('reshape_1_data', 'eltwise_1'),
('reshape_2_data', 'eltwise_1'),
('eltwise_1', 'eltwise_1_data')
],
{'placeholder_1_data': {'shape': np.array([1, 3, 64, 64])},
'placeholder_2_data': {'shape': np.array([3, 64, 1])},
'placeholder_3_data': {'shape': np.array([3, 1])},
'reshape_1_const': {'value': int64_array([0]), 'shape': int64_array([1])},
'reshape_1_const_data': {'value': int64_array([0]),
'shape': int64_array([1])},
'reshape_1_data': {'shape': np.array([1, 3, 64, 1])},
'reshape_2_const': {'value': int64_array([2, 0]), 'shape': int64_array([2])},
'reshape_2_const_data': {'value': int64_array([2, 0]),
'shape': int64_array([2])},
'reshape_2_data': {'shape': np.array([1, 3, 1, 1])},
'eltwise_1_data': {'shape': np.array([1, 3, 64, 64])}
}, nodes_with_edges_only=True)
normalize_eltwise_inputs(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'eltwise_1', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -16,7 +16,7 @@
from extensions.front.div import Div
from extensions.front.sub import Sub
from extensions.middle.AddFakeQuantizeFuse import AddFakeQuantizeFuse
from extensions.middle.EltwiseInputNormalization import EltwiseInputNormalize
from extensions.middle.EltwiseInputReshape import normalize_eltwise_inputs
from extensions.middle.MulFakeQuantizeFuse import MulFakeQuantizeFuse
from extensions.middle.RemoveRedundantReshapes import RemoveRedundantReshapes
@ -82,7 +82,7 @@ class Fusing(MiddleReplacementPattern):
for_graph_and_each_sub_graph_recursively(graph, fuse_mul_add_sequence)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
EltwiseInputNormalize().find_and_replace_pattern(graph)
normalize_eltwise_inputs(graph)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
# Fusing linear operation to Convolution
@ -96,7 +96,7 @@ class Fusing(MiddleReplacementPattern):
for_graph_and_each_sub_graph_recursively(graph, fuse_linear_ops)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
EltwiseInputNormalize().find_and_replace_pattern(graph)
normalize_eltwise_inputs(graph)
for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up())
MarkNodesToFuseUpToFakeQuantize().find_and_replace_pattern(graph)

View File

@ -14,17 +14,14 @@
limitations under the License.
"""
import networkx as nx
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
def eltwise_infer(node, op=None, **kwargs):
raw_inputs = [(inp, attr) for inp, attr in node.get_sorted_inputs()
if 'control_flow_edge' not in attr or not attr['control_flow_edge']]
inputs = [Node(node.graph, inp) for inp, attr in raw_inputs]
shapes = [node.graph.node[inp]['shape'] for inp, attr in raw_inputs]
values = [node.graph.node[inp]['value'] for inp, attr in raw_inputs]
@ -53,14 +50,6 @@ def eltwise_infer(node, op=None, **kwargs):
shapes[id] = new_shape
# Save shape for further transformation that applies this shapes for input nodes
# We set new_shape attribute on edge for given input node
edge_attrs = node.graph.get_edge_data(inputs[id].id, node.id)[0]
nx.set_edge_attributes(G=node.graph,
values={(inputs[id].id, node.id, 0): new_shape},
name='new_shape')
# Reshape value to correctly calculate output shape
if values[id] is not None:
values[id] = np.reshape(values[id], new_shape)

View File

@ -38,13 +38,13 @@ def get_canonical_axis_index_node(rank: Node, axis: int) -> Node:
graph = rank.graph
name = rank.soft_get('name', rank.id)
if axis < 0:
axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array([axis])}).create_node()
axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array(axis)}).create_node()
add = Add(graph, {'name': name + '/positive_axis'}).create_node()
rank.out_port(0).connect(add.in_port(0))
axis.out_port(0).connect(add.in_port(1))
return add
else:
return Const(graph, {'name': name + '/positive_axis', 'value': int64_array([axis])}).create_node()
return Const(graph, {'name': name + '/positive_axis', 'value': int64_array(axis)}).create_node()
def get_range_node_of_idxs(rank: Node, begin: int, end: int,
@ -66,20 +66,20 @@ def get_range_node_of_idxs(rank: Node, begin: int, end: int,
end_idx = get_canonical_axis_index_node(rank, end)
if not include_begin:
const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
const = Const(graph, {'value': int64_array(1), 'name': name + '/exclude_begin/value'}).create_node()
add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
start_idx.out_port(0).connect(add.in_port(0))
const.out_port(0).connect(add.in_port(1))
start_idx = add
if include_end:
const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
const = Const(graph, {'value': int64_array(1), 'name': name + '/including_end/value'}).create_node()
add = Add(graph, {'name': name + '/including_end'}).create_node()
end_idx.out_port(0).connect(add.in_port(0))
const.out_port(0).connect(add.in_port(1))
end_idx = add
delta = Const(graph, {'name': name + '/delta', 'value': int64_array([1])}).create_node()
delta = Const(graph, {'name': name + '/delta', 'value': int64_array(1)}).create_node()
range_node = Range(graph, {'name': name + '/range_idxs'}).create_node()
start_idx.out_port(0).connect(range_node.in_port(0))

View File

@ -21,9 +21,6 @@
#include "ngraph/axis_vector.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/fused_op.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
namespace ngraph
{
@ -31,7 +28,7 @@ namespace ngraph
{
namespace v0
{
class NGRAPH_API Unsqueeze : public ngraph::op::util::FusedOp
class NGRAPH_API Unsqueeze : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
@ -39,9 +36,7 @@ namespace ngraph
Unsqueeze() = default;
Unsqueeze(const Output<Node>& data, const Output<Node>& axes);
virtual void pre_validate_and_infer_types() override;
virtual OutputVector decompose_op() const override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
@ -55,5 +50,3 @@ namespace ngraph
using v0::Unsqueeze;
}
}
NGRAPH_SUPPRESS_DEPRECATED_END

View File

@ -29,17 +29,15 @@
using namespace std;
using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START
NGRAPH_RTTI_DEFINITION(op::v0::Unsqueeze, "Unsqueeze", 0);
op::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
: FusedOp({data, axes})
op::v0::Unsqueeze::Unsqueeze(const Output<Node>& data, const Output<Node>& axes)
: Op({data, axes})
{
constructor_validate_and_infer_types();
}
void op::Unsqueeze::pre_validate_and_infer_types()
void op::v0::Unsqueeze::validate_and_infer_types()
{
const auto data = input_value(0);
auto data_partial_shape = data.get_partial_shape();
@ -79,24 +77,12 @@ void op::Unsqueeze::pre_validate_and_infer_types()
set_output_type(0, get_input_element_type(0), PartialShape{output_shape});
}
OutputVector op::Unsqueeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
this,
(get_output_partial_shape(0).is_static()),
"output shape was not calculated during pre_validate_and_infer_types. Can not decompose.");
auto data = input_value(0);
auto data_shape = data.get_shape();
auto output_shape = get_output_shape(0);
return {builder::opset1::reshape(data, output_shape)};
}
bool ngraph::op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor)
bool op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Unsqueeze::clone_with_new_inputs(const OutputVector& new_args) const
shared_ptr<Node> op::v0::Unsqueeze::clone_with_new_inputs(const OutputVector& new_args) const
{
if (new_args.size() != 2)
{

View File

@ -1487,7 +1487,7 @@ NGRAPH_TEST(${BACKEND_NAME}, unsqueeze)
auto data_node = make_shared<op::Parameter>(element::f32, Shape{4, 2});
auto axes_node =
make_shared<ngraph::op::Constant>(element::i64, Shape{2}, vector<int64_t>{1, 2});
auto squeeze = make_shared<op::Unsqueeze>(data_node, axes_node);
auto squeeze = make_shared<op::v0::Unsqueeze>(data_node, axes_node);
auto function = make_shared<Function>(NodeVector{squeeze}, ParameterVector{data_node});
auto test_case = test::TestCase<TestEngine>(function);

View File

@ -189,7 +189,7 @@ TEST(constant_folding, constant_unsqueeze)
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
vector<int64_t> values_axes{2, 3};
auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
auto unsqueeze = make_shared<op::v0::Unsqueeze>(constant, constant_axes);
unsqueeze->set_friendly_name("test");
auto f = make_shared<Function>(unsqueeze, ParameterVector{});
@ -197,7 +197,7 @@ TEST(constant_folding, constant_unsqueeze)
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
ASSERT_EQ(count_ops_of_type<op::v0::Unsqueeze>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =

View File

@ -877,7 +877,7 @@ namespace
void op_is_Unsqueeze()
{
op::Unsqueeze node;
op::v0::Unsqueeze node;
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));

View File

@ -126,5 +126,5 @@ NGRAPH_OP(Tan, ngraph::op)
NGRAPH_OP(Tanh, ngraph::op)
NGRAPH_OP(TensorIterator, ngraph::op)
NGRAPH_OP(Tile, ngraph::op::v0)
NGRAPH_OP(Unsqueeze, ngraph::op)
NGRAPH_OP(Unsqueeze, ngraph::op::v0)
NGRAPH_OP(Xor, ngraph::op)

View File

@ -26,7 +26,7 @@ TEST(type_prop, unsqueeze)
auto param = make_shared<op::Parameter>(element::f32, Shape{4, 1, 4, 1, 8});
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
auto unsqueeze = make_shared<op::v0::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
ASSERT_EQ(unsqueeze->get_shape(), (Shape{4, 1, 1, 1, 4, 1, 8}));
@ -37,7 +37,7 @@ TEST(type_prop, unsqueeze_dynamic)
auto param = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(5));
auto axes_node =
make_shared<ngraph::op::Constant>(element::u64, Shape{2}, vector<int64_t>{1, 2});
auto unsqueeze = make_shared<op::Unsqueeze>(param, axes_node);
auto unsqueeze = make_shared<op::v0::Unsqueeze>(param, axes_node);
ASSERT_EQ(unsqueeze->get_element_type(), element::f32);
EXPECT_TRUE(