publish master branch snapshot, revision ea98a886d925eb152931aab13856e68037665562

This commit is contained in:
Alexey Suhov 2020-05-22 03:42:00 +03:00
parent deb008a26f
commit ccb7438803
45 changed files with 2674 additions and 875 deletions

View File

@ -85,7 +85,7 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
THROW_IE_EXCEPTION << "Invalid IR! " << node_param.name << " name is not unique!";
opName.insert(node_param.name);
params[node_param.layerId] = {node, node_param};
if (node_param.type == "Result") {
if (node_param.type == "Result" || node_param.type == "Assign") {
outputs.push_back(node_param.layerId);
}
}
@ -118,7 +118,9 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
ngraph::ParameterVector parameter_nodes;
ngraph::ResultVector result_nodes;
std::vector<std::shared_ptr<ngraph::Node>> allNodes;
ngraph::NodeVector allNodes;
std::vector<std::shared_ptr<ngraph::op::Assign>> assign_nodes;
std::map<std::string, std::shared_ptr<ngraph::Node>> variable_id_to_read_value;
// Following topological order create nGraph operations
for (auto& layer_id : order) {
@ -159,12 +161,28 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
if (auto result_node = std::dynamic_pointer_cast<ngraph::op::Result>(node)) {
result_nodes.emplace_back(result_node);
}
if (auto assign_node = std::dynamic_pointer_cast<ngraph::op::Assign>(node)) {
assign_nodes.emplace_back(assign_node);
}
if (auto read_value_node = std::dynamic_pointer_cast<ngraph::op::ReadValue>(node)) {
variable_id_to_read_value[read_value_node->get_variable_id()] = read_value_node;
}
allNodes.emplace_back(node);
}
::ngraph::op::GenericIE::DisableReshape noReshape(allNodes);
return std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
auto function = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
if (!result_nodes.empty()) {
for (const auto& assign : assign_nodes) {
assign->add_control_dependency(variable_id_to_read_value.at(assign->get_variable_id()));
// often Assign node is a leaf of the graph, we add control_dependency for one of the results
// to make Assign node visible for traversals get_ops(), get_ordered_ops()
result_nodes[0]->add_control_dependency(assign);
}
}
return function;
}
V10Parser::GenericLayerParams V10Parser::parseGenericParams(const pugi::xml_node& node) {

View File

@ -334,6 +334,28 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
return res;
});
addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), "Memory",
details::convertPrecision(node->get_output_element_type(0))};
auto res = std::make_shared<CNNLayer>(attrs);
res->params["id"] = params.at("variable_id");
res->params["index"] = "0";
res->params["size"] = "2";
return res;
});
addSpecificCreator({"ReadValue"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), "Memory",
details::convertPrecision(node->get_output_element_type(0))};
auto res = std::make_shared<CNNLayer>(attrs);
res->params["id"] = params.at("variable_id");
res->params["index"] = "1";
res->params["size"] = "2";
return res;
});
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
@ -656,14 +678,16 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
// Collect all names from current graph
// It is necessary in order to differentiate outputs from constant layers when we share constants
// (Constant operations contains outputs for converted and original functions)
const ngraph::NodeVector& nodes = graph->get_ops();
std::unordered_set<std::string> op_names;
for (const auto &layer : graph->get_ops())
for (const auto &layer : nodes)
op_names.insert(layer->get_name());
bool keep_constants = ::ngraph::op::util::has_op_with_type<::ngraph::op::FakeQuantize>(graph);
// Create layers and output data
for (const auto &layer : graph->get_ops()) {
for (const auto &layer : nodes) {
if (isInternalLayer(layer, op_names, keep_constants)) continue;
// TODO: remove this rt info when all blobs will be inputs
@ -703,8 +727,18 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
}
inputCount++;
}
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "1") {
inputCount = 0;
}
cnnLayer->insData.resize(inputCount);
for (size_t i = 0; i < layer->get_output_size(); i++) {
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
cnnLayer->outData.clear();
continue;
}
std::string outName = layer->get_friendly_name();
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
@ -747,6 +781,8 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
// Set input data
for (const auto &layer : graph->get_ordered_ops()) {
if (std::dynamic_pointer_cast<::ngraph::op::ReadValue>(layer))
continue;
if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) {
IE_ASSERT(layer->get_inputs().size() == 1);
const auto &input = layer->input_value(0);

View File

@ -62,9 +62,9 @@ void ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::convert_normalize_l2_
across_spatial,
channel_shared);
normalize_ie->set_friendly_name(normalize->get_friendly_name());
ngraph::copy_runtime_info(normalize, normalize_ie);
ngraph::replace_node(normalize, normalize_ie);
normalize_ie->set_friendly_name(mul->get_friendly_name());
ngraph::copy_runtime_info({normalize, mul}, normalize_ie);
ngraph::replace_node(mul, normalize_ie);
return true;
};

View File

@ -46,6 +46,7 @@ extensions/back/ParameterToPlaceholder.py
extensions/back/pass_separator.py
extensions/back/priorbox_mutation.py
extensions/back/ProposalMutation.py
extensions/back/ReadValueAssignToMemory.py
extensions/back/ReduceToPooling.py
extensions/back/ReduceTransposeDimensions.py
extensions/back/remove_last_softmax_pattern.py
@ -872,6 +873,7 @@ mo/middle/pattern_match.py
mo/middle/replacement.py
mo/ops/__init__.py
mo/ops/activation.py
mo/ops/assign.py
mo/ops/broadcast.py
mo/ops/clamp.py
mo/ops/concat.py
@ -897,6 +899,7 @@ mo/ops/pad.py
mo/ops/permute.py
mo/ops/pooling.py
mo/ops/power.py
mo/ops/read_value.py
mo/ops/reshape.py
mo/ops/result.py
mo/ops/roipooling.py

View File

@ -23,7 +23,7 @@ from mo.ops.crop import Crop
from mo.utils.logger import log
class CutMemory(BackReplacementPattern):
class CutMemoryInput(BackReplacementPattern):
"""
Cut Memory layers and have inputs/outputs in graph instead of them
"""
@ -38,30 +38,56 @@ class CutMemory(BackReplacementPattern):
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='Memory'))],
('op', dict(kind='op', op='ReadValue'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['id']
node_id = node['variable_id']
if node.in_port(0).disconnected():
i = 0
for dest in node.out_port(0).get_destinations():
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
'shape': dest.data.get_shape()}).create_node()
i += 1
dest.disconnect()
new_in.out_port(0).connect(dest)
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
extra={'is_warning': True})
else:
out_node_port = node.out_port(0).get_destination()
in_node_port = node.in_port(0).get_source()
node.in_port(0).disconnect()
node.out_port(0).disconnect()
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node()
in_node_port.connect(crop.in_port(0))
crop.out_port(0).connect(out_node_port)
i = 0
node.in_port(0).disconnect()
for dest in node.out_port(0).get_destinations():
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
'shape': dest.data.get_shape()}).create_node()
i += 1
dest.disconnect()
new_in.out_port(0).connect(dest)
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
extra={'is_warning': True})
class CutMemoryOutput(BackReplacementPattern):
"""
Cut Memory layers and have inputs/outputs in graph instead of them
"""
enabled = True
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi" and graph.graph['cmd_params'].remove_memory]
force_clean_up = True
def run_before(self):
return [ParameterToInput]
@staticmethod
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='Assign'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['variable_id']
out_node_port = node.out_port(0).get_destination()
in_node_port = node.in_port(0).get_source()
node.in_port(0).disconnect()
node.out_port(0).disconnect()
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]),
'axis': np.array([0])}).create_node()
in_node_port.connect(crop.in_port(0))
crop.out_port(0).connect(out_node_port)

View File

