publish master branch snapshot, revision ea98a886d925eb152931aab13856e68037665562
This commit is contained in:
parent
deb008a26f
commit
ccb7438803
@ -85,7 +85,7 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
|
||||
THROW_IE_EXCEPTION << "Invalid IR! " << node_param.name << " name is not unique!";
|
||||
opName.insert(node_param.name);
|
||||
params[node_param.layerId] = {node, node_param};
|
||||
if (node_param.type == "Result") {
|
||||
if (node_param.type == "Result" || node_param.type == "Assign") {
|
||||
outputs.push_back(node_param.layerId);
|
||||
}
|
||||
}
|
||||
@ -118,7 +118,9 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
|
||||
|
||||
ngraph::ParameterVector parameter_nodes;
|
||||
ngraph::ResultVector result_nodes;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> allNodes;
|
||||
ngraph::NodeVector allNodes;
|
||||
std::vector<std::shared_ptr<ngraph::op::Assign>> assign_nodes;
|
||||
std::map<std::string, std::shared_ptr<ngraph::Node>> variable_id_to_read_value;
|
||||
|
||||
// Following topological order create nGraph operations
|
||||
for (auto& layer_id : order) {
|
||||
@ -159,12 +161,28 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
|
||||
if (auto result_node = std::dynamic_pointer_cast<ngraph::op::Result>(node)) {
|
||||
result_nodes.emplace_back(result_node);
|
||||
}
|
||||
|
||||
if (auto assign_node = std::dynamic_pointer_cast<ngraph::op::Assign>(node)) {
|
||||
assign_nodes.emplace_back(assign_node);
|
||||
}
|
||||
|
||||
if (auto read_value_node = std::dynamic_pointer_cast<ngraph::op::ReadValue>(node)) {
|
||||
variable_id_to_read_value[read_value_node->get_variable_id()] = read_value_node;
|
||||
}
|
||||
allNodes.emplace_back(node);
|
||||
}
|
||||
|
||||
::ngraph::op::GenericIE::DisableReshape noReshape(allNodes);
|
||||
|
||||
return std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
|
||||
auto function = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
|
||||
if (!result_nodes.empty()) {
|
||||
for (const auto& assign : assign_nodes) {
|
||||
assign->add_control_dependency(variable_id_to_read_value.at(assign->get_variable_id()));
|
||||
// often Assign node is a leaf of the graph, we add control_dependency for one of the results
|
||||
// to make Assign node visible for traversals get_ops(), get_ordered_ops()
|
||||
result_nodes[0]->add_control_dependency(assign);
|
||||
}
|
||||
}
|
||||
return function;
|
||||
}
|
||||
|
||||
V10Parser::GenericLayerParams V10Parser::parseGenericParams(const pugi::xml_node& node) {
|
||||
|
@ -334,6 +334,28 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||
LayerParams attrs = {node->get_friendly_name(), "Memory",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<CNNLayer>(attrs);
|
||||
res->params["id"] = params.at("variable_id");
|
||||
res->params["index"] = "0";
|
||||
res->params["size"] = "2";
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"ReadValue"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||
LayerParams attrs = {node->get_friendly_name(), "Memory",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<CNNLayer>(attrs);
|
||||
res->params["id"] = params.at("variable_id");
|
||||
res->params["index"] = "1";
|
||||
res->params["size"] = "2";
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
|
||||
@ -656,14 +678,16 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
||||
// Collect all names from current graph
|
||||
// It is necessary in order to differentiate outputs from constant layers when we share constants
|
||||
// (Constant operations contains outputs for converted and original functions)
|
||||
const ngraph::NodeVector& nodes = graph->get_ops();
|
||||
|
||||
std::unordered_set<std::string> op_names;
|
||||
for (const auto &layer : graph->get_ops())
|
||||
for (const auto &layer : nodes)
|
||||
op_names.insert(layer->get_name());
|
||||
|
||||
bool keep_constants = ::ngraph::op::util::has_op_with_type<::ngraph::op::FakeQuantize>(graph);
|
||||
|
||||
// Create layers and output data
|
||||
for (const auto &layer : graph->get_ops()) {
|
||||
for (const auto &layer : nodes) {
|
||||
if (isInternalLayer(layer, op_names, keep_constants)) continue;
|
||||
|
||||
// TODO: remove this rt info when all blobs will be inputs
|
||||
@ -703,8 +727,18 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
||||
}
|
||||
inputCount++;
|
||||
}
|
||||
|
||||
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "1") {
|
||||
inputCount = 0;
|
||||
}
|
||||
|
||||
cnnLayer->insData.resize(inputCount);
|
||||
|
||||
for (size_t i = 0; i < layer->get_output_size(); i++) {
|
||||
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
|
||||
cnnLayer->outData.clear();
|
||||
continue;
|
||||
}
|
||||
std::string outName = layer->get_friendly_name();
|
||||
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
|
||||
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
|
||||
@ -747,6 +781,8 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
||||
|
||||
// Set input data
|
||||
for (const auto &layer : graph->get_ordered_ops()) {
|
||||
if (std::dynamic_pointer_cast<::ngraph::op::ReadValue>(layer))
|
||||
continue;
|
||||
if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) {
|
||||
IE_ASSERT(layer->get_inputs().size() == 1);
|
||||
const auto &input = layer->input_value(0);
|
||||
|
@ -62,9 +62,9 @@ void ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::convert_normalize_l2_
|
||||
across_spatial,
|
||||
channel_shared);
|
||||
|
||||
normalize_ie->set_friendly_name(normalize->get_friendly_name());
|
||||
ngraph::copy_runtime_info(normalize, normalize_ie);
|
||||
ngraph::replace_node(normalize, normalize_ie);
|
||||
normalize_ie->set_friendly_name(mul->get_friendly_name());
|
||||
ngraph::copy_runtime_info({normalize, mul}, normalize_ie);
|
||||
ngraph::replace_node(mul, normalize_ie);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -46,6 +46,7 @@ extensions/back/ParameterToPlaceholder.py
|
||||
extensions/back/pass_separator.py
|
||||
extensions/back/priorbox_mutation.py
|
||||
extensions/back/ProposalMutation.py
|
||||
extensions/back/ReadValueAssignToMemory.py
|
||||
extensions/back/ReduceToPooling.py
|
||||
extensions/back/ReduceTransposeDimensions.py
|
||||
extensions/back/remove_last_softmax_pattern.py
|
||||
@ -872,6 +873,7 @@ mo/middle/pattern_match.py
|
||||
mo/middle/replacement.py
|
||||
mo/ops/__init__.py
|
||||
mo/ops/activation.py
|
||||
mo/ops/assign.py
|
||||
mo/ops/broadcast.py
|
||||
mo/ops/clamp.py
|
||||
mo/ops/concat.py
|
||||
@ -897,6 +899,7 @@ mo/ops/pad.py
|
||||
mo/ops/permute.py
|
||||
mo/ops/pooling.py
|
||||
mo/ops/power.py
|
||||
mo/ops/read_value.py
|
||||
mo/ops/reshape.py
|
||||
mo/ops/result.py
|
||||
mo/ops/roipooling.py
|
||||
|
@ -23,7 +23,7 @@ from mo.ops.crop import Crop
|
||||
from mo.utils.logger import log
|
||||
|
||||
|
||||
class CutMemory(BackReplacementPattern):
|
||||
class CutMemoryInput(BackReplacementPattern):
|
||||
"""
|
||||
Cut Memory layers and have inputs/outputs in graph instead of them
|
||||
"""
|
||||
@ -38,30 +38,56 @@ class CutMemory(BackReplacementPattern):
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='Memory'))],
|
||||
('op', dict(kind='op', op='ReadValue'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['id']
|
||||
node_id = node['variable_id']
|
||||
|
||||
if node.in_port(0).disconnected():
|
||||
i = 0
|
||||
for dest in node.out_port(0).get_destinations():
|
||||
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
|
||||
'shape': dest.data.get_shape()}).create_node()
|
||||
i += 1
|
||||
dest.disconnect()
|
||||
new_in.out_port(0).connect(dest)
|
||||
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
|
||||
extra={'is_warning': True})
|
||||
else:
|
||||
out_node_port = node.out_port(0).get_destination()
|
||||
in_node_port = node.in_port(0).get_source()
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).disconnect()
|
||||
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node()
|
||||
in_node_port.connect(crop.in_port(0))
|
||||
crop.out_port(0).connect(out_node_port)
|
||||
i = 0
|
||||
node.in_port(0).disconnect()
|
||||
for dest in node.out_port(0).get_destinations():
|
||||
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
|
||||
'shape': dest.data.get_shape()}).create_node()
|
||||
i += 1
|
||||
dest.disconnect()
|
||||
new_in.out_port(0).connect(dest)
|
||||
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
|
||||
extra={'is_warning': True})
|
||||
|
||||
|
||||
class CutMemoryOutput(BackReplacementPattern):
|
||||
"""
|
||||
Cut Memory layers and have inputs/outputs in graph instead of them
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi" and graph.graph['cmd_params'].remove_memory]
|
||||
force_clean_up = True
|
||||
|
||||
def run_before(self):
|
||||
return [ParameterToInput]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='Assign'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['variable_id']
|
||||
|
||||
out_node_port = node.out_port(0).get_destination()
|
||||
in_node_port = node.in_port(0).get_source()
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).disconnect()
|
||||
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]),
|
||||
'axis': np.array([0])}).create_node()
|
||||
in_node_port.connect(crop.in_port(0))
|
||||
crop.out_port(0).connect(out_node_port)
|
||||
|
@ -17,7 +17,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.CutMemory import CutMemory
|
||||
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
|
||||
@ -29,18 +29,21 @@ class CutMemoryTest(unittest.TestCase):
|
||||
nodes_attrs={
|
||||
'input': {'kind': 'op'},
|
||||
'data_in': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'index': 1, 'id': 'memory_', 'in_ports_count': 1},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
'const_0_data': {'kind': 'data'},
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'variable_id': 'memory_'},
|
||||
'data_mem': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'concat': {'kind': 'op', 'op': 'Concat', 'axis': 0},
|
||||
'concat_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'some_op': {'kind': 'op'},
|
||||
'some_op_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'index': 0, 'id': 'memory_'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'variable_id': 'memory_'},
|
||||
'data_mem_out': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'mem_out_result': {'kind': 'op', 'op': 'Result'}
|
||||
},
|
||||
edges=[
|
||||
('input', 'data_in'), ('memory_in', 'data_mem'),
|
||||
('input', 'data_in'),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'memory_in'), ('memory_in', 'data_mem'),
|
||||
('data_in', 'concat', {'in': 0}), ('data_mem', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'some_op'),
|
||||
('some_op', 'some_op_data'), ('some_op_data', 'memory_out'),
|
||||
@ -69,7 +72,8 @@ class CutMemoryTest(unittest.TestCase):
|
||||
('crop', 'crop_data'), ('crop_data', 'mem_out_result')
|
||||
],
|
||||
)
|
||||
CutMemory().find_and_replace_pattern(graph)
|
||||
CutMemoryInput().find_and_replace_pattern(graph)
|
||||
CutMemoryOutput().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mem_out_result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
140
model-optimizer/extensions/back/ReadValueAssignToMemory.py
Normal file
140
model-optimizer/extensions/back/ReadValueAssignToMemory.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.memory import Memory
|
||||
|
||||
|
||||
"""
|
||||
All transformations in this file should be removed after removing IR v7 support
|
||||
"""
|
||||
|
||||
|
||||
class ReplaceReadValueByMemory(BackReplacementPattern):
|
||||
"""
|
||||
Replace ReadValue by Memory. Should be removed after v7 IR support removing.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [CutMemoryInput]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='ReadValue'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['variable_id']
|
||||
|
||||
node.in_port(0).disconnect()
|
||||
new_in = Memory(graph, {'name': node.id, 'id': node_id, 'index': 1, 'size': 2,
|
||||
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
|
||||
for dest in node.out_port(0).get_destinations():
|
||||
dest.disconnect()
|
||||
new_in.out_port(0).connect(dest)
|
||||
|
||||
|
||||
class ReplaceAssignByMemory(BackReplacementPattern):
|
||||
"""
|
||||
Replace Assign by Memory. Should be removed after v7 IR support removing.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||
force_clean_up = True
|
||||
|
||||
def run_after(self):
|
||||
return [CutMemoryOutput]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='Assign'))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
node_id = node['variable_id']
|
||||
|
||||
new_out = Memory(graph, {'name': node.id, 'id': node_id, 'index': 0, 'size': 2,
|
||||
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
|
||||
node.in_port(0).get_source().connect(new_out.in_port(0))
|
||||
node.in_port(0).disconnect()
|
||||
node.out_port(0).get_connection().set_source(new_out.out_port(0))
|
||||
|
||||
|
||||
class KaldiRemoveMemoryOutputBackReplacementPatternV7(BackReplacementPattern):
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||
|
||||
def run_after(self):
|
||||
from extensions.back.pass_separator import BackFinish
|
||||
return [BackFinish]
|
||||
|
||||
def run_before(self):
|
||||
from extensions.back.SpecialNodesFinalization import RemoveOutputOps
|
||||
return [RemoveOutputOps]
|
||||
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('memory_node', dict(op='Memory')),
|
||||
('data_node', dict(kind='data')),
|
||||
('op_output', dict(op='Result'))
|
||||
],
|
||||
edges=[
|
||||
('memory_node', 'data_node'),
|
||||
('data_node', 'op_output')
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
"""
|
||||
Need to find the pattern: Memory -> Data -> Result
|
||||
|
||||
It is needed to make Memory nodes appear in IR,
|
||||
but they are output nodes by default and we remove the Result node after each output memory.
|
||||
|
||||
DO NOT use graph clean up after it
|
||||
otherwise Memory nodes would be removed as they are not on the path from input to output
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : Graph
|
||||
Graph with loaded model.
|
||||
match : dict
|
||||
Patterns which were found in graph structure.
|
||||
"""
|
||||
memory = match['memory_node']
|
||||
data = match['data_node']
|
||||
op_output = match['op_output']
|
||||
|
||||
graph.remove_edge(memory.id, data.id)
|
||||
graph.remove_node(data.id)
|
||||
graph.remove_node(op_output.id)
|
@ -33,7 +33,7 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('memory_node', dict(op='Memory')),
|
||||
('memory_node', dict(op='Assign')),
|
||||
('data_node', dict(kind='data')),
|
||||
('op_output', dict(op='Result'))
|
||||
],
|
||||
@ -63,6 +63,8 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
|
||||
"""
|
||||
memory = match['memory_node']
|
||||
data = match['data_node']
|
||||
op_output = match['op_output']
|
||||
|
||||
graph.remove_edge(memory.id, data.id)
|
||||
graph.remove_node(data.id)
|
||||
graph.remove_node(op_output.id)
|
||||
|
@ -26,7 +26,7 @@ class KaldiRemoveMemoryOutputTest(unittest.TestCase):
|
||||
'kind': 'data'
|
||||
},
|
||||
'memory_node': {
|
||||
'op': 'Memory',
|
||||
'op': 'Assign',
|
||||
'kind': 'op'
|
||||
},
|
||||
'output_node': {
|
||||
|
@ -63,7 +63,7 @@ def apply_biases_to_last_layer(graph, counts):
|
||||
outputs_ids = find_outputs(graph)
|
||||
for output in outputs_ids.copy():
|
||||
node = Node(graph, output)
|
||||
if node.op != 'Memory':
|
||||
if node.op != 'Assign':
|
||||
continue
|
||||
outputs_ids.remove(output)
|
||||
|
||||
|
@ -13,12 +13,12 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.eltwise import Eltwise
|
||||
from mo.ops.eltwise_n import EltwiseN
|
||||
from mo.utils.error import Error
|
||||
|
||||
@ -43,8 +43,12 @@ class ReplaceEltwiseNin1NodePattern(FrontReplacementOp):
|
||||
edge_attrs = inp[0][1]
|
||||
graph.add_edge(in_node, ss_node.id, **edge_attrs)
|
||||
if ss_node.num_splits == 2:
|
||||
eltwise_node = Eltwise(graph, attrs={'name': 'Eltwise_' + node.name,
|
||||
'operation': node['operation']}).create_node()
|
||||
if node['operation'] == 'mul':
|
||||
eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
|
||||
elif node['operation'] == 'sum':
|
||||
eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
|
||||
else:
|
||||
raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
|
||||
elif ss_node.num_splits > 2:
|
||||
eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
|
||||
'operation': node['operation']}).create_node()
|
||||
|
@ -20,13 +20,19 @@ from extensions.ops.activation_ops import Tanh, Sigmoid
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.caffe.extractors.utils import input_as_const
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.graph.graph import Node, Graph, Port
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
from mo.ops.shape import Shape
|
||||
|
||||
|
||||
def unique_id(prefix: str = 'id') -> str:
|
||||
@ -46,6 +52,35 @@ def unique_id(prefix: str = 'id') -> str:
|
||||
unique_id.names = []
|
||||
|
||||
|
||||
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
|
||||
# create init_graph connected to ReadValue
|
||||
graph = input_out_port.node.graph
|
||||
input_name = input_out_port.node.name
|
||||
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
|
||||
shape_of_input.in_port(0).connect(input_out_port)
|
||||
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
|
||||
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
|
||||
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
|
||||
'axis': int64_array([0]), 'offset': int64_array([0])
|
||||
}).create_node()
|
||||
get_batch.in_port(0).connect(shape_of_input.out_port(0))
|
||||
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
|
||||
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
|
||||
'value': int64_array([second_dim]),
|
||||
'shape': int64_array([1])}).create_node()
|
||||
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
|
||||
'axis': 0, 'in_ports_count': 2}).create_node()
|
||||
mem_shape.in_port(0).connect(get_batch.out_port(0))
|
||||
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
|
||||
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
|
||||
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
|
||||
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
|
||||
}).create_node()
|
||||
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
|
||||
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
|
||||
return init_value_prev_lstm_output
|
||||
|
||||
|
||||
class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
op = "LSTMCell"
|
||||
enabled = True
|
||||
@ -69,7 +104,7 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
)
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
input_node = node.in_node()
|
||||
input_out_port = node.in_port(0).get_source()
|
||||
|
||||
memory_pair_input = unique_id('id')
|
||||
memory_pair_output = unique_id('id')
|
||||
@ -81,16 +116,17 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
'bias_term': True,
|
||||
}
|
||||
|
||||
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node([input_node])
|
||||
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node()
|
||||
fc_layer_after_input.in_port(0).connect(input_out_port)
|
||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
|
||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
|
||||
|
||||
prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
|
||||
'id': memory_pair_input,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
|
||||
}).create_node()
|
||||
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
|
||||
node.gifo_r_weights_shape[1])
|
||||
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
|
||||
'variable_id': memory_pair_input
|
||||
}).create_node()
|
||||
prev_lstm_output.in_port(0).connect(init_value_prev_lstm_output.out_port(0))
|
||||
|
||||
# *Memory(output) -> FullyConnected
|
||||
fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
|
||||
@ -99,15 +135,16 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
'bias_term': False,
|
||||
}
|
||||
|
||||
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node(
|
||||
[prev_lstm_output])
|
||||
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node()
|
||||
fc_layer_from_prev_state.in_port(0).connect(prev_lstm_output.out_port(0))
|
||||
input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
|
||||
|
||||
# Memory -> FullyConnected \
|
||||
# *Eltwise(sum)
|
||||
# Input -> FullyConnected /
|
||||
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise',
|
||||
}).create_node([fc_layer_from_prev_state, fc_layer_after_input])
|
||||
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise'}).create_node()
|
||||
join_input_prev_state_sum.in_port(0).connect(fc_layer_from_prev_state.out_port(0))
|
||||
join_input_prev_state_sum.in_port(1).connect(fc_layer_after_input.out_port(0))
|
||||
|
||||
# *Eltwise(sum) -> Split
|
||||
# it is split into 4 nodes: Act, Eltw*3
|
||||
@ -120,131 +157,147 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
# |____(4)Eltwise(sum)
|
||||
split_joined_input_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||
split_joined_input = Split(graph, {'name': 'join_input_split',
|
||||
'num_splits': 4,
|
||||
}).create_node([join_input_prev_state_sum, split_joined_input_axis])
|
||||
'num_splits': 4, 'out_ports_count': 4}).create_node()
|
||||
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
|
||||
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
|
||||
|
||||
prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
||||
'id': memory_pair_output,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
}).create_node()
|
||||
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
||||
# 'id': memory_pair_output,
|
||||
# 'index': 1,
|
||||
# 'size': 2,
|
||||
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
# }).create_node()
|
||||
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
|
||||
node.input_gate_weights.shape[0])
|
||||
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
|
||||
'variable_id': memory_pair_output}).create_node()
|
||||
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))
|
||||
|
||||
# *Memory(state) -> *ScaleShift(input)
|
||||
state_input_scaleshift_attrs = {'name': 'input_scaleshift',
|
||||
'bias_term': False
|
||||
}
|
||||
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
|
||||
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node()
|
||||
state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
|
||||
input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
|
||||
|
||||
# *Memory(state) -> *ScaleShift(forget)
|
||||
state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
|
||||
'bias_term': False
|
||||
}
|
||||
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
|
||||
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node()
|
||||
state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
|
||||
input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
|
||||
|
||||
# Split \
|
||||
# (2)Eltwise(sum)
|
||||
# Memory(state) -> *ScaleShift(input) /
|
||||
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
|
||||
}).create_node([(split_joined_input, 1),
|
||||
state_input_scaleshift
|
||||
])
|
||||
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise'
|
||||
}).create_node()
|
||||
join_prev_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(1))
|
||||
join_prev_lstm_input_joined_input_sum.in_port(1).connect(state_input_scaleshift.out_port(0))
|
||||
# Split \
|
||||
# (3)Eltwise(sum)
|
||||
# Memory(state) -> *ScaleShift(forget) /
|
||||
join_prev_lstm_input_joined_forget_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
|
||||
}).create_node([(split_joined_input, 2),
|
||||
state_forget_scaleshift
|
||||
])
|
||||
}).create_node()
|
||||
join_prev_lstm_input_joined_forget_sum.in_port(0).connect(split_joined_input.out_port(2))
|
||||
join_prev_lstm_input_joined_forget_sum.in_port(1).connect(state_forget_scaleshift.out_port(0))
|
||||
|
||||
# Split -> Tanh
|
||||
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node([(split_joined_input, 0)])
|
||||
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
|
||||
remember_tahn.in_port(0).connect(split_joined_input.out_port(0))
|
||||
|
||||
# Split -> (2)Eltwise(sum) -> *Sigmoid
|
||||
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'
|
||||
}).create_node([join_prev_lstm_input_joined_input_sum])
|
||||
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'}).create_node()
|
||||
remember_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_input_sum.out_port(0))
|
||||
|
||||
# Split -> (3)Eltwise(sum) -> **Sigmoid
|
||||
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'
|
||||
}).create_node([join_prev_lstm_input_joined_forget_sum])
|
||||
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'}).create_node()
|
||||
forget_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_forget_sum.out_port(0))
|
||||
|
||||
# *Memory(state) \
|
||||
# (6)Eltwise(mul)
|
||||
# Split -> (3)Eltwise(sum) -> **Sigmoid /
|
||||
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul',
|
||||
}).create_node([forget_sigmoid, prev_lstm_state])
|
||||
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul'}).create_node()
|
||||
join_forget_prev_state_mul.in_port(0).connect(forget_sigmoid.out_port(0))
|
||||
join_forget_prev_state_mul.in_port(1).connect(prev_lstm_state.out_port(0))
|
||||
|
||||
# Split -> Tahn \
|
||||
# (5)Eltwise(mul)
|
||||
# Split -> (2)Eltwise(sum) -> *Sigmoid /
|
||||
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul',
|
||||
}).create_node([remember_tahn, remember_sigmoid])
|
||||
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul'}).create_node()
|
||||
join_remember_candidates_mul.in_port(0).connect(remember_tahn.out_port(0))
|
||||
join_remember_candidates_mul.in_port(1).connect(remember_sigmoid.out_port(0))
|
||||
|
||||
# (5)Eltwise(mul) \
|
||||
# (7)Eltwise(sum)
|
||||
# (6)Eltwise(mul) /
|
||||
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum',
|
||||
}).create_node(
|
||||
[join_forget_prev_state_mul, join_remember_candidates_mul])
|
||||
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum'}).create_node()
|
||||
join_forget_remember_sum.in_port(0).connect(join_forget_prev_state_mul.out_port(0))
|
||||
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
|
||||
|
||||
# (7)Eltwise(sum) -> Clamp
|
||||
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
|
||||
'max': node.clip_value,
|
||||
'min': -node.clip_value
|
||||
}).create_node(
|
||||
[join_forget_remember_sum])
|
||||
'min': -node.clip_value}).create_node()
|
||||
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
|
||||
#
|
||||
# Clamp -> (2)Memory(state)
|
||||
next_lstm_state = Memory(graph, {'name': 'next_lstm_state',
|
||||
'id': memory_pair_output,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
}).create_node([join_forget_clamp])
|
||||
Result(graph, {'name': 'next_lstm_state_out'}).create_node([next_lstm_state])
|
||||
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
|
||||
'variable_id': memory_pair_output}).create_node()
|
||||
next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||
|
||||
res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
|
||||
res_node.in_port(0).connect(next_lstm_state.out_port(0))
|
||||
|
||||
# Clamp -> (2)Tahn
|
||||
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node([join_forget_clamp])
|
||||
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node()
|
||||
state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||
|
||||
# Clamp -> (2)ScaleShift
|
||||
clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
|
||||
'bias_term': False}
|
||||
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
|
||||
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node()
|
||||
clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||
input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
|
||||
|
||||
# Split \
|
||||
# (4)Eltwise(sum)
|
||||
# Clamp -> (2)ScaleShift /
|
||||
join_next_lstm_input_joined_input_sum = Add(graph, {'name': 'join_next_lstm_input_joined_input_sum',
|
||||
}).create_node([(split_joined_input, 3), clamp_scaleshift])
|
||||
}).create_node()
|
||||
join_next_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(3))
|
||||
join_next_lstm_input_joined_input_sum.in_port(1).connect(clamp_scaleshift.out_port(0))
|
||||
|
||||
# (4)Eltwise(sum) -> (3)Sigmoid
|
||||
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node([join_next_lstm_input_joined_input_sum])
|
||||
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node()
|
||||
output_sigmoid.in_port(0).connect(join_next_lstm_input_joined_input_sum.out_port(0))
|
||||
|
||||
# (4)Eltwise(sum) -> (3)Sigmoid \
|
||||
# (5)Eltwise(mul)
|
||||
# Clamp -> (2)Tahn /
|
||||
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node([state_filtered_tahn, output_sigmoid])
|
||||
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node()
|
||||
joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
|
||||
joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))
|
||||
|
||||
# (5)Eltwise(mul) -> (3)FullyConnected
|
||||
fc_output_attrs = {'name': 'FullyConnected',
|
||||
'out-size': node.projection_weights_shape[0],
|
||||
'transpose_weights': True,
|
||||
'bias_term': False}
|
||||
fc_output = FullyConnected(graph, fc_output_attrs).create_node([joined_output_mul])
|
||||
fc_output = FullyConnected(graph, fc_output_attrs).create_node()
|
||||
fc_output.in_port(0).connect(joined_output_mul.out_port(0))
|
||||
input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights)
|
||||
|
||||
# / (2)Memory(output)
|
||||
# (3)FullyConnected
|
||||
# \ Output (any next node) (edge created automatically after replacement)
|
||||
next_lstm_output = Memory(graph, {'name': 'next_lstm_output',
|
||||
'id': memory_pair_input,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
|
||||
}).create_node([fc_output])
|
||||
Result(graph, {'name': 'next_lstm_output_out'}).create_node([next_lstm_output])
|
||||
next_lstm_output = Assign(graph, {'name': 'next_lstm_output',
|
||||
'variable_id': memory_pair_input}).create_node()
|
||||
next_lstm_output.in_port(0).connect(fc_output.out_port(0))
|
||||
|
||||
res_node_lstm_output = Result(graph, {'name': 'next_lstm_output_out'}).create_node()
|
||||
res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))
|
||||
|
||||
return [fc_output.id]
|
||||
|
@ -22,7 +22,7 @@ from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.eltwise import Eltwise
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
|
||||
|
||||
@ -41,19 +41,19 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||
split_node = Split(graph, {'name': graph.unique_id(prefix='Split_lstm_input_'),
|
||||
split_node = Split(graph, {'name': 'Split_lstm_input_',
|
||||
'num_splits': 5}).create_node()
|
||||
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
|
||||
split_node.in_port(1).connect(split_node_axis.out_port(0))
|
||||
|
||||
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
||||
i_scale_attrs = {'name': graph.unique_id(prefix='i_scaleshift'),
|
||||
i_scale_attrs = {'name': 'i_scaleshift',
|
||||
'bias_term': False}
|
||||
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
|
||||
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
|
||||
split_node.out_port(4).connect(i_scale.in_port(0))
|
||||
|
||||
sum_i_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum'}).create_node()
|
||||
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
|
||||
split_node.out_port(0).connect(sum_i_c.in_port(0))
|
||||
i_scale.out_port(0).connect(sum_i_c.in_port(1))
|
||||
|
||||
@ -61,13 +61,13 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
|
||||
|
||||
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
||||
f_scale_attrs = {'name': graph.unique_id(prefix='f_scaleshift'),
|
||||
f_scale_attrs = {'name': 'f_scaleshift',
|
||||
'bias_term': False}
|
||||
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
|
||||
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
|
||||
split_node.out_port(4).connect(f_scale.in_port(0))
|
||||
|
||||
sum_f_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum'}).create_node()
|
||||
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
|
||||
split_node.out_port(1).connect(sum_f_c.in_port(0))
|
||||
f_scale.out_port(0).connect(sum_f_c.in_port(1))
|
||||
|
||||
@ -78,28 +78,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
|
||||
split_node.out_port(2).connect(c_tanh.in_port(0))
|
||||
|
||||
prod_i_c_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_i_c_tanh_'),
|
||||
'operation': 'mul'}).create_node()
|
||||
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
|
||||
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
|
||||
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
|
||||
|
||||
prod_f_ct_1 = Eltwise(graph, {'name': graph.unique_id(prefix='prod_f_ct_1_'),
|
||||
'operation': 'mul'}).create_node()
|
||||
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
|
||||
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
|
||||
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
|
||||
|
||||
sum_f_i = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum'}).create_node()
|
||||
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
|
||||
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
|
||||
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
|
||||
|
||||
# o_t = Sigmoid(o_part + w_oc*c_t)
|
||||
o_scale_attrs = {'name': graph.unique_id(prefix='o_scaleshift'),
|
||||
o_scale_attrs = {'name': 'o_scaleshift',
|
||||
'bias_term': False}
|
||||
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
|
||||
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
|
||||
sum_f_i.out_port(0).connect(o_scale.in_port(0))
|
||||
|
||||
sum_o_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum'}).create_node()
|
||||
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
|
||||
split_node.out_port(3).connect(sum_o_c.in_port(0))
|
||||
o_scale.out_port(0).connect(sum_o_c.in_port(1))
|
||||
|
||||
@ -110,13 +108,12 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
|
||||
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
|
||||
|
||||
prod_o_c_t_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
|
||||
'operation': 'mul'}).create_node()
|
||||
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
|
||||
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
|
||||
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
|
||||
|
||||
# add concat to create 1 output
|
||||
concat = Concat(graph, {'name': graph.unique_id(prefix='Concat_c_m')}).create_node()
|
||||
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
|
||||
concat.add_sequence_of_ports('in', range(2))
|
||||
sum_f_i.out_port(0).connect(concat.in_port(0))
|
||||
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))
|
||||
|
@ -15,15 +15,18 @@
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||
from extensions.ops.elementwise import Equal
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.error import Error
|
||||
from mo.utils.graph import invert_sub_graph_between_nodes
|
||||
@ -48,7 +51,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
@staticmethod
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[('op', dict(op='Memory', index=0))],
|
||||
nodes=[('op', dict(op='Assign'))],
|
||||
edges=[])
|
||||
|
||||
@staticmethod
|
||||
@ -93,9 +96,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
select_node.in_port(2).connect(zero_else.out_port(0))
|
||||
|
||||
# check if we have already appropriate iteration counter
|
||||
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='Memory', index=1,
|
||||
shape=int64_array([context_len]))),
|
||||
('mem_in_data', dict()),
|
||||
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
|
||||
('mem_in_data', dict(shape=int64_array([context_len]))),
|
||||
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
|
||||
offset=int64_array([1]),
|
||||
dim=int64_array([context_len-1]))),
|
||||
@ -104,8 +106,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
('concat_data', dict()),
|
||||
('const_1', dict(op='Const')),
|
||||
('const_1_data', dict()),
|
||||
('mem_out', dict(op='Memory', index=0,
|
||||
shape=int64_array([context_len]))),
|
||||
('mem_out', dict(op='Assign')),
|
||||
('crop_out', dict(op='Crop', axis=int64_array([1]),
|
||||
offset=int64_array([0]),
|
||||
dim=int64_array([1]))),
|
||||
@ -122,12 +123,13 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
('crop_out_data', 'select')])
|
||||
counter_match = next(existing_counters, None)
|
||||
if counter_match is not None:
|
||||
ones = Node(graph, inverse_dict(counter_match)['const_1'])
|
||||
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
|
||||
else:
|
||||
mem_out = Memory(graph, {'name': 'iteration_number', 'size': 2,
|
||||
'index': 1, 'id': 'iteration_' + node.name,
|
||||
'shape': int64_array([context_len]),
|
||||
'dst_type': np.int32}).create_node()
|
||||
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
|
||||
mem_out = ReadValue(graph, {'name': 'iteration_number',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
|
||||
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
||||
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
||||
cut_first.in_port(0).connect(mem_out.out_port(0))
|
||||
@ -135,9 +137,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
||||
concat.in_port(0).connect(cut_first.out_port(0))
|
||||
concat.in_port(1).connect(ones.out_port(0))
|
||||
mem_in = Memory(graph, {'name': 'iteration_number_out', 'size': 2,
|
||||
'index': 0, 'id': 'iteration_' + node.name,
|
||||
'shape': int64_array([context_len])}).create_node()
|
||||
mem_in = Assign(graph, {'name': 'iteration_number_out',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
mem_in.in_port(0).connect(concat.out_port(0))
|
||||
res = Result(graph, {}).create_node()
|
||||
mem_in.out_port(0).connect(res.in_port(0))
|
||||
@ -146,6 +147,12 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
cut_last.in_port(0).connect(concat.out_port(0))
|
||||
input_port = cut_last.out_port(0)
|
||||
|
||||
select_node.in_port(0).connect(input_port)
|
||||
# Check if data from memory is 1
|
||||
# if it is True, we have correct data and should proceed with saving it to memory
|
||||
# else we have not gathered context and have garbage here, shouldn't change initial state of memory
|
||||
cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
|
||||
cast_in.in_port(0).connect(ones.out_port(0))
|
||||
cast_in.in_port(1).connect(input_port)
|
||||
select_node.in_port(0).connect(cast_in.out_port(0))
|
||||
select_node.out_port(0).connect(node.in_port(0))
|
||||
select_node.out_port(0).data.set_shape(in_node_shape)
|
||||
|
@ -30,23 +30,14 @@ class InsertSelectTests(unittest.TestCase):
|
||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
ref_graph = graph.copy()
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
|
||||
self.assertTrue(flag, resp)
|
||||
@ -60,7 +51,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@ -76,15 +67,32 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||
'equal_data': {'kind': 'data'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
@ -95,22 +103,34 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('select', 'select_out_data'),
|
||||
('select_out_data', 'memory')
|
||||
],
|
||||
@ -132,7 +152,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@ -151,15 +171,32 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||
'equal_data': {'kind': 'data'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
@ -170,7 +207,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@ -178,13 +215,25 @@ class InsertSelectTests(unittest.TestCase):
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'select', {'in': 0}),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
|
||||
('select', 'select_out_data'),
|
||||
@ -208,7 +257,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@ -227,15 +276,32 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||
'equal_data': {'kind': 'data'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
@ -246,7 +312,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
@ -254,13 +320,25 @@ class InsertSelectTests(unittest.TestCase):
|
||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'select', {'in': 0}),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
|
||||
('select', 'select_out_data'),
|
||||
|
@ -59,7 +59,18 @@ class RemoveUselessCropsPattern(MiddleReplacementPattern):
|
||||
if out['op'] == 'Crop' and out['axis'] == axis and \
|
||||
len(out.out_port(0).get_destinations()) == 1 and \
|
||||
out.out_port(0).get_destination().node == concat_node:
|
||||
offsets_dims.append((out['offset'], out['dim']))
|
||||
# crop type 1
|
||||
if 'dim' in out:
|
||||
offsets_dims.append((out['offset'], out['dim']))
|
||||
# crop type 3
|
||||
elif 'crop_begin' in out and 'crop_end' in out:
|
||||
offsets_dims.append((out['crop_begin'], out['crop_end']-out['crop_begin']))
|
||||
# crop type 2 with const dim
|
||||
elif not out.in_port(1).disconnected() and out.in_port(1).data.get_value() is not None:
|
||||
offsets_dims.append((out['offset'], out.in_port(1).data.get_value()))
|
||||
# crop type 2 with non-const dim or strange type of crop
|
||||
else:
|
||||
return
|
||||
crop_list.append(out)
|
||||
|
||||
offsets_dims.sort(key=lambda off_dim: off_dim[0])
|
||||
|
@ -84,6 +84,136 @@ class RemoveUselessCropsPatternTests(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_useless_crops_type2(self):
|
||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
|
||||
'const_26_data': {'kind': 'data', 'value': 26},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
|
||||
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('crop_data_1', 'concat'),
|
||||
('crop_data_2', 'concat'),
|
||||
('crop_data_3', 'concat'),
|
||||
('crop_data_4', 'concat'),
|
||||
('crop_data_5', 'concat'),
|
||||
('concat', 'concat_data'),
|
||||
('concat_data', 'placeholder')])
|
||||
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
|
||||
'const_26_data': {'kind': 'data', 'value': 26},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'dim': 26, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[
|
||||
('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
|
||||
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('concat', 'concat_data'),
|
||||
('in_node', 'placeholder')
|
||||
]
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_useless_crops_type3(self):
|
||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('crop_data_1', 'concat'),
|
||||
('crop_data_2', 'concat'),
|
||||
('crop_data_3', 'concat'),
|
||||
('crop_data_4', 'concat'),
|
||||
('crop_data_5', 'concat'),
|
||||
('concat', 'concat_data'),
|
||||
('concat_data', 'placeholder')])
|
||||
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
|
||||
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||
},
|
||||
[
|
||||
('placeholder_in', 'in_node'),
|
||||
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
|
||||
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||
('concat', 'concat_data'),
|
||||
('in_node', 'placeholder')
|
||||
]
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_useful_crops(self):
|
||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||
|
@ -15,13 +15,15 @@
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||
from extensions.ops.splice import Splice
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.error import Error
|
||||
|
||||
@ -67,7 +69,8 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
|
||||
|
||||
splice = Splice(graph, {'name': node_name,
|
||||
'id': node_id,
|
||||
'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
|
||||
'context': int64_array(range(node_t, 1))
|
||||
if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
|
||||
splice.in_port(0).connect(input_node_out_port)
|
||||
|
||||
# offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
|
||||
@ -106,6 +109,7 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
||||
Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
|
||||
"""
|
||||
enabled = True
|
||||
force_shape_inference = True
|
||||
|
||||
def run_before(self):
|
||||
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
||||
@ -141,43 +145,34 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
||||
in_shape = input_port.data.get_shape()
|
||||
node_t = abs(node.t)
|
||||
|
||||
memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
|
||||
'index': 1, 'size': 2,
|
||||
'shape': np.array([in_shape[1]*node_t])}).create_node()
|
||||
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
|
||||
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
|
||||
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
|
||||
|
||||
if node_t > 1:
|
||||
crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
|
||||
'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
|
||||
memory_out.out_port(0).connect(crop_concat.in_port(0))
|
||||
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
|
||||
concat.add_sequence_of_ports('in', range(2))
|
||||
crop_concat.out_port(0).connect(concat.in_port(0))
|
||||
crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
|
||||
concat.in_port(1).connect(input_port)
|
||||
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
|
||||
'index': 0, 'size': 2,
|
||||
'shape': memory_out.shape}).create_node()
|
||||
|
||||
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
|
||||
concat.out_port(0).connect(memory_in.in_port(0))
|
||||
concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
|
||||
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
||||
memory_in.out_port(0).connect(out.in_port(0))
|
||||
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
|
||||
crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
|
||||
'offset': np.array([0]), 'axis': np.array([1])}).create_node()
|
||||
memory_out.out_port(0).connect(crop_out.in_port(0))
|
||||
out_port.get_connection().set_source(crop_out.out_port(0))
|
||||
crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
|
||||
else:
|
||||
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
|
||||
'index': 0, 'size': 2,
|
||||
'shape': memory_out.shape}).create_node()
|
||||
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
|
||||
memory_in.in_port(0).connect(input_port)
|
||||
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
||||
memory_in.out_port(0).connect(out.in_port(0))
|
||||
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
out_port.get_connection().set_source(memory_out.out_port(0))
|
||||
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
||||
|
||||
graph.remove_node(op_output_id)
|
||||
graph.remove_node(node.id)
|
||||
|
@ -13,15 +13,16 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
|
||||
from extensions.ops.split import VariadicSplit
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.memory import Memory
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
|
||||
|
||||
@ -39,7 +40,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
So this pass will convert this graph to the next one:
|
||||
|
||||
Input [N, H] __
|
||||
\ /
|
||||
/ /
|
||||
Concat [N, k*H]
|
||||
/ \
|
||||
Memory [N, k*H] -> Slice [N, (k-1)*H] Memory [N, k*H]
|
||||
@ -67,11 +68,9 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
|
||||
memory_pair_id = unique_id('id')
|
||||
# Memory(in)
|
||||
input_memory = Memory(graph, {'name': 'prev_splice_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size])}).create_node()
|
||||
input_memory = ReadValue(graph, {'name': 'prev_splice_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
|
||||
# Memory(in) \
|
||||
# Crop
|
||||
# Input(temp) /
|
||||
@ -90,11 +89,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
concat_node.in_port(0).connect(crop.out_port(0))
|
||||
|
||||
# Concat -> Memory(out)
|
||||
mem_out = Memory(graph, {'name': 'out_splice_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size])}).create_node()
|
||||
mem_out = Assign(graph, {'name': 'out_splice_memory', 'variable_id': memory_pair_id}).create_node()
|
||||
mem_out.in_port(0).connect(concat_node.out_port(0))
|
||||
Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))
|
||||
|
||||
@ -110,11 +105,12 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
|
||||
# create separate splice construction for const_dim
|
||||
memory_pair_id = unique_id('memory_for_const_dim')
|
||||
input_memory_const_dim = Memory(graph, {'name': 'const_dim_in_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 1,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size_constdim])}).create_node()
|
||||
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
|
||||
memory_size_constdim)
|
||||
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
|
||||
|
||||
crop_const_dim = Crop(graph, {'name': 'const_dim_crop',
|
||||
'axis': int64_array([1]),
|
||||
'offset': int64_array([memory_element_constdim]),
|
||||
@ -127,11 +123,8 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
'axis': 1}).create_node()
|
||||
concat_node_const_dim.in_port(0).connect(crop_const_dim.out_port(0))
|
||||
|
||||
mem_out_const_dim = Memory(graph, {'name': 'const_dim_out_memory',
|
||||
'id': memory_pair_id,
|
||||
'index': 0,
|
||||
'size': 2,
|
||||
'shape': int64_array([memory_size_constdim])}).create_node()
|
||||
mem_out_const_dim = Assign(graph, {'name': 'const_dim_out_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
mem_out_const_dim.in_port(0).connect(concat_node_const_dim.out_port(0))
|
||||
Result(graph).create_node().in_port(0).connect(mem_out_const_dim.out_port(0))
|
||||
|
||||
@ -148,9 +141,15 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
concat_const.in_port(1).connect(crop_first.out_port(0))
|
||||
concat_const.in_port(0).connect(concat_node.out_port(0))
|
||||
|
||||
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
|
||||
memory_size)
|
||||
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||
node.in_port(0).get_connection().set_destination(split.in_port(0))
|
||||
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
|
||||
else:
|
||||
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
|
||||
memory_size)
|
||||
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
|
||||
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
import unittest
|
||||
|
||||
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
@ -42,19 +43,47 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
|
||||
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
|
||||
'crop_mem_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data', 'shape': [1, 143]},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
|
||||
},
|
||||
[
|
||||
('in_placeholder', 'in_node'),
|
||||
|
||||
('in_node', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||
('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'),
|
||||
('memory_in_data', 'crop_mem'),
|
||||
('crop_mem', 'crop_mem_data'),
|
||||
@ -86,22 +115,54 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
'split': {'kind': 'op', 'op': 'Split'},
|
||||
'split_data_0': {'kind': 'data'},
|
||||
'split_data_1': {'kind': 'data'},
|
||||
'memory_in': {'kind': 'op', 'op': 'Memory'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
|
||||
'crop_mem_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Memory'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
|
||||
|
||||
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_2_data': {'kind': 'data'},
|
||||
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_2_data': {'kind': 'data'},
|
||||
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_2_data': {'kind': 'data'},
|
||||
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||
'second_dim_2_data': {'kind': 'data'},
|
||||
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_2_data': {'kind': 'data'},
|
||||
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_2_data': {'kind': 'data'},
|
||||
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_2_data': {'kind': 'data'},
|
||||
|
||||
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_constdims_data': {'kind': 'data'},
|
||||
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
|
||||
'crop_mem_constdims_data': {'kind': 'data'},
|
||||
'concat_constdims': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_constdims_data': {'kind': 'data'},
|
||||
'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
|
||||
'memory_out_constdims': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out_constdims_data': {'kind': 'data'},
|
||||
'result_constdims': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
|
||||
@ -121,6 +182,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
('in_node', 'split', {'in': 0}),
|
||||
('split', 'split_data_0', {'out': 0}),
|
||||
('split', 'split_data_1', {'out': 1}),
|
||||
|
||||
('split_data_0', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||
('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'),
|
||||
('memory_in_data', 'crop_mem'),
|
||||
('crop_mem', 'crop_mem_data'),
|
||||
@ -130,6 +203,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'),
|
||||
('memory_out_data', 'result'),
|
||||
|
||||
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
|
||||
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
|
||||
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
|
||||
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
|
||||
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
|
||||
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
|
||||
('gather_shape_2', 'gather_shape_2_data'),
|
||||
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
|
||||
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
|
||||
('broadcast_2_data', 'memory_in_constdims'),
|
||||
|
||||
('memory_in_constdims', 'memory_in_constdims_data'),
|
||||
('memory_in_constdims_data', 'crop_mem_constdims'),
|
||||
('crop_mem_constdims', 'crop_mem_constdims_data'),
|
||||
|
@ -113,9 +113,6 @@ def prepare_ir(argv: argparse.Namespace):
|
||||
elif (is_kaldi or is_onnx) and not argv.input_model:
|
||||
raise Error('Path to input model is required: use --input_model.')
|
||||
|
||||
if is_kaldi:
|
||||
argv.generate_experimental_IR_V10 = False
|
||||
|
||||
log.debug(str(argv))
|
||||
log.debug("Model Optimizer started")
|
||||
|
||||
|
40
model-optimizer/mo/ops/assign.py
Normal file
40
model-optimizer/mo/ops/assign.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class Assign(Op):
|
||||
op = 'Assign'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': self.op,
|
||||
'op': self.op,
|
||||
'version': 'opset3',
|
||||
'infer': self.infer,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
|
||||
def backend_attrs(self):
|
||||
return ['variable_id']
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
assert node.has_valid('variable_id'), \
|
||||
"There is no required attribute variable_id in Assign op with name " + node.id
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
45
model-optimizer/mo/ops/read_value.py
Normal file
45
model-optimizer/mo/ops/read_value.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class ReadValue(Op):
|
||||
op = 'ReadValue'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': self.op,
|
||||
'op': self.op,
|
||||
'version': 'opset3',
|
||||
'infer': self.infer,
|
||||
'type_infer': self.type_infer,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
|
||||
def backend_attrs(self):
|
||||
return ['variable_id']
|
||||
|
||||
@staticmethod
|
||||
def type_infer(node: Node):
|
||||
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
assert node.has_valid('variable_id'), \
|
||||
"There is no required attribute variable_id in ReadValue op with name " + node.id
|
||||
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
@ -32,6 +32,7 @@ from mo.ops.convolution import Convolution
|
||||
from mo.ops.deconvolution import Deconvolution
|
||||
from mo.ops.op import Op
|
||||
from mo.ops.pooling import Pooling
|
||||
from mo.ops.result import Result
|
||||
from mo.utils.class_registration import update_registration
|
||||
from mo.utils.import_extensions import import_by_path
|
||||
from mo.utils.ir_reader.extender import Extender
|
||||
@ -218,6 +219,18 @@ def ti_add_edge_attrs(op: Node):
|
||||
i += 1
|
||||
|
||||
|
||||
def assign_add_output_result(op: Node):
|
||||
"""
|
||||
Function adds necessary output result node for Assign node
|
||||
:param op:
|
||||
:return:
|
||||
"""
|
||||
assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
|
||||
''.format(op.soft_get('type'))
|
||||
tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
|
||||
op.out_port(0).connect(tmp_result.in_port(0))
|
||||
|
||||
|
||||
def copy_input_blobs(op: Node, copy_op: Node):
|
||||
"""
|
||||
Function copy input blob data nodes from restored graph to copied one
|
||||
@ -243,6 +256,7 @@ preprocessing_op_nodes = {
|
||||
|
||||
# Map with postprocessing functions for nodes
|
||||
postprocessing_op_nodes = {
|
||||
'Assign': assign_add_output_result,
|
||||
'TensorIterator': ti_add_edge_attrs,
|
||||
}
|
||||
|
||||
|
@ -22,6 +22,7 @@ from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConst
|
||||
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
||||
from extensions.back.TopKNormalizer import TopKNormalizer
|
||||
from extensions.back.blob_normalizer import BlobNormalizer
|
||||
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_precision
|
||||
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||
@ -77,6 +78,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
||||
PackBinaryWeights,
|
||||
BlobNormalizer,
|
||||
ConvolutionNormalizer,
|
||||
KaldiRemoveMemoryOutputBackReplacementPattern,
|
||||
]
|
||||
|
||||
# We need to run some specific passes from MO back stage.
|
||||
|
@ -137,6 +137,8 @@ set (SRC
|
||||
op/asin.hpp
|
||||
op/asinh.cpp
|
||||
op/asinh.hpp
|
||||
op/assign.cpp
|
||||
op/assign.hpp
|
||||
op/atan.cpp
|
||||
op/atan.hpp
|
||||
op/atanh.cpp
|
||||
@ -314,6 +316,8 @@ set (SRC
|
||||
op/proposal.cpp
|
||||
op/psroi_pooling.hpp
|
||||
op/psroi_pooling.cpp
|
||||
op/read_value.hpp
|
||||
op/read_value.cpp
|
||||
op/reduce_logical_and.cpp
|
||||
op/reduce_logical_and.hpp
|
||||
op/reduce_logical_or.cpp
|
||||
@ -518,6 +522,7 @@ set (SRC
|
||||
op/util/scatter_base.hpp
|
||||
op/util/unary_elementwise_arithmetic.cpp
|
||||
op/util/unary_elementwise_arithmetic.hpp
|
||||
op/util/variable.hpp
|
||||
ops.hpp
|
||||
opsets/opset.cpp
|
||||
partial_shape.cpp
|
||||
|
81
ngraph/src/ngraph/op/assign.cpp
Normal file
81
ngraph/src/ngraph/op/assign.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/assign.hpp"
|
||||
#include <ops.hpp>
|
||||
#include "ngraph/op/read_value.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v3::Assign::type_info;
|
||||
|
||||
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
|
||||
: Op({new_value})
|
||||
, m_variable_id(variable_id)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v3::Assign::validate_and_infer_types()
|
||||
{
|
||||
auto value = input_value(0);
|
||||
auto arg_t = get_input_element_type(0);
|
||||
auto output_shape = get_input_partial_shape(0);
|
||||
if (!m_variable)
|
||||
{
|
||||
NodeVector start_nodes;
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
start_nodes.push_back(input.get_source_output().get_node_shared_ptr());
|
||||
}
|
||||
auto nodes = topological_sort(start_nodes);
|
||||
for (const auto& node : nodes)
|
||||
{
|
||||
if (auto read_value = as_type_ptr<op::v3::ReadValue>(node))
|
||||
{
|
||||
if (read_value->get_variable_id() == m_variable_id)
|
||||
m_variable = read_value->get_variable();
|
||||
}
|
||||
}
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
|
||||
}
|
||||
|
||||
auto variable_info = m_variable->get_info();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_variable_id == variable_info.variable_id,
|
||||
"Variables identifiers are inconsistent.");
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
output_shape == variable_info.data_shape,
|
||||
"Variables output shapes are inconsistent.");
|
||||
|
||||
set_output_type(0, arg_t, output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<Assign>(new_args.at(0), m_variable_id);
|
||||
}
|
||||
|
||||
bool op::v3::Assign::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("variable_id", m_variable_id);
|
||||
return true;
|
||||
}
|
67
ngraph/src/ngraph/op/assign.hpp
Normal file
67
ngraph/src/ngraph/op/assign.hpp
Normal file
@ -0,0 +1,67 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/variable.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v3
|
||||
{
|
||||
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
||||
class NGRAPH_API Assign : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Assign", 3};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
Assign() = default;
|
||||
|
||||
/// \brief Constructs an Assign operation.
|
||||
///
|
||||
/// \param new_value Node that produces the input tensor.
|
||||
/// \param variable_id identificator of the variable to be updated.
|
||||
Assign(const Output<Node>& new_value, const std::string& variable_id);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::string get_variable_id() { return m_variable_id; }
|
||||
std::shared_ptr<ngraph::Variable> get_variable() { return m_variable; }
|
||||
void set_variable_id(const std::string& variable_id)
|
||||
{
|
||||
m_variable_id = variable_id;
|
||||
}
|
||||
void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
|
||||
{
|
||||
m_variable = variable;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string m_variable_id;
|
||||
std::shared_ptr<ngraph::Variable> m_variable;
|
||||
};
|
||||
}
|
||||
using v3::Assign;
|
||||
}
|
||||
}
|
@ -265,3 +265,5 @@ NGRAPH_OP(Transpose, ngraph::op::v1, 1)
|
||||
NGRAPH_OP(Unsqueeze, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(VariadicSplit, ngraph::op::v1, 1)
|
||||
NGRAPH_OP(Xor, ngraph::op::v0, 0)
|
||||
NGRAPH_OP(Assign, ngraph::op::v3, 3)
|
||||
NGRAPH_OP(ReadValue, ngraph::op::v3, 3)
|
||||
|
51
ngraph/src/ngraph/op/read_value.cpp
Normal file
51
ngraph/src/ngraph/op/read_value.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/read_value.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::ReadValue::type_info;
|
||||
|
||||
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id)
|
||||
: Op({new_value})
|
||||
, m_variable_id(variable_id)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::ReadValue::validate_and_infer_types()
|
||||
{
|
||||
auto arg_t = get_input_element_type(0);
|
||||
auto output_shape = get_input_partial_shape(0);
|
||||
|
||||
VariableInfo info = {output_shape, arg_t, m_variable_id};
|
||||
m_variable = std::make_shared<Variable>(info);
|
||||
set_output_type(0, arg_t, output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::ReadValue::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ReadValue>(new_args.at(0), m_variable_id);
|
||||
}
|
||||
|
||||
bool op::v3::ReadValue::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("variable_id", m_variable_id);
|
||||
return true;
|
||||
}
|
68
ngraph/src/ngraph/op/read_value.hpp
Normal file
68
ngraph/src/ngraph/op/read_value.hpp
Normal file
@ -0,0 +1,68 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/variable.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v3
|
||||
{
|
||||
/// \brief ReadValue operation creates the variable with `variable_id` and returns value
|
||||
/// of this variable.
|
||||
class NGRAPH_API ReadValue : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ReadValue", 3};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
ReadValue() = default;
|
||||
|
||||
/// \brief Constructs a ReadValue operation.
|
||||
///
|
||||
/// \param new_value Node that produces the input tensor.
|
||||
/// \param variable_id identificator of the variable to create.
|
||||
ReadValue(const Output<Node>& new_value, const std::string& variable_id);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::string get_variable_id() { return m_variable_id; }
|
||||
std::shared_ptr<ngraph::Variable> get_variable() { return m_variable; }
|
||||
void set_variable_id(const std::string& variable_id)
|
||||
{
|
||||
m_variable_id = variable_id;
|
||||
}
|
||||
void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
|
||||
{
|
||||
m_variable = variable;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string m_variable_id;
|
||||
std::shared_ptr<ngraph::Variable> m_variable;
|
||||
};
|
||||
}
|
||||
using v3::ReadValue;
|
||||
}
|
||||
}
|
@ -23,6 +23,9 @@
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/topk.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
@ -225,9 +228,224 @@ void op::v0::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */,
|
||||
throw ngraph_error("Forward-propagation-only operation");
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <element::Type_t INPUT_ET, element::Type_t INDEX_ET>
|
||||
inline bool evaluate_execute(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& out_indices,
|
||||
const HostTensorPtr& out_values,
|
||||
const Shape out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool compute_max,
|
||||
const op::TopK::SortType sort)
|
||||
{
|
||||
using T = typename element_type_traits<INPUT_ET>::value_type;
|
||||
using U = typename element_type_traits<INDEX_ET>::value_type;
|
||||
const Shape in_shape = arg0->get_shape();
|
||||
out_indices->set_shape(out_shape);
|
||||
out_indices->set_element_type(INDEX_ET);
|
||||
|
||||
out_values->set_shape(out_shape);
|
||||
out_values->set_element_type(arg0->get_element_type());
|
||||
|
||||
runtime::reference::topk<T, U>(arg0->get_data_ptr<INPUT_ET>(),
|
||||
out_indices->get_data_ptr<INDEX_ET>(),
|
||||
out_values->get_data_ptr<INPUT_ET>(),
|
||||
in_shape,
|
||||
out_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <element::Type_t INPUT_ET>
|
||||
bool evaluate(const HostTensorPtr& arg,
|
||||
const HostTensorPtr& out_indices,
|
||||
const HostTensorPtr& out_values,
|
||||
const Shape out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool max,
|
||||
const op::TopK::SortType sort,
|
||||
const element::Type index_et)
|
||||
{
|
||||
bool rc = true;
|
||||
switch (index_et)
|
||||
{
|
||||
case element::Type_t::i64:
|
||||
evaluate_execute<INPUT_ET, element::Type_t::i64>(
|
||||
arg, out_indices, out_values, out_shape, axis, k, max, sort);
|
||||
break;
|
||||
case element::Type_t::i32:
|
||||
evaluate_execute<INPUT_ET, element::Type_t::i32>(
|
||||
arg, out_indices, out_values, out_shape, axis, k, max, sort);
|
||||
break;
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool evaluate_topk(const HostTensorPtr& arg,
|
||||
const HostTensorPtr& out_indices,
|
||||
const HostTensorPtr& out_values,
|
||||
const Shape out_shape,
|
||||
const size_t axis,
|
||||
const size_t k,
|
||||
const bool max,
|
||||
const op::TopK::SortType sort,
|
||||
const element::Type index_et)
|
||||
{
|
||||
bool rc = true;
|
||||
switch (arg->get_element_type())
|
||||
{
|
||||
TYPE_CASE(i8)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(i16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(i32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(i64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(u8)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(u16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(u32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(u64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(bf16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(f16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(f32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
TYPE_CASE(f64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||
break;
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
template <element::Type_t K_ET>
|
||||
size_t get_k_from_hosttensor(const HostTensorPtr& arg)
|
||||
{
|
||||
using T = typename element_type_traits<K_ET>::value_type;
|
||||
auto p = arg->get_data_ptr<T>();
|
||||
size_t k = p[0];
|
||||
return k;
|
||||
}
|
||||
|
||||
#define CASE_GET_K(a) \
|
||||
case element::Type_t::a: k = get_k_from_hosttensor<element::Type_t::a>
|
||||
|
||||
size_t read_k_from_host_tensor(const HostTensorPtr& arg_k)
|
||||
{
|
||||
size_t k = 0;
|
||||
switch (arg_k->get_element_type())
|
||||
{
|
||||
CASE_GET_K(i8)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(i16)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(i32)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(i64)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(u8)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(u16)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(u32)(arg_k);
|
||||
break;
|
||||
CASE_GET_K(u64)(arg_k);
|
||||
break;
|
||||
default:
|
||||
// other types are not supported and would have thrown in ctor
|
||||
ngraph_error("read_k_from_host_tensor: type is not integral\n");
|
||||
break;
|
||||
}
|
||||
return k;
|
||||
}
|
||||
|
||||
// used in only v0, where type is set as int64_t
|
||||
size_t read_top_k_axis_from_host_tensor(const HostTensorPtr& arg)
|
||||
{
|
||||
NGRAPH_CHECK(arg->get_element_type() == element::i64,
|
||||
"TopK axis element type should be i64");
|
||||
auto p = arg->get_data_ptr<int64_t>();
|
||||
size_t axis = static_cast<size_t>(p[0]);
|
||||
return axis;
|
||||
}
|
||||
}
|
||||
|
||||
Shape op::v0::TopK::compute_output_shape(const Shape input_shape,
|
||||
const int64_t k,
|
||||
const size_t axis)
|
||||
{
|
||||
Shape output_shape{input_shape};
|
||||
if (k != 0)
|
||||
{
|
||||
output_shape[axis] = k;
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
bool op::v0::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
|
||||
{
|
||||
// check data types for arg, k and output element type
|
||||
Shape arg_shape = inputs[0]->get_shape();
|
||||
|
||||
// 1. get axis, mode ( max/min), sort_type
|
||||
size_t axis = 0;
|
||||
Dimension axis_dim = get_top_k_axis_dynamic();
|
||||
if (axis_dim.is_static())
|
||||
{
|
||||
axis = axis_dim.get_length();
|
||||
}
|
||||
else
|
||||
{
|
||||
axis = read_top_k_axis_from_host_tensor(inputs[2]);
|
||||
NGRAPH_CHECK(axis <= arg_shape.size(), "TopK axis is out of bounds");
|
||||
}
|
||||
bool compute_max = get_compute_max();
|
||||
SortType sort_type = get_sort();
|
||||
|
||||
// 2. get value of k - from constant node or from HT
|
||||
size_t k = get_k();
|
||||
if (k == 0)
|
||||
{
|
||||
k = read_k_from_host_tensor(inputs[1]);
|
||||
if (k == 0)
|
||||
{
|
||||
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
|
||||
k = arg_shape[axis];
|
||||
}
|
||||
}
|
||||
NGRAPH_CHECK(k <= arg_shape.at(axis), "K exceeds the dimension of the TopK axis");
|
||||
|
||||
// 3. Compute output_shape
|
||||
auto output_shape = compute_output_shape(inputs[0]->get_shape(), k, axis);
|
||||
|
||||
return evaluate_topk(inputs[0],
|
||||
outputs[0],
|
||||
outputs[1],
|
||||
output_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort_type,
|
||||
get_index_element_type());
|
||||
}
|
||||
|
||||
// v1 version starts
|
||||
constexpr NodeTypeInfo op::v1::TopK::type_info;
|
||||
|
||||
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
op::v1::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
@ -236,7 +454,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
|
||||
const element::Type& index_element_type)
|
||||
: Op{{data, k}}
|
||||
, m_axis{axis}
|
||||
, m_normalized_axis{0}
|
||||
, m_normalized_axis{UNKNOWN_NORMALIZED_AXIS}
|
||||
, m_mode{as_enum<Mode>(mode)}
|
||||
, m_sort{as_enum<SortType>(sort)}
|
||||
, m_index_element_type{index_element_type}
|
||||
@ -244,8 +462,6 @@ op::v1::TopK::TopK(const Output<Node>& data,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
op::v1::TopK::TopK(const Output<Node>& data,
|
||||
const Output<Node>& k,
|
||||
const int64_t axis,
|
||||
@ -318,6 +534,25 @@ void op::v1::TopK::validate_and_infer_types()
|
||||
set_output_type(1, m_index_element_type, output_shape);
|
||||
}
|
||||
|
||||
Shape op::v1::TopK::compute_output_shape(const std::string& node_description,
|
||||
const PartialShape input_partial_shape,
|
||||
const int64_t k)
|
||||
{
|
||||
PartialShape output_shape{input_partial_shape};
|
||||
|
||||
m_normalized_axis = ngraph::normalize_axis(node_description, m_axis, output_shape.rank());
|
||||
if (k != 0)
|
||||
{
|
||||
output_shape[m_normalized_axis] = k;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_shape[m_normalized_axis] = input_partial_shape[m_normalized_axis];
|
||||
}
|
||||
|
||||
return output_shape.get_shape();
|
||||
}
|
||||
|
||||
void op::v1::TopK::set_axis(const int64_t axis)
|
||||
{
|
||||
const auto input_rank = get_input_partial_shape(0).rank();
|
||||
@ -332,6 +567,19 @@ void op::v1::TopK::set_axis(const int64_t axis)
|
||||
m_axis = axis;
|
||||
}
|
||||
|
||||
void op::v1::TopK::set_axis(const Rank input_rank, const int64_t axis)
|
||||
{
|
||||
if (input_rank.is_static())
|
||||
{
|
||||
m_normalized_axis = ngraph::normalize_axis(this, axis, input_rank);
|
||||
}
|
||||
else
|
||||
{
|
||||
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
|
||||
}
|
||||
m_axis = axis;
|
||||
}
|
||||
|
||||
uint64_t op::v1::TopK::get_axis() const
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
@ -433,6 +681,49 @@ void op::v1::TopK::set_k(size_t k)
|
||||
op::Constant::create(element::i64, Shape{}, {k})->output(0));
|
||||
}
|
||||
|
||||
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
|
||||
{
|
||||
Shape arg_shape = inputs[0]->get_shape();
|
||||
// 1. get axis, mode ( max/min), sort_type
|
||||
set_axis(arg_shape.size(), m_axis);
|
||||
size_t axis = get_axis();
|
||||
bool compute_max = get_mode() == TopKMode::MAX ? true : false;
|
||||
SortType sort_type = get_sort_type();
|
||||
|
||||
// 2. get value of k - from constant node or from HT
|
||||
size_t k = 0;
|
||||
if (input_value(1).get_node_shared_ptr()->is_constant())
|
||||
{
|
||||
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
|
||||
get_input_element_type(1));
|
||||
NGRAPH_CHECK(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
|
||||
}
|
||||
else
|
||||
{
|
||||
k = read_k_from_host_tensor(inputs[1]);
|
||||
}
|
||||
|
||||
// 3. Compute output_shape
|
||||
auto output_shape = compute_output_shape(this->description(), inputs[0]->get_shape(), k);
|
||||
|
||||
// do this after compute_output_shape
|
||||
if (k == 0)
|
||||
{
|
||||
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
|
||||
k = arg_shape[axis];
|
||||
}
|
||||
|
||||
return evaluate_topk(inputs[0],
|
||||
outputs[1],
|
||||
outputs[0],
|
||||
output_shape,
|
||||
axis,
|
||||
k,
|
||||
compute_max,
|
||||
sort_type,
|
||||
get_index_element_type());
|
||||
}
|
||||
|
||||
// v3 version starts
|
||||
constexpr NodeTypeInfo op::v3::TopK::type_info;
|
||||
|
||||
@ -516,3 +807,8 @@ shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_arg
|
||||
|
||||
return std::move(new_v3_topk);
|
||||
}
|
||||
|
||||
bool op::v3::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
|
||||
{
|
||||
return op::v1::TopK::evaluate(outputs, inputs);
|
||||
}
|
||||
|
@ -102,12 +102,18 @@ namespace ngraph
|
||||
bool get_compute_max() const { return m_compute_max; }
|
||||
SortType get_sort() const { return m_sort; }
|
||||
size_t get_default_output_index() const override { return no_default_index(); }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
|
||||
protected:
|
||||
element::Type m_index_element_type;
|
||||
bool m_compute_max{false};
|
||||
SortType m_sort{SortType::NONE};
|
||||
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
|
||||
const OutputVector& deltas) override;
|
||||
Shape compute_output_shape(const Shape input_shape,
|
||||
const int64_t k,
|
||||
const size_t axis);
|
||||
};
|
||||
} // namespace v0
|
||||
|
||||
@ -181,6 +187,9 @@ namespace ngraph
|
||||
size_t get_k() const;
|
||||
void set_k(size_t k);
|
||||
size_t get_default_output_index() const override { return no_default_index(); }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
|
||||
protected:
|
||||
int64_t m_axis;
|
||||
uint64_t m_normalized_axis;
|
||||
@ -196,6 +205,10 @@ namespace ngraph
|
||||
|
||||
template <typename T>
|
||||
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
|
||||
Shape compute_output_shape(const std::string& node_description,
|
||||
const PartialShape input_partial_shape,
|
||||
const int64_t k);
|
||||
void set_axis(const Rank input_rank, const int64_t axis);
|
||||
};
|
||||
} // namespace v1
|
||||
|
||||
@ -240,6 +253,9 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
|
||||
protected:
|
||||
virtual size_t
|
||||
read_k_from_constant_node(const std::shared_ptr<Node>& node,
|
||||
|
@ -142,7 +142,8 @@ void op::util::BroadcastBase::validate_and_infer_types()
|
||||
auto output_rank = input_value(1).get_partial_shape();
|
||||
if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static())
|
||||
{
|
||||
result_shape = PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
|
||||
result_shape =
|
||||
PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
|
||||
}
|
||||
const auto shape_constant = as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
|
||||
|
||||
|
47
ngraph/src/ngraph/op/util/variable.hpp
Normal file
47
ngraph/src/ngraph/op/util/variable.hpp
Normal file
@ -0,0 +1,47 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
struct VariableInfo
|
||||
{
|
||||
PartialShape data_shape;
|
||||
element::Type data_type;
|
||||
std::string variable_id;
|
||||
};
|
||||
|
||||
class NGRAPH_API Variable
|
||||
{
|
||||
public:
|
||||
Variable() = default;
|
||||
|
||||
explicit Variable(const VariableInfo& variable_info)
|
||||
: m_info(variable_info)
|
||||
{
|
||||
}
|
||||
|
||||
VariableInfo get_info() { return m_info; }
|
||||
void update(const VariableInfo& variable_info) { m_info = variable_info; }
|
||||
private:
|
||||
VariableInfo m_info;
|
||||
};
|
||||
}
|
@ -30,6 +30,7 @@
|
||||
#include "ngraph/op/argmin.hpp"
|
||||
#include "ngraph/op/asin.hpp"
|
||||
#include "ngraph/op/asinh.hpp"
|
||||
#include "ngraph/op/assign.hpp"
|
||||
#include "ngraph/op/atan.hpp"
|
||||
#include "ngraph/op/atan2.hpp"
|
||||
#include "ngraph/op/atanh.hpp"
|
||||
@ -149,6 +150,7 @@
|
||||
#include "ngraph/op/quantized_convolution.hpp"
|
||||
#include "ngraph/op/quantized_dot.hpp"
|
||||
#include "ngraph/op/range.hpp"
|
||||
#include "ngraph/op/read_value.hpp"
|
||||
#include "ngraph/op/recv.hpp"
|
||||
#include "ngraph/op/reduce_logical_and.hpp"
|
||||
#include "ngraph/op/reduce_logical_or.hpp"
|
||||
|
@ -166,4 +166,6 @@ NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
|
||||
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
|
||||
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
|
||||
NGRAPH_OP(ShapeOf, ngraph::op::v3)
|
||||
NGRAPH_OP(Assign, ngraph::op::v3)
|
||||
NGRAPH_OP(ReadValue, ngraph::op::v3)
|
||||
NGRAPH_OP(TopK, ngraph::op::v3)
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/specialize_function.hpp"
|
||||
#include <pass/constant_folding.hpp>
|
||||
#include "ngraph/op/assign.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/tensor_iterator.hpp"
|
||||
|
||||
@ -84,7 +85,18 @@ std::shared_ptr<Function>
|
||||
}
|
||||
else
|
||||
{
|
||||
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
|
||||
NodeVector cloned_dependencies;
|
||||
for (auto& dependency : old_node->get_control_dependencies())
|
||||
{
|
||||
std::shared_ptr<Node> dependent = m.at(dependency.get());
|
||||
if (find(cloned_dependencies.begin(), cloned_dependencies.end(), dependent) ==
|
||||
cloned_dependencies.end())
|
||||
{
|
||||
cloned_dependencies.push_back(dependent);
|
||||
}
|
||||
}
|
||||
m[old_node.get()] = old_node->copy_with_new_inputs(new_args, cloned_dependencies);
|
||||
|
||||
// TODO: workaround for shape inference, delete it after fix
|
||||
if (::ngraph::as_type_ptr<ngraph::op::TensorIterator>(m[old_node.get()]))
|
||||
{
|
||||
|
@ -114,6 +114,7 @@ set(SRC
|
||||
tensor.cpp
|
||||
type_prop/all.cpp
|
||||
type_prop/any.cpp
|
||||
type_prop/assign.cpp
|
||||
type_prop/avg_pool.cpp
|
||||
type_prop/batch_mat_mul.cpp
|
||||
type_prop/batch_mat_mul_transpose.cpp
|
||||
@ -178,6 +179,7 @@ set(SRC
|
||||
type_prop/quantized_dot.cpp
|
||||
type_prop/random_uniform.cpp
|
||||
type_prop/range.cpp
|
||||
type_prop/read_value.cpp
|
||||
type_prop/replace_slice.cpp
|
||||
type_prop/reshape.cpp
|
||||
type_prop/reverse.cpp
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -64,10 +64,10 @@
|
||||
#include "ngraph/op/stop_gradient.hpp"
|
||||
#include "ngraph/op/tan.hpp"
|
||||
#include "ngraph/op/tanh.hpp"
|
||||
#include "ngraph/op/topk.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "runtime/backend.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
#include "util/ndarray.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
@ -1555,3 +1555,358 @@ TEST(eval, evaluate_dynamic_scatter_elements_update_one_elem_i32)
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
ASSERT_EQ(cval, out);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v1)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
Shape rshape{2, 2, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
const auto k = op::Constant::create(element::i32, Shape{}, {2});
|
||||
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
|
||||
|
||||
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
|
||||
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v1_dyn)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
|
||||
|
||||
auto fun =
|
||||
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v3_dyn)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "index", element::i32);
|
||||
|
||||
auto fun =
|
||||
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v3_dyn_values)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||
|
||||
auto fun =
|
||||
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
vector<float> expec0{12, 9, 10, 4, 11, 7, 6, 3};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 2, 2, 0, 1};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v3_dyn_values_k0)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||
|
||||
auto fun =
|
||||
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i32>(Shape{}, {0})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
vector<float> expec0{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v0_dyn)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
|
||||
element::Type result_et{element::i32};
|
||||
bool compute_max = true;
|
||||
|
||||
auto B = make_shared<op::v0::TopK>(
|
||||
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||
|
||||
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||
ParameterVector{A, k, axis});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {2}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result1_val = read_vector<float>(result1);
|
||||
auto result0_val = read_vector<int32_t>(result0);
|
||||
|
||||
vector<float> expec1{12, 9, 10, 4, 11, 7, 6, 3};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
|
||||
vector<int32_t> expec0{0, 1, 1, 2, 2, 2, 0, 1};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v0_dyn_k0)
|
||||
{
|
||||
Shape shape{2, 3, 2};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
|
||||
element::Type result_et{element::i32};
|
||||
bool compute_max = true;
|
||||
|
||||
auto B = make_shared<op::v0::TopK>(
|
||||
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||
|
||||
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||
ParameterVector{A, k, axis});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {0}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
auto result1_val = read_vector<float>(result1);
|
||||
auto result0_val = read_vector<int32_t>(result0);
|
||||
|
||||
vector<float> expec1{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
|
||||
vector<int32_t> expec0{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v3_param_dyn_values_k0)
|
||||
{
|
||||
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||
|
||||
auto fun =
|
||||
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i32>(Shape{}, {0})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
vector<float> expec0{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v3_param_dyn_values_k2)
|
||||
{
|
||||
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||
|
||||
auto fun =
|
||||
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result0_val = read_vector<float>(result0);
|
||||
auto result1_val = read_vector<int32_t>(result1);
|
||||
vector<float> expec0{12, 9, 10, 4, 11, 7, 6, 3};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
|
||||
vector<int32_t> expec1{0, 1, 1, 2, 2, 2, 0, 1};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v0_param_dyn_k2)
|
||||
{
|
||||
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
|
||||
element::Type result_et{element::i32};
|
||||
bool compute_max = true;
|
||||
|
||||
auto B = make_shared<op::v0::TopK>(
|
||||
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||
|
||||
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||
ParameterVector{A, k, axis});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {2}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||
auto result1_val = read_vector<float>(result1);
|
||||
auto result0_val = read_vector<int32_t>(result0);
|
||||
|
||||
vector<float> expec1{12, 9, 10, 4, 11, 7, 6, 3};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
|
||||
vector<int32_t> expec0{0, 1, 1, 2, 2, 2, 0, 1};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
}
|
||||
|
||||
TEST(eval, topk_v0_param_dyn_k0)
|
||||
{
|
||||
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||
|
||||
element::Type result_et{element::i32};
|
||||
bool compute_max = true;
|
||||
|
||||
auto B = make_shared<op::v0::TopK>(
|
||||
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||
|
||||
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||
ParameterVector{A, k, axis});
|
||||
|
||||
auto result0 = make_shared<HostTensor>();
|
||||
auto result1 = make_shared<HostTensor>();
|
||||
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||
{make_host_tensor<element::Type_t::f32>(
|
||||
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {0}),
|
||||
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||
auto result1_val = read_vector<float>(result1);
|
||||
auto result0_val = read_vector<int32_t>(result0);
|
||||
|
||||
vector<float> expec1{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||
ASSERT_EQ(result1_val, expec1);
|
||||
|
||||
vector<int32_t> expec0{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||
ASSERT_EQ(result0_val, expec0);
|
||||
}
|
||||
|
52
ngraph/test/type_prop/assign.cpp
Normal file
52
ngraph/test/type_prop/assign.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, assign_variable_not_found)
|
||||
{
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
|
||||
try
|
||||
{
|
||||
auto space_to_depth = make_shared<op::Assign>(A, "variable_id");
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Should not find variable with variable_id";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Can't find variable with id = variable_id"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, assign_deduce)
|
||||
{
|
||||
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
|
||||
auto read_value = make_shared<op::ReadValue>(input, "variable_id");
|
||||
auto assign = make_shared<op::Assign>(read_value, "variable_id");
|
||||
|
||||
ASSERT_EQ(assign->get_element_type(), element::f32);
|
||||
ASSERT_EQ(assign->get_shape(), (Shape{1, 2, 64, 64}));
|
||||
}
|
31
ngraph/test/type_prop/read_value.cpp
Normal file
31
ngraph/test/type_prop/read_value.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, read_value_deduce)
|
||||
{
|
||||
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
|
||||
auto read_value = make_shared<op::ReadValue>(input, "variable_id");
|
||||
|
||||
ASSERT_EQ(read_value->get_element_type(), element::f32);
|
||||
ASSERT_EQ(read_value->get_shape(), (Shape{1, 2, 64, 64}));
|
||||
}
|
@ -49,7 +49,7 @@ ngraph::test::NgraphTestCase::NgraphTestCase(const std::shared_ptr<Function>& fu
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
|
||||
::testing::AssertionResult ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
|
||||
{
|
||||
m_tolerance_bits = tolerance_bits;
|
||||
const auto& function_results = m_function->get_results();
|
||||
@ -85,6 +85,7 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
|
||||
m_output_index = 0;
|
||||
m_expected_outputs.clear();
|
||||
m_input_tensors.clear();
|
||||
return ::testing::AssertionSuccess();
|
||||
}
|
||||
|
||||
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
|
||||
|
@ -187,7 +187,7 @@ namespace ngraph
|
||||
add_expected_output<T>(expected_shape, value);
|
||||
}
|
||||
|
||||
void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
|
||||
::testing::AssertionResult run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
|
Loading…
Reference in New Issue
Block a user