@ -17,7 +17,7 @@ import unittest
import numpy as np
from extensions.back.CutMemory import CutMemory
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
@ -29,18 +29,21 @@ class CutMemoryTest(unittest.TestCase):
nodes_attrs={
'input': {'kind': 'op'},
'data_in': {'kind': 'data', 'shape': None, 'value': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'index': 1, 'id': 'memory_', 'in_ports_count': 1},
'const_0': {'kind': 'op', 'op': 'Const'},
'const_0_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'variable_id': 'memory_'},
'data_mem': {'kind': 'data', 'shape': None, 'value': None},
'concat': {'kind': 'op', 'op': 'Concat', 'axis': 0},
'concat_data': {'kind': 'data', 'shape': None, 'value': None},
'some_op': {'kind': 'op'},
'some_op_data': {'kind': 'data', 'shape': None, 'value': None},
'memory_out': {'kind': 'op', 'op': 'Memory', 'index': 0, 'id': 'memory_'},
'memory_out': {'kind': 'op', 'op': 'Assign', 'variable_id': 'memory_'},
'data_mem_out': {'kind': 'data', 'shape': None, 'value': None},
'mem_out_result': {'kind': 'op', 'op': 'Result'}
},
edges=[
('input', 'data_in'), ('memory_in', 'data_mem'),
('input', 'data_in'),
('const_0', 'const_0_data'), ('const_0_data', 'memory_in'), ('memory_in', 'data_mem'),
('data_in', 'concat', {'in': 0}), ('data_mem', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'some_op'),
('some_op', 'some_op_data'), ('some_op_data', 'memory_out'),
@ -69,7 +72,8 @@ class CutMemoryTest(unittest.TestCase):
('crop', 'crop_data'), ('crop_data', 'mem_out_result')
],
)
CutMemory().find_and_replace_pattern(graph)
CutMemoryInput().find_and_replace_pattern(graph)
CutMemoryOutput().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mem_out_result', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -0,0 +1,140 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.ops.memory import Memory
"""
All transformations in this file should be removed after removing IR v7 support
"""
class ReplaceReadValueByMemory(BackReplacementPattern):
"""
Replace ReadValue by Memory. Should be removed after v7 IR support removing.
"""
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
force_clean_up = True
def run_after(self):
return [CutMemoryInput]
@staticmethod
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='ReadValue'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['variable_id']
node.in_port(0).disconnect()
new_in = Memory(graph, {'name': node.id, 'id': node_id, 'index': 1, 'size': 2,
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
for dest in node.out_port(0).get_destinations():
dest.disconnect()
new_in.out_port(0).connect(dest)
class ReplaceAssignByMemory(BackReplacementPattern):
"""
Replace Assign by Memory. Should be removed after v7 IR support removing.
"""
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
force_clean_up = True
def run_after(self):
return [CutMemoryOutput]
@staticmethod
def pattern():
return dict(
nodes=[
('op', dict(kind='op', op='Assign'))],
edges=[]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
node_id = node['variable_id']
new_out = Memory(graph, {'name': node.id, 'id': node_id, 'index': 0, 'size': 2,
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
node.in_port(0).get_source().connect(new_out.in_port(0))
node.in_port(0).disconnect()
node.out_port(0).get_connection().set_source(new_out.out_port(0))
class KaldiRemoveMemoryOutputBackReplacementPatternV7(BackReplacementPattern):
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
def run_after(self):
from extensions.back.pass_separator import BackFinish
return [BackFinish]
def run_before(self):
from extensions.back.SpecialNodesFinalization import RemoveOutputOps
return [RemoveOutputOps]
@staticmethod
def pattern():
return dict(
nodes=[
('memory_node', dict(op='Memory')),
('data_node', dict(kind='data')),
('op_output', dict(op='Result'))
],
edges=[
('memory_node', 'data_node'),
('data_node', 'op_output')
]
)
@staticmethod
def replace_pattern(graph: Graph, match: dict):
"""
Need to find the pattern: Memory -> Data -> Result
It is needed to make Memory nodes appear in IR,
but they are output nodes by default and we remove the Result node after each output memory.
DO NOT use graph clean up after it
otherwise Memory nodes would be removed as they are not on the path from input to output
Parameters
----------
graph : Graph
Graph with loaded model.
match : dict
Patterns which were found in graph structure.
"""
memory = match['memory_node']
data = match['data_node']
op_output = match['op_output']
graph.remove_edge(memory.id, data.id)
graph.remove_node(data.id)
graph.remove_node(op_output.id)

View File

@ -33,7 +33,7 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
def pattern():
return dict(
nodes=[
('memory_node', dict(op='Memory')),
('memory_node', dict(op='Assign')),
('data_node', dict(kind='data')),
('op_output', dict(op='Result'))
],
@ -63,6 +63,8 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
"""
memory = match['memory_node']
data = match['data_node']
op_output = match['op_output']
graph.remove_edge(memory.id, data.id)
graph.remove_node(data.id)
graph.remove_node(op_output.id)

View File

@ -26,7 +26,7 @@ class KaldiRemoveMemoryOutputTest(unittest.TestCase):
'kind': 'data'
},
'memory_node': {
'op': 'Memory',
'op': 'Assign',
'kind': 'op'
},
'output_node': {

View File

@ -63,7 +63,7 @@ def apply_biases_to_last_layer(graph, counts):
outputs_ids = find_outputs(graph)
for output in outputs_ids.copy():
node = Node(graph, output)
if node.op != 'Memory':
if node.op != 'Assign':
continue
outputs_ids.remove(output)

View File

@ -13,12 +13,12 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.eltwise import Eltwise
from mo.ops.eltwise_n import EltwiseN
from mo.utils.error import Error
@ -43,8 +43,12 @@ class ReplaceEltwiseNin1NodePattern(FrontReplacementOp):
edge_attrs = inp[0][1]
graph.add_edge(in_node, ss_node.id, **edge_attrs)
if ss_node.num_splits == 2:
eltwise_node = Eltwise(graph, attrs={'name': 'Eltwise_' + node.name,
'operation': node['operation']}).create_node()
if node['operation'] == 'mul':
eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
elif node['operation'] == 'sum':
eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
else:
raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
elif ss_node.num_splits > 2:
eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
'operation': node['operation']}).create_node()

View File

@ -20,13 +20,19 @@ from extensions.ops.activation_ops import Tanh, Sigmoid
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph
from mo.graph.graph import Node, Graph, Port
from mo.ops.assign import Assign
from mo.ops.broadcast import Broadcast
from mo.ops.clamp import Clamp
from mo.ops.crop import Crop
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.ops.scale_shift import ScaleShiftOp
from mo.ops.shape import Shape
def unique_id(prefix: str = 'id') -> str:
@ -46,6 +52,35 @@ def unique_id(prefix: str = 'id') -> str:
unique_id.names = []
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
# create init_graph connected to ReadValue
graph = input_out_port.node.graph
input_name = input_out_port.node.name
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
shape_of_input.in_port(0).connect(input_out_port)
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
'axis': int64_array([0]), 'offset': int64_array([0])
}).create_node()
get_batch.in_port(0).connect(shape_of_input.out_port(0))
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
'value': int64_array([second_dim]),
'shape': int64_array([1])}).create_node()
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
'axis': 0, 'in_ports_count': 2}).create_node()
mem_shape.in_port(0).connect(get_batch.out_port(0))
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
}).create_node()
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
return init_value_prev_lstm_output
class ReplaceLSTMNodePattern(FrontReplacementOp):
op = "LSTMCell"
enabled = True
@ -69,7 +104,7 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
)
def replace_op(self, graph: Graph, node: Node):
input_node = node.in_node()
input_out_port = node.in_port(0).get_source()
memory_pair_input = unique_id('id')
memory_pair_output = unique_id('id')
@ -81,16 +116,17 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
'bias_term': True,
}
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node([input_node])
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node()
fc_layer_after_input.in_port(0).connect(input_out_port)
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
'id': memory_pair_input,
'index': 1,
'size': 2,
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
}).create_node()
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
node.gifo_r_weights_shape[1])
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
'variable_id': memory_pair_input
}).create_node()
prev_lstm_output.in_port(0).connect(init_value_prev_lstm_output.out_port(0))
# *Memory(output) -> FullyConnected
fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
@ -99,15 +135,16 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
'bias_term': False,
}
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node(
[prev_lstm_output])
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node()
fc_layer_from_prev_state.in_port(0).connect(prev_lstm_output.out_port(0))
input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
# Memory -> FullyConnected \
# *Eltwise(sum)
# Input -> FullyConnected /
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise',
}).create_node([fc_layer_from_prev_state, fc_layer_after_input])
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise'}).create_node()
join_input_prev_state_sum.in_port(0).connect(fc_layer_from_prev_state.out_port(0))
join_input_prev_state_sum.in_port(1).connect(fc_layer_after_input.out_port(0))
# *Eltwise(sum) -> Split
# it is split into 4 nodes: Act, Eltw*3
@ -120,131 +157,147 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
# |____(4)Eltwise(sum)
split_joined_input_axis = Const(graph, {'value': np.int64(1)}).create_node()
split_joined_input = Split(graph, {'name': 'join_input_split',
'num_splits': 4,
}).create_node([join_input_prev_state_sum, split_joined_input_axis])
'num_splits': 4, 'out_ports_count': 4}).create_node()
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
'id': memory_pair_output,
'index': 1,
'size': 2,
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
}).create_node()
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
# 'id': memory_pair_output,
# 'index': 1,
# 'size': 2,
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
# }).create_node()
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
node.input_gate_weights.shape[0])
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
'variable_id': memory_pair_output}).create_node()
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))
# *Memory(state) -> *ScaleShift(input)
state_input_scaleshift_attrs = {'name': 'input_scaleshift',
'bias_term': False
}
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node()
state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
# *Memory(state) -> *ScaleShift(forget)
state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
'bias_term': False
}
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node()
state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
# Split \
# (2)Eltwise(sum)
# Memory(state) -> *ScaleShift(input) /
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
}).create_node([(split_joined_input, 1),
state_input_scaleshift
])
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise'
}).create_node()
join_prev_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(1))
join_prev_lstm_input_joined_input_sum.in_port(1).connect(state_input_scaleshift.out_port(0))
# Split \
# (3)Eltwise(sum)
# Memory(state) -> *ScaleShift(forget) /
join_prev_lstm_input_joined_forget_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
}).create_node([(split_joined_input, 2),
state_forget_scaleshift
])
}).create_node()
join_prev_lstm_input_joined_forget_sum.in_port(0).connect(split_joined_input.out_port(2))
join_prev_lstm_input_joined_forget_sum.in_port(1).connect(state_forget_scaleshift.out_port(0))
# Split -> Tanh
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node([(split_joined_input, 0)])
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
remember_tahn.in_port(0).connect(split_joined_input.out_port(0))
# Split -> (2)Eltwise(sum) -> *Sigmoid
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'
}).create_node([join_prev_lstm_input_joined_input_sum])
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'}).create_node()
remember_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_input_sum.out_port(0))
# Split -> (3)Eltwise(sum) -> **Sigmoid
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'
}).create_node([join_prev_lstm_input_joined_forget_sum])
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'}).create_node()
forget_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_forget_sum.out_port(0))
# *Memory(state) \
# (6)Eltwise(mul)
# Split -> (3)Eltwise(sum) -> **Sigmoid /
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul',
}).create_node([forget_sigmoid, prev_lstm_state])
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul'}).create_node()
join_forget_prev_state_mul.in_port(0).connect(forget_sigmoid.out_port(0))
join_forget_prev_state_mul.in_port(1).connect(prev_lstm_state.out_port(0))
# Split -> Tahn \
# (5)Eltwise(mul)
# Split -> (2)Eltwise(sum) -> *Sigmoid /
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul',
}).create_node([remember_tahn, remember_sigmoid])
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul'}).create_node()
join_remember_candidates_mul.in_port(0).connect(remember_tahn.out_port(0))
join_remember_candidates_mul.in_port(1).connect(remember_sigmoid.out_port(0))
# (5)Eltwise(mul) \
# (7)Eltwise(sum)
# (6)Eltwise(mul) /
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum',
}).create_node(
[join_forget_prev_state_mul, join_remember_candidates_mul])
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum'}).create_node()
join_forget_remember_sum.in_port(0).connect(join_forget_prev_state_mul.out_port(0))
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
# (7)Eltwise(sum) -> Clamp
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
'max': node.clip_value,
'min': -node.clip_value
}).create_node(
[join_forget_remember_sum])
'min': -node.clip_value}).create_node()
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
#
# Clamp -> (2)Memory(state)
next_lstm_state = Memory(graph, {'name': 'next_lstm_state',
'id': memory_pair_output,
'index': 0,
'size': 2,
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
}).create_node([join_forget_clamp])
Result(graph, {'name': 'next_lstm_state_out'}).create_node([next_lstm_state])
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
'variable_id': memory_pair_output}).create_node()
next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))
res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
res_node.in_port(0).connect(next_lstm_state.out_port(0))
# Clamp -> (2)Tahn
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node([join_forget_clamp])
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node()
state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))
# Clamp -> (2)ScaleShift
clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
'bias_term': False}
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node()
clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
# Split \
# (4)Eltwise(sum)
# Clamp -> (2)ScaleShift /
join_next_lstm_input_joined_input_sum = Add(graph, {'name': 'join_next_lstm_input_joined_input_sum',
}).create_node([(split_joined_input, 3), clamp_scaleshift])
}).create_node()
join_next_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(3))
join_next_lstm_input_joined_input_sum.in_port(1).connect(clamp_scaleshift.out_port(0))
# (4)Eltwise(sum) -> (3)Sigmoid
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node([join_next_lstm_input_joined_input_sum])
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node()
output_sigmoid.in_port(0).connect(join_next_lstm_input_joined_input_sum.out_port(0))
# (4)Eltwise(sum) -> (3)Sigmoid \
# (5)Eltwise(mul)
# Clamp -> (2)Tahn /
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node([state_filtered_tahn, output_sigmoid])
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node()
joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))
# (5)Eltwise(mul) -> (3)FullyConnected
fc_output_attrs = {'name': 'FullyConnected',
'out-size': node.projection_weights_shape[0],
'transpose_weights': True,
'bias_term': False}
fc_output = FullyConnected(graph, fc_output_attrs).create_node([joined_output_mul])
fc_output = FullyConnected(graph, fc_output_attrs).create_node()
fc_output.in_port(0).connect(joined_output_mul.out_port(0))
input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights)
# / (2)Memory(output)
# (3)FullyConnected
# \ Output (any next node) (edge created automatically after replacement)
next_lstm_output = Memory(graph, {'name': 'next_lstm_output',
'id': memory_pair_input,
'index': 0,
'size': 2,
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
}).create_node([fc_output])
Result(graph, {'name': 'next_lstm_output_out'}).create_node([next_lstm_output])
next_lstm_output = Assign(graph, {'name': 'next_lstm_output',
'variable_id': memory_pair_input}).create_node()
next_lstm_output.in_port(0).connect(fc_output.out_port(0))
res_node_lstm_output = Result(graph, {'name': 'next_lstm_output_out'}).create_node()
res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))
return [fc_output.id]

View File

@ -22,7 +22,7 @@ from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.eltwise import Eltwise
from extensions.ops.elementwise import Add, Mul
from mo.ops.scale_shift import ScaleShiftOp
@ -41,19 +41,19 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
def replace_op(self, graph: Graph, node: Node):
# split input to (i_part, f_part, c_part, o_part, ct_1)
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
split_node = Split(graph, {'name': graph.unique_id(prefix='Split_lstm_input_'),
split_node = Split(graph, {'name': 'Split_lstm_input_',
'num_splits': 5}).create_node()
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
split_node.in_port(1).connect(split_node_axis.out_port(0))
# i_t = Sigmoid(i_part + w_ic*ct_1)
i_scale_attrs = {'name': graph.unique_id(prefix='i_scaleshift'),
i_scale_attrs = {'name': 'i_scaleshift',
'bias_term': False}
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
split_node.out_port(4).connect(i_scale.in_port(0))
sum_i_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum'}).create_node()
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
split_node.out_port(0).connect(sum_i_c.in_port(0))
i_scale.out_port(0).connect(sum_i_c.in_port(1))
@ -61,13 +61,13 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
# f_t = Sigmoid(f_part + w_fc*ct_1)
f_scale_attrs = {'name': graph.unique_id(prefix='f_scaleshift'),
f_scale_attrs = {'name': 'f_scaleshift',
'bias_term': False}
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
split_node.out_port(4).connect(f_scale.in_port(0))
sum_f_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum'}).create_node()
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
split_node.out_port(1).connect(sum_f_c.in_port(0))
f_scale.out_port(0).connect(sum_f_c.in_port(1))
@ -78,28 +78,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
split_node.out_port(2).connect(c_tanh.in_port(0))
prod_i_c_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_i_c_tanh_'),
'operation': 'mul'}).create_node()
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
prod_f_ct_1 = Eltwise(graph, {'name': graph.unique_id(prefix='prod_f_ct_1_'),
'operation': 'mul'}).create_node()
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
sum_f_i = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum'}).create_node()
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
# o_t = Sigmoid(o_part + w_oc*c_t)
o_scale_attrs = {'name': graph.unique_id(prefix='o_scaleshift'),
o_scale_attrs = {'name': 'o_scaleshift',
'bias_term': False}
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
sum_f_i.out_port(0).connect(o_scale.in_port(0))
sum_o_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum'}).create_node()
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
split_node.out_port(3).connect(sum_o_c.in_port(0))
o_scale.out_port(0).connect(sum_o_c.in_port(1))
@ -110,13 +108,12 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
prod_o_c_t_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
'operation': 'mul'}).create_node()
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
# add concat to create 1 output
concat = Concat(graph, {'name': graph.unique_id(prefix='Concat_c_m')}).create_node()
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
concat.add_sequence_of_ports('in', range(2))
sum_f_i.out_port(0).connect(concat.in_port(0))
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

View File

@ -15,15 +15,18 @@
"""
import numpy as np
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
from extensions.ops.elementwise import Equal
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.utils.error import Error
from mo.utils.graph import invert_sub_graph_between_nodes
@ -48,7 +51,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
@staticmethod
def pattern():
return dict(
nodes=[('op', dict(op='Memory', index=0))],
nodes=[('op', dict(op='Assign'))],
edges=[])
@staticmethod
@ -93,9 +96,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
select_node.in_port(2).connect(zero_else.out_port(0))
# check if we have already appropriate iteration counter
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='Memory', index=1,
shape=int64_array([context_len]))),
('mem_in_data', dict()),
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
('mem_in_data', dict(shape=int64_array([context_len]))),
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
offset=int64_array([1]),
dim=int64_array([context_len-1]))),
@ -104,8 +106,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
('concat_data', dict()),
('const_1', dict(op='Const')),
('const_1_data', dict()),
('mem_out', dict(op='Memory', index=0,
shape=int64_array([context_len]))),
('mem_out', dict(op='Assign')),
('crop_out', dict(op='Crop', axis=int64_array([1]),
offset=int64_array([0]),
dim=int64_array([1]))),
@ -122,12 +123,13 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
('crop_out_data', 'select')])
counter_match = next(existing_counters, None)
if counter_match is not None:
ones = Node(graph, inverse_dict(counter_match)['const_1'])
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
else:
mem_out = Memory(graph, {'name': 'iteration_number', 'size': 2,
'index': 1, 'id': 'iteration_' + node.name,
'shape': int64_array([context_len]),
'dst_type': np.int32}).create_node()
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
mem_out = ReadValue(graph, {'name': 'iteration_number',
'variable_id': 'iteration_'+node.name}).create_node()
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
cut_first.in_port(0).connect(mem_out.out_port(0))
@ -135,9 +137,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
concat.in_port(0).connect(cut_first.out_port(0))
concat.in_port(1).connect(ones.out_port(0))
mem_in = Memory(graph, {'name': 'iteration_number_out', 'size': 2,
'index': 0, 'id': 'iteration_' + node.name,
'shape': int64_array([context_len])}).create_node()
mem_in = Assign(graph, {'name': 'iteration_number_out',
'variable_id': 'iteration_'+node.name}).create_node()
mem_in.in_port(0).connect(concat.out_port(0))
res = Result(graph, {}).create_node()
mem_in.out_port(0).connect(res.in_port(0))
@ -146,6 +147,12 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
cut_last.in_port(0).connect(concat.out_port(0))
input_port = cut_last.out_port(0)
select_node.in_port(0).connect(input_port)
# Check if data from memory is 1
# if it is True, we have correct data and should proceed with saving it to memory
# else we have not gathered context and have garbage here, shouldn't change initial state of memory
cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
cast_in.in_port(0).connect(ones.out_port(0))
cast_in.in_port(1).connect(input_port)
select_node.in_port(0).connect(cast_in.out_port(0))
select_node.out_port(0).connect(node.in_port(0))
select_node.out_port(0).data.set_shape(in_node_shape)

View File

@ -30,23 +30,14 @@ class InsertSelectTests(unittest.TestCase):
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'memory')
],
nodes_with_edges_only=True)
ref_graph = graph.copy()
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'memory')
],
nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
self.assertTrue(flag, resp)
@ -60,7 +51,7 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
'placeholder_2': {'kind': 'op', 'op': None},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@ -76,15 +67,32 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
'placeholder_2': {'kind': 'op', 'op': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
'equal': {'kind': 'op', 'op': 'Equal'},
'equal_data': {'kind': 'data'},
'select': {'kind': 'op', 'op': 'Select'},
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
@ -95,22 +103,34 @@ class InsertSelectTests(unittest.TestCase):
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),
('select_out_data', 'memory')
],
@ -132,7 +152,7 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@ -151,15 +171,32 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
'equal': {'kind': 'op', 'op': 'Equal'},
'equal_data': {'kind': 'data'},
'select': {'kind': 'op', 'op': 'Select'},
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
@ -170,7 +207,7 @@ class InsertSelectTests(unittest.TestCase):
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@ -178,13 +215,25 @@ class InsertSelectTests(unittest.TestCase):
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'select', {'in': 0}),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),
@ -208,7 +257,7 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@ -227,15 +276,32 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
'equal': {'kind': 'op', 'op': 'Equal'},
'equal_data': {'kind': 'data'},
'select': {'kind': 'op', 'op': 'Select'},
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
@ -246,7 +312,7 @@ class InsertSelectTests(unittest.TestCase):
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
@ -254,13 +320,25 @@ class InsertSelectTests(unittest.TestCase):
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'select', {'in': 0}),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),

View File

@ -59,7 +59,18 @@ class RemoveUselessCropsPattern(MiddleReplacementPattern):
if out['op'] == 'Crop' and out['axis'] == axis and \
len(out.out_port(0).get_destinations()) == 1 and \
out.out_port(0).get_destination().node == concat_node:
offsets_dims.append((out['offset'], out['dim']))
# crop type 1
if 'dim' in out:
offsets_dims.append((out['offset'], out['dim']))
# crop type 3
elif 'crop_begin' in out and 'crop_end' in out:
offsets_dims.append((out['crop_begin'], out['crop_end']-out['crop_begin']))
# crop type 2 with const dim
elif not out.in_port(1).disconnected() and out.in_port(1).data.get_value() is not None:
offsets_dims.append((out['offset'], out.in_port(1).data.get_value()))
# crop type 2 with non-const dim or strange type of crop
else:
return
crop_list.append(out)
offsets_dims.sort(key=lambda off_dim: off_dim[0])

View File

@ -84,6 +84,136 @@ class RemoveUselessCropsPatternTests(unittest.TestCase):
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
self.assertTrue(flag, resp)
def test_useless_crops_type2(self):
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
'const_26_data': {'kind': 'data', 'value': 26},
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('crop_data_1', 'concat'),
('crop_data_2', 'concat'),
('crop_data_3', 'concat'),
('crop_data_4', 'concat'),
('crop_data_5', 'concat'),
('concat', 'concat_data'),
('concat_data', 'placeholder')])
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
'const_26_data': {'kind': 'data', 'value': 26},
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'dim': 26, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[
('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('concat', 'concat_data'),
('in_node', 'placeholder')
]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
self.assertTrue(flag, resp)
def test_useless_crops_type3(self):
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('crop_data_1', 'concat'),
('crop_data_2', 'concat'),
('crop_data_3', 'concat'),
('crop_data_4', 'concat'),
('crop_data_5', 'concat'),
('concat', 'concat_data'),
('concat_data', 'placeholder')])
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 130]},
'placeholder': {'kind': 'op', 'op': 'Parameter'},
},
[
('placeholder_in', 'in_node'),
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
('concat', 'concat_data'),
('in_node', 'placeholder')
]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
self.assertTrue(flag, resp)
def test_useful_crops(self):
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
'in_node': {'kind': 'data', 'shape': [1, 130]},

View File

@ -15,13 +15,15 @@
"""
import numpy as np
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
from extensions.ops.splice import Splice
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.crop import Crop
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.utils.error import Error
@ -67,7 +69,8 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
splice = Splice(graph, {'name': node_name,
'id': node_id,
'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
'context': int64_array(range(node_t, 1))
if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
splice.in_port(0).connect(input_node_out_port)
# offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
@ -106,6 +109,7 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
"""
enabled = True
force_shape_inference = True
def run_before(self):
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
@ -141,43 +145,34 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
in_shape = input_port.data.get_shape()
node_t = abs(node.t)
memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
'index': 1, 'size': 2,
'shape': np.array([in_shape[1]*node_t])}).create_node()
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
if node_t > 1:
crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
memory_out.out_port(0).connect(crop_concat.in_port(0))
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
concat.add_sequence_of_ports('in', range(2))
crop_concat.out_port(0).connect(concat.in_port(0))
crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
concat.in_port(1).connect(input_port)
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
'index': 0, 'size': 2,
'shape': memory_out.shape}).create_node()
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
concat.out_port(0).connect(memory_in.in_port(0))
concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
out = Result(graph, {'name': 'Memory_output'}).create_node()
memory_in.out_port(0).connect(out.in_port(0))
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
'offset': np.array([0]), 'axis': np.array([1])}).create_node()
memory_out.out_port(0).connect(crop_out.in_port(0))
out_port.get_connection().set_source(crop_out.out_port(0))
crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
else:
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
'index': 0, 'size': 2,
'shape': memory_out.shape}).create_node()
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
memory_in.in_port(0).connect(input_port)
out = Result(graph, {'name': 'Memory_output'}).create_node()
memory_in.out_port(0).connect(out.in_port(0))
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
out_port.get_connection().set_source(memory_out.out_port(0))
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
graph.remove_node(op_output_id)
graph.remove_node(node.id)

View File

@ -13,15 +13,16 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
from extensions.ops.split import VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.crop import Crop
from mo.ops.memory import Memory
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
@ -39,7 +40,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
So this pass will convert this graph to the next one:
Input [N, H] __
\ /
/ /
Concat [N, k*H]
/ \
Memory [N, k*H] -> Slice [N, (k-1)*H] Memory [N, k*H]
@ -67,11 +68,9 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
memory_pair_id = unique_id('id')
# Memory(in)
input_memory = Memory(graph, {'name': 'prev_splice_memory',
'id': memory_pair_id,
'index': 1,
'size': 2,
'shape': int64_array([memory_size])}).create_node()
input_memory = ReadValue(graph, {'name': 'prev_splice_memory',
'variable_id': memory_pair_id}).create_node()
# Memory(in) \
# Crop
# Input(temp) /
@ -90,11 +89,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
concat_node.in_port(0).connect(crop.out_port(0))
# Concat -> Memory(out)
mem_out = Memory(graph, {'name': 'out_splice_memory',
'id': memory_pair_id,
'index': 0,
'size': 2,
'shape': int64_array([memory_size])}).create_node()
mem_out = Assign(graph, {'name': 'out_splice_memory', 'variable_id': memory_pair_id}).create_node()
mem_out.in_port(0).connect(concat_node.out_port(0))
Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))
@ -110,11 +105,12 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
# create separate splice construction for const_dim
memory_pair_id = unique_id('memory_for_const_dim')
input_memory_const_dim = Memory(graph, {'name': 'const_dim_in_memory',
'id': memory_pair_id,
'index': 1,
'size': 2,
'shape': int64_array([memory_size_constdim])}).create_node()
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
memory_size_constdim)
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
'variable_id': memory_pair_id}).create_node()
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
crop_const_dim = Crop(graph, {'name': 'const_dim_crop',
'axis': int64_array([1]),
'offset': int64_array([memory_element_constdim]),
@ -127,11 +123,8 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
'axis': 1}).create_node()
concat_node_const_dim.in_port(0).connect(crop_const_dim.out_port(0))
mem_out_const_dim = Memory(graph, {'name': 'const_dim_out_memory',
'id': memory_pair_id,
'index': 0,
'size': 2,
'shape': int64_array([memory_size_constdim])}).create_node()
mem_out_const_dim = Assign(graph, {'name': 'const_dim_out_memory',
'variable_id': memory_pair_id}).create_node()
mem_out_const_dim.in_port(0).connect(concat_node_const_dim.out_port(0))
Result(graph).create_node().in_port(0).connect(mem_out_const_dim.out_port(0))
@ -148,9 +141,15 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
concat_const.in_port(1).connect(crop_first.out_port(0))
concat_const.in_port(0).connect(concat_node.out_port(0))
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
memory_size)
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
node.in_port(0).get_connection().set_destination(split.in_port(0))
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
else:
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
memory_size)
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
node.out_port(0).get_connection().set_source(concat_node.out_port(0))

View File

@ -16,6 +16,7 @@
import unittest
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
@ -42,19 +43,47 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
'in_node': {'kind': 'data', 'shape': [1, 13]},
'memory_in': {'kind': 'op', 'op': 'Memory'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
'crop_mem_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data', 'shape': [1, 143]},
'memory_out': {'kind': 'op', 'op': 'Memory'},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
},
[
('in_placeholder', 'in_node'),
('in_node', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}),
('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'),
('memory_in_data', 'crop_mem'),
('crop_mem', 'crop_mem_data'),
@ -86,22 +115,54 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
'split': {'kind': 'op', 'op': 'Split'},
'split_data_0': {'kind': 'data'},
'split_data_1': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'Memory'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
'crop_mem_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Memory'},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
'shape_2_data': {'kind': 'data'},
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_2_data': {'kind': 'data'},
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_2_data': {'kind': 'data'},
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
'second_dim_2_data': {'kind': 'data'},
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
'gather_shape_2_data': {'kind': 'data'},
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_2_data': {'kind': 'data'},
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_2_data': {'kind': 'data'},
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_constdims_data': {'kind': 'data'},
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
'crop_mem_constdims_data': {'kind': 'data'},
'concat_constdims': {'kind': 'op', 'op': 'Concat'},
'concat_constdims_data': {'kind': 'data'},
'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
'memory_out_constdims': {'kind': 'op', 'op': 'Assign'},
'memory_out_constdims_data': {'kind': 'data'},
'result_constdims': {'kind': 'op', 'op': 'Result'},
'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
@ -121,6 +182,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
('in_node', 'split', {'in': 0}),
('split', 'split_data_0', {'out': 0}),
('split', 'split_data_1', {'out': 1}),
('split_data_0', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}),
('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'),
('memory_in_data', 'crop_mem'),
('crop_mem', 'crop_mem_data'),
@ -130,6 +203,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'),
('memory_out_data', 'result'),
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
('gather_shape_2', 'gather_shape_2_data'),
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
('broadcast_2_data', 'memory_in_constdims'),
('memory_in_constdims', 'memory_in_constdims_data'),
('memory_in_constdims_data', 'crop_mem_constdims'),
('crop_mem_constdims', 'crop_mem_constdims_data'),

View File

@ -113,9 +113,6 @@ def prepare_ir(argv: argparse.Namespace):
elif (is_kaldi or is_onnx) and not argv.input_model:
raise Error('Path to input model is required: use --input_model.')
if is_kaldi:
argv.generate_experimental_IR_V10 = False
log.debug(str(argv))
log.debug("Model Optimizer started")

View File

@ -0,0 +1,40 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.graph.graph import Graph, Node
from mo.ops.op import Op
class Assign(Op):
op = 'Assign'
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': self.op,
'op': self.op,
'version': 'opset3',
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
def backend_attrs(self):
return ['variable_id']
@staticmethod
def infer(node: Node):
assert node.has_valid('variable_id'), \
"There is no required attribute variable_id in Assign op with name " + node.id
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())

View File

@ -0,0 +1,45 @@
"""
Copyright (C) 2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.graph.graph import Graph, Node
from mo.ops.op import Op
class ReadValue(Op):
op = 'ReadValue'
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': self.op,
'op': self.op,
'version': 'opset3',
'infer': self.infer,
'type_infer': self.type_infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
def backend_attrs(self):
return ['variable_id']
@staticmethod
def type_infer(node: Node):
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
@staticmethod
def infer(node: Node):
assert node.has_valid('variable_id'), \
"There is no required attribute variable_id in ReadValue op with name " + node.id
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())

View File

@ -32,6 +32,7 @@ from mo.ops.convolution import Convolution
from mo.ops.deconvolution import Deconvolution
from mo.ops.op import Op
from mo.ops.pooling import Pooling
from mo.ops.result import Result
from mo.utils.class_registration import update_registration
from mo.utils.import_extensions import import_by_path
from mo.utils.ir_reader.extender import Extender
@ -218,6 +219,18 @@ def ti_add_edge_attrs(op: Node):
i += 1
def assign_add_output_result(op: Node):
"""
Function adds necessary output result node for Assign node
:param op:
:return:
"""
assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
''.format(op.soft_get('type'))
tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
op.out_port(0).connect(tmp_result.in_port(0))
def copy_input_blobs(op: Node, copy_op: Node):
"""
Function copy input blob data nodes from restored graph to copied one
@ -243,6 +256,7 @@ preprocessing_op_nodes = {
# Map with postprocessing functions for nodes
postprocessing_op_nodes = {
'Assign': assign_add_output_result,
'TensorIterator': ti_add_edge_attrs,
}

View File

@ -22,6 +22,7 @@ from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConst
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
from extensions.back.TopKNormalizer import TopKNormalizer
from extensions.back.blob_normalizer import BlobNormalizer
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_precision
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
@ -77,6 +78,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
PackBinaryWeights,
BlobNormalizer,
ConvolutionNormalizer,
KaldiRemoveMemoryOutputBackReplacementPattern,
]
# We need to run some specific passes from MO back stage.

View File

@ -137,6 +137,8 @@ set (SRC
op/asin.hpp
op/asinh.cpp
op/asinh.hpp
op/assign.cpp
op/assign.hpp
op/atan.cpp
op/atan.hpp
op/atanh.cpp
@ -314,6 +316,8 @@ set (SRC
op/proposal.cpp
op/psroi_pooling.hpp
op/psroi_pooling.cpp
op/read_value.hpp
op/read_value.cpp
op/reduce_logical_and.cpp
op/reduce_logical_and.hpp
op/reduce_logical_or.cpp
@ -518,6 +522,7 @@ set (SRC
op/util/scatter_base.hpp
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp
op/util/variable.hpp
ops.hpp
opsets/opset.cpp
partial_shape.cpp

View File

@ -0,0 +1,81 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/assign.hpp"
#include <ops.hpp>
#include "ngraph/op/read_value.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v3::Assign::type_info;
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
: Op({new_value})
, m_variable_id(variable_id)
{
constructor_validate_and_infer_types();
}
void op::v3::Assign::validate_and_infer_types()
{
auto value = input_value(0);
auto arg_t = get_input_element_type(0);
auto output_shape = get_input_partial_shape(0);
if (!m_variable)
{
NodeVector start_nodes;
for (const auto& input : inputs())
{
start_nodes.push_back(input.get_source_output().get_node_shared_ptr());
}
auto nodes = topological_sort(start_nodes);
for (const auto& node : nodes)
{
if (auto read_value = as_type_ptr<op::v3::ReadValue>(node))
{
if (read_value->get_variable_id() == m_variable_id)
m_variable = read_value->get_variable();
}
}
NODE_VALIDATION_CHECK(
this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
}
auto variable_info = m_variable->get_info();
NODE_VALIDATION_CHECK(this,
m_variable_id == variable_info.variable_id,
"Variables identifiers are inconsistent.");
NODE_VALIDATION_CHECK(
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
NODE_VALIDATION_CHECK(this,
output_shape == variable_info.data_shape,
"Variables output shapes are inconsistent.");
set_output_type(0, arg_t, output_shape);
}
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Assign>(new_args.at(0), m_variable_id);
}
bool op::v3::Assign::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("variable_id", m_variable_id);
return true;
}

View File

@ -0,0 +1,67 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/variable.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief Assign operation sets an input value to the variable with `variable_id`
class NGRAPH_API Assign : public Op
{
public:
static constexpr NodeTypeInfo type_info{"Assign", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Assign() = default;
/// \brief Constructs an Assign operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable_id identificator of the variable to be updated.
Assign(const Output<Node>& new_value, const std::string& variable_id);
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::string get_variable_id() { return m_variable_id; }
std::shared_ptr<ngraph::Variable> get_variable() { return m_variable; }
void set_variable_id(const std::string& variable_id)
{
m_variable_id = variable_id;
}
void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
{
m_variable = variable;
}
private:
std::string m_variable_id;
std::shared_ptr<ngraph::Variable> m_variable;
};
}
using v3::Assign;
}
}

View File

@ -265,3 +265,5 @@ NGRAPH_OP(Transpose, ngraph::op::v1, 1)
NGRAPH_OP(Unsqueeze, ngraph::op::v0, 0)
NGRAPH_OP(VariadicSplit, ngraph::op::v1, 1)
NGRAPH_OP(Xor, ngraph::op::v0, 0)
NGRAPH_OP(Assign, ngraph::op::v3, 3)
NGRAPH_OP(ReadValue, ngraph::op::v3, 3)

View File

@ -0,0 +1,51 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/read_value.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::ReadValue::type_info;
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id)
: Op({new_value})
, m_variable_id(variable_id)
{
constructor_validate_and_infer_types();
}
void op::ReadValue::validate_and_infer_types()
{
auto arg_t = get_input_element_type(0);
auto output_shape = get_input_partial_shape(0);
VariableInfo info = {output_shape, arg_t, m_variable_id};
m_variable = std::make_shared<Variable>(info);
set_output_type(0, arg_t, output_shape);
}
shared_ptr<Node> op::ReadValue::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ReadValue>(new_args.at(0), m_variable_id);
}
bool op::v3::ReadValue::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("variable_id", m_variable_id);
return true;
}

View File

@ -0,0 +1,68 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/variable.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
/// \brief ReadValue operation creates the variable with `variable_id` and returns value
/// of this variable.
class NGRAPH_API ReadValue : public Op
{
public:
static constexpr NodeTypeInfo type_info{"ReadValue", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
ReadValue() = default;
/// \brief Constructs a ReadValue operation.
///
/// \param new_value Node that produces the input tensor.
/// \param variable_id identificator of the variable to create.
ReadValue(const Output<Node>& new_value, const std::string& variable_id);
void validate_and_infer_types() override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::string get_variable_id() { return m_variable_id; }
std::shared_ptr<ngraph::Variable> get_variable() { return m_variable; }
void set_variable_id(const std::string& variable_id)
{
m_variable_id = variable_id;
}
void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
{
m_variable = variable;
}
private:
std::string m_variable_id;
std::shared_ptr<ngraph::Variable> m_variable;
};
}
using v3::ReadValue;
}
}

View File

@ -23,6 +23,9 @@
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/topk.hpp"
using namespace std;
using namespace ngraph;
@ -225,9 +228,224 @@ void op::v0::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */,
throw ngraph_error("Forward-propagation-only operation");
}
namespace
{
template <element::Type_t INPUT_ET, element::Type_t INDEX_ET>
inline bool evaluate_execute(const HostTensorPtr& arg0,
const HostTensorPtr& out_indices,
const HostTensorPtr& out_values,
const Shape out_shape,
const size_t axis,
const size_t k,
const bool compute_max,
const op::TopK::SortType sort)
{
using T = typename element_type_traits<INPUT_ET>::value_type;
using U = typename element_type_traits<INDEX_ET>::value_type;
const Shape in_shape = arg0->get_shape();
out_indices->set_shape(out_shape);
out_indices->set_element_type(INDEX_ET);
out_values->set_shape(out_shape);
out_values->set_element_type(arg0->get_element_type());
runtime::reference::topk<T, U>(arg0->get_data_ptr<INPUT_ET>(),
out_indices->get_data_ptr<INDEX_ET>(),
out_values->get_data_ptr<INPUT_ET>(),
in_shape,
out_shape,
axis,
k,
compute_max,
sort);
return true;
}
template <element::Type_t INPUT_ET>
bool evaluate(const HostTensorPtr& arg,
const HostTensorPtr& out_indices,
const HostTensorPtr& out_values,
const Shape out_shape,
const size_t axis,
const size_t k,
const bool max,
const op::TopK::SortType sort,
const element::Type index_et)
{
bool rc = true;
switch (index_et)
{
case element::Type_t::i64:
evaluate_execute<INPUT_ET, element::Type_t::i64>(
arg, out_indices, out_values, out_shape, axis, k, max, sort);
break;
case element::Type_t::i32:
evaluate_execute<INPUT_ET, element::Type_t::i32>(
arg, out_indices, out_values, out_shape, axis, k, max, sort);
break;
default: rc = false; break;
}
return rc;
}
bool evaluate_topk(const HostTensorPtr& arg,
const HostTensorPtr& out_indices,
const HostTensorPtr& out_values,
const Shape out_shape,
const size_t axis,
const size_t k,
const bool max,
const op::TopK::SortType sort,
const element::Type index_et)
{
bool rc = true;
switch (arg->get_element_type())
{
TYPE_CASE(i8)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(i16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(i32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(i64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(u8)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(u16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(u32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(u64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(bf16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(f16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(f32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
TYPE_CASE(f64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
break;
default: rc = false; break;
}
return rc;
}
template <element::Type_t K_ET>
size_t get_k_from_hosttensor(const HostTensorPtr& arg)
{
using T = typename element_type_traits<K_ET>::value_type;
auto p = arg->get_data_ptr<T>();
size_t k = p[0];
return k;
}
#define CASE_GET_K(a) \
case element::Type_t::a: k = get_k_from_hosttensor<element::Type_t::a>
size_t read_k_from_host_tensor(const HostTensorPtr& arg_k)
{
size_t k = 0;
switch (arg_k->get_element_type())
{
CASE_GET_K(i8)(arg_k);
break;
CASE_GET_K(i16)(arg_k);
break;
CASE_GET_K(i32)(arg_k);
break;
CASE_GET_K(i64)(arg_k);
break;
CASE_GET_K(u8)(arg_k);
break;
CASE_GET_K(u16)(arg_k);
break;
CASE_GET_K(u32)(arg_k);
break;
CASE_GET_K(u64)(arg_k);
break;
default:
// other types are not supported and would have thrown in ctor
ngraph_error("read_k_from_host_tensor: type is not integral\n");
break;
}
return k;
}
// used in only v0, where type is set as int64_t
size_t read_top_k_axis_from_host_tensor(const HostTensorPtr& arg)
{
NGRAPH_CHECK(arg->get_element_type() == element::i64,
"TopK axis element type should be i64");
auto p = arg->get_data_ptr<int64_t>();
size_t axis = static_cast<size_t>(p[0]);
return axis;
}
}
Shape op::v0::TopK::compute_output_shape(const Shape input_shape,
const int64_t k,
const size_t axis)
{
Shape output_shape{input_shape};
if (k != 0)
{
output_shape[axis] = k;
}
return output_shape;
}
bool op::v0::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
{
// check data types for arg, k and output element type
Shape arg_shape = inputs[0]->get_shape();
// 1. get axis, mode ( max/min), sort_type
size_t axis = 0;
Dimension axis_dim = get_top_k_axis_dynamic();
if (axis_dim.is_static())
{
axis = axis_dim.get_length();
}
else
{
axis = read_top_k_axis_from_host_tensor(inputs[2]);
NGRAPH_CHECK(axis <= arg_shape.size(), "TopK axis is out of bounds");
}
bool compute_max = get_compute_max();
SortType sort_type = get_sort();
// 2. get value of k - from constant node or from HT
size_t k = get_k();
if (k == 0)
{
k = read_k_from_host_tensor(inputs[1]);
if (k == 0)
{
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
k = arg_shape[axis];
}
}
NGRAPH_CHECK(k <= arg_shape.at(axis), "K exceeds the dimension of the TopK axis");
// 3. Compute output_shape
auto output_shape = compute_output_shape(inputs[0]->get_shape(), k, axis);
return evaluate_topk(inputs[0],
outputs[0],
outputs[1],
output_shape,
axis,
k,
compute_max,
sort_type,
get_index_element_type());
}
// v1 version starts
constexpr NodeTypeInfo op::v1::TopK::type_info;
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
@ -236,7 +454,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
const element::Type& index_element_type)
: Op{{data, k}}
, m_axis{axis}
, m_normalized_axis{0}
, m_normalized_axis{UNKNOWN_NORMALIZED_AXIS}
, m_mode{as_enum<Mode>(mode)}
, m_sort{as_enum<SortType>(sort)}
, m_index_element_type{index_element_type}
@ -244,8 +462,6 @@ op::v1::TopK::TopK(const Output<Node>& data,
constructor_validate_and_infer_types();
}
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
op::v1::TopK::TopK(const Output<Node>& data,
const Output<Node>& k,
const int64_t axis,
@ -318,6 +534,25 @@ void op::v1::TopK::validate_and_infer_types()
set_output_type(1, m_index_element_type, output_shape);
}
Shape op::v1::TopK::compute_output_shape(const std::string& node_description,
const PartialShape input_partial_shape,
const int64_t k)
{
PartialShape output_shape{input_partial_shape};
m_normalized_axis = ngraph::normalize_axis(node_description, m_axis, output_shape.rank());
if (k != 0)
{
output_shape[m_normalized_axis] = k;
}
else
{
output_shape[m_normalized_axis] = input_partial_shape[m_normalized_axis];
}
return output_shape.get_shape();
}
void op::v1::TopK::set_axis(const int64_t axis)
{
const auto input_rank = get_input_partial_shape(0).rank();
@ -332,6 +567,19 @@ void op::v1::TopK::set_axis(const int64_t axis)
m_axis = axis;
}
void op::v1::TopK::set_axis(const Rank input_rank, const int64_t axis)
{
if (input_rank.is_static())
{
m_normalized_axis = ngraph::normalize_axis(this, axis, input_rank);
}
else
{
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
}
m_axis = axis;
}
uint64_t op::v1::TopK::get_axis() const
{
NODE_VALIDATION_CHECK(
@ -433,6 +681,49 @@ void op::v1::TopK::set_k(size_t k)
op::Constant::create(element::i64, Shape{}, {k})->output(0));
}
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
{
Shape arg_shape = inputs[0]->get_shape();
// 1. get axis, mode ( max/min), sort_type
set_axis(arg_shape.size(), m_axis);
size_t axis = get_axis();
bool compute_max = get_mode() == TopKMode::MAX ? true : false;
SortType sort_type = get_sort_type();
// 2. get value of k - from constant node or from HT
size_t k = 0;
if (input_value(1).get_node_shared_ptr()->is_constant())
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
NGRAPH_CHECK(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
}
else
{
k = read_k_from_host_tensor(inputs[1]);
}
// 3. Compute output_shape
auto output_shape = compute_output_shape(this->description(), inputs[0]->get_shape(), k);
// do this after compute_output_shape
if (k == 0)
{
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
k = arg_shape[axis];
}
return evaluate_topk(inputs[0],
outputs[1],
outputs[0],
output_shape,
axis,
k,
compute_max,
sort_type,
get_index_element_type());
}
// v3 version starts
constexpr NodeTypeInfo op::v3::TopK::type_info;
@ -516,3 +807,8 @@ shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_arg
return std::move(new_v3_topk);
}
bool op::v3::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
{
return op::v1::TopK::evaluate(outputs, inputs);
}

View File

@ -102,12 +102,18 @@ namespace ngraph
bool get_compute_max() const { return m_compute_max; }
SortType get_sort() const { return m_sort; }
size_t get_default_output_index() const override { return no_default_index(); }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
protected:
element::Type m_index_element_type;
bool m_compute_max{false};
SortType m_sort{SortType::NONE};
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const OutputVector& deltas) override;
Shape compute_output_shape(const Shape input_shape,
const int64_t k,
const size_t axis);
};
} // namespace v0
@ -181,6 +187,9 @@ namespace ngraph
size_t get_k() const;
void set_k(size_t k);
size_t get_default_output_index() const override { return no_default_index(); }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
protected:
int64_t m_axis;
uint64_t m_normalized_axis;
@ -196,6 +205,10 @@ namespace ngraph
template <typename T>
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
Shape compute_output_shape(const std::string& node_description,
const PartialShape input_partial_shape,
const int64_t k);
void set_axis(const Rank input_rank, const int64_t axis);
};
} // namespace v1
@ -240,6 +253,9 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
protected:
virtual size_t
read_k_from_constant_node(const std::shared_ptr<Node>& node,

View File

@ -142,7 +142,8 @@ void op::util::BroadcastBase::validate_and_infer_types()
auto output_rank = input_value(1).get_partial_shape();
if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static())
{
result_shape = PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
result_shape =
PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
}
const auto shape_constant = as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());

View File

@ -0,0 +1,47 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <utility>
#include "ngraph/op/op.hpp"
namespace ngraph
{
struct VariableInfo
{
PartialShape data_shape;
element::Type data_type;
std::string variable_id;
};
class NGRAPH_API Variable
{
public:
Variable() = default;
explicit Variable(const VariableInfo& variable_info)
: m_info(variable_info)
{
}
VariableInfo get_info() { return m_info; }
void update(const VariableInfo& variable_info) { m_info = variable_info; }
private:
VariableInfo m_info;
};
}

View File

@ -30,6 +30,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/asin.hpp"
#include "ngraph/op/asinh.hpp"
#include "ngraph/op/assign.hpp"
#include "ngraph/op/atan.hpp"
#include "ngraph/op/atan2.hpp"
#include "ngraph/op/atanh.hpp"
@ -149,6 +150,7 @@
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/quantized_dot.hpp"
#include "ngraph/op/range.hpp"
#include "ngraph/op/read_value.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/reduce_logical_and.hpp"
#include "ngraph/op/reduce_logical_or.hpp"

View File

@ -166,4 +166,6 @@ NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
NGRAPH_OP(ShapeOf, ngraph::op::v3)
NGRAPH_OP(Assign, ngraph::op::v3)
NGRAPH_OP(ReadValue, ngraph::op::v3)
NGRAPH_OP(TopK, ngraph::op::v3)

View File

@ -16,6 +16,7 @@
#include "ngraph/specialize_function.hpp"
#include <pass/constant_folding.hpp>
#include "ngraph/op/assign.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/tensor_iterator.hpp"
@ -84,7 +85,18 @@ std::shared_ptr<Function>
}
else
{
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
NodeVector cloned_dependencies;
for (auto& dependency : old_node->get_control_dependencies())
{
std::shared_ptr<Node> dependent = m.at(dependency.get());
if (find(cloned_dependencies.begin(), cloned_dependencies.end(), dependent) ==
cloned_dependencies.end())
{
cloned_dependencies.push_back(dependent);
}
}
m[old_node.get()] = old_node->copy_with_new_inputs(new_args, cloned_dependencies);
// TODO: workaround for shape inference, delete it after fix
if (::ngraph::as_type_ptr<ngraph::op::TensorIterator>(m[old_node.get()]))
{

View File

@ -114,6 +114,7 @@ set(SRC
tensor.cpp
type_prop/all.cpp
type_prop/any.cpp
type_prop/assign.cpp
type_prop/avg_pool.cpp
type_prop/batch_mat_mul.cpp
type_prop/batch_mat_mul_transpose.cpp
@ -178,6 +179,7 @@ set(SRC
type_prop/quantized_dot.cpp
type_prop/random_uniform.cpp
type_prop/range.cpp
type_prop/read_value.cpp
type_prop/replace_slice.cpp
type_prop/reshape.cpp
type_prop/reverse.cpp

File diff suppressed because it is too large Load Diff

View File

@ -64,10 +64,10 @@
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/transpose.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "runtime/backend.hpp"
#include "util/all_close_f.hpp"
#include "util/ndarray.hpp"
#include "util/test_tools.hpp"
@ -1555,3 +1555,358 @@ TEST(eval, evaluate_dynamic_scatter_elements_update_one_elem_i32)
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
ASSERT_EQ(cval, out);
}
TEST(eval, topk_v1)
{
Shape shape{2, 3, 2};
Shape rshape{2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
const auto k = op::Constant::create(element::i32, Shape{}, {2});
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v1_dyn)
{
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
auto fun =
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v3_dyn)
{
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "index", element::i32);
auto fun =
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v3_dyn_values)
{
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
auto fun =
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 11, 7, 6, 3};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 2, 2, 0, 1};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v3_dyn_values_k0)
{
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
auto fun =
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i32>(Shape{}, {0})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v0_dyn)
{
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::i64, Shape{});
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
element::Type result_et{element::i32};
bool compute_max = true;
auto B = make_shared<op::v0::TopK>(
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
ParameterVector{A, k, axis});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i64>(Shape{}, {2}),
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
EXPECT_EQ(result0->get_element_type(), element::i32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::f32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result1_val = read_vector<float>(result1);
auto result0_val = read_vector<int32_t>(result0);
vector<float> expec1{12, 9, 10, 4, 11, 7, 6, 3};
ASSERT_EQ(result1_val, expec1);
vector<int32_t> expec0{0, 1, 1, 2, 2, 2, 0, 1};
ASSERT_EQ(result0_val, expec0);
}
TEST(eval, topk_v0_dyn_k0)
{
Shape shape{2, 3, 2};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto k = make_shared<op::Parameter>(element::i64, Shape{});
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
element::Type result_et{element::i32};
bool compute_max = true;
auto B = make_shared<op::v0::TopK>(
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
ParameterVector{A, k, axis});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i64>(Shape{}, {0}),
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
EXPECT_EQ(result0->get_element_type(), element::i32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
EXPECT_EQ(result1->get_element_type(), element::f32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
auto result1_val = read_vector<float>(result1);
auto result0_val = read_vector<int32_t>(result0);
vector<float> expec1{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
ASSERT_EQ(result1_val, expec1);
vector<int32_t> expec0{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
ASSERT_EQ(result0_val, expec0);
}
TEST(eval, topk_v3_param_dyn_values_k0)
{
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
auto fun =
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i32>(Shape{}, {0})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v3_param_dyn_values_k2)
{
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto k = make_shared<op::Parameter>(element::u32, Shape{});
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
auto fun =
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
EXPECT_EQ(result0->get_element_type(), element::f32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::i32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result0_val = read_vector<float>(result0);
auto result1_val = read_vector<int32_t>(result1);
vector<float> expec0{12, 9, 10, 4, 11, 7, 6, 3};
ASSERT_EQ(result0_val, expec0);
vector<int32_t> expec1{0, 1, 1, 2, 2, 2, 0, 1};
ASSERT_EQ(result1_val, expec1);
}
TEST(eval, topk_v0_param_dyn_k2)
{
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto k = make_shared<op::Parameter>(element::i64, Shape{});
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
element::Type result_et{element::i32};
bool compute_max = true;
auto B = make_shared<op::v0::TopK>(
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
ParameterVector{A, k, axis});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i64>(Shape{}, {2}),
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
EXPECT_EQ(result0->get_element_type(), element::i32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
EXPECT_EQ(result1->get_element_type(), element::f32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
auto result1_val = read_vector<float>(result1);
auto result0_val = read_vector<int32_t>(result0);
vector<float> expec1{12, 9, 10, 4, 11, 7, 6, 3};
ASSERT_EQ(result1_val, expec1);
vector<int32_t> expec0{0, 1, 1, 2, 2, 2, 0, 1};
ASSERT_EQ(result0_val, expec0);
}
TEST(eval, topk_v0_param_dyn_k0)
{
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
auto k = make_shared<op::Parameter>(element::i64, Shape{});
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
element::Type result_et{element::i32};
bool compute_max = true;
auto B = make_shared<op::v0::TopK>(
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
ParameterVector{A, k, axis});
auto result0 = make_shared<HostTensor>();
auto result1 = make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result0, result1},
{make_host_tensor<element::Type_t::f32>(
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
make_host_tensor<element::Type_t::i64>(Shape{}, {0}),
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
EXPECT_EQ(result0->get_element_type(), element::i32);
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
EXPECT_EQ(result1->get_element_type(), element::f32);
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
auto result1_val = read_vector<float>(result1);
auto result0_val = read_vector<int32_t>(result0);
vector<float> expec1{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
ASSERT_EQ(result1_val, expec1);
vector<int32_t> expec0{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
ASSERT_EQ(result0_val, expec0);
}

View File

@ -0,0 +1,52 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, assign_variable_not_found)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
try
{
auto space_to_depth = make_shared<op::Assign>(A, "variable_id");
// Should have thrown, so fail if it didn't
FAIL() << "Should not find variable with variable_id";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Can't find variable with id = variable_id"));
}
catch (...)
{
FAIL() << "Deduced type check failed for unexpected reason";
}
}
TEST(type_prop, assign_deduce)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
auto read_value = make_shared<op::ReadValue>(input, "variable_id");
auto assign = make_shared<op::Assign>(read_value, "variable_id");
ASSERT_EQ(assign->get_element_type(), element::f32);
ASSERT_EQ(assign->get_shape(), (Shape{1, 2, 64, 64}));
}

View File

@ -0,0 +1,31 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, read_value_deduce)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
auto read_value = make_shared<op::ReadValue>(input, "variable_id");
ASSERT_EQ(read_value->get_element_type(), element::f32);
ASSERT_EQ(read_value->get_shape(), (Shape{1, 2, 64, 64}));
}

View File

@ -49,7 +49,7 @@ ngraph::test::NgraphTestCase::NgraphTestCase(const std::shared_ptr<Function>& fu
}
}
void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
::testing::AssertionResult ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
{
m_tolerance_bits = tolerance_bits;
const auto& function_results = m_function->get_results();
@ -85,6 +85,7 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
m_output_index = 0;
m_expected_outputs.clear();
m_input_tensors.clear();
return ::testing::AssertionSuccess();
}
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)

View File

@ -187,7 +187,7 @@ namespace ngraph
add_expected_output<T>(expected_shape, value);
}
void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
::testing::AssertionResult run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
private:
template <typename T>