publish master branch snapshot, revision ea98a886d925eb152931aab13856e68037665562
This commit is contained in:
parent
deb008a26f
commit
ccb7438803
@ -85,7 +85,7 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
|
|||||||
THROW_IE_EXCEPTION << "Invalid IR! " << node_param.name << " name is not unique!";
|
THROW_IE_EXCEPTION << "Invalid IR! " << node_param.name << " name is not unique!";
|
||||||
opName.insert(node_param.name);
|
opName.insert(node_param.name);
|
||||||
params[node_param.layerId] = {node, node_param};
|
params[node_param.layerId] = {node, node_param};
|
||||||
if (node_param.type == "Result") {
|
if (node_param.type == "Result" || node_param.type == "Assign") {
|
||||||
outputs.push_back(node_param.layerId);
|
outputs.push_back(node_param.layerId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -118,7 +118,9 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
|
|||||||
|
|
||||||
ngraph::ParameterVector parameter_nodes;
|
ngraph::ParameterVector parameter_nodes;
|
||||||
ngraph::ResultVector result_nodes;
|
ngraph::ResultVector result_nodes;
|
||||||
std::vector<std::shared_ptr<ngraph::Node>> allNodes;
|
ngraph::NodeVector allNodes;
|
||||||
|
std::vector<std::shared_ptr<ngraph::op::Assign>> assign_nodes;
|
||||||
|
std::map<std::string, std::shared_ptr<ngraph::Node>> variable_id_to_read_value;
|
||||||
|
|
||||||
// Following topological order create nGraph operations
|
// Following topological order create nGraph operations
|
||||||
for (auto& layer_id : order) {
|
for (auto& layer_id : order) {
|
||||||
@ -159,12 +161,28 @@ std::shared_ptr<ngraph::Function> V10Parser::parse(const pugi::xml_node& root, c
|
|||||||
if (auto result_node = std::dynamic_pointer_cast<ngraph::op::Result>(node)) {
|
if (auto result_node = std::dynamic_pointer_cast<ngraph::op::Result>(node)) {
|
||||||
result_nodes.emplace_back(result_node);
|
result_nodes.emplace_back(result_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto assign_node = std::dynamic_pointer_cast<ngraph::op::Assign>(node)) {
|
||||||
|
assign_nodes.emplace_back(assign_node);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto read_value_node = std::dynamic_pointer_cast<ngraph::op::ReadValue>(node)) {
|
||||||
|
variable_id_to_read_value[read_value_node->get_variable_id()] = read_value_node;
|
||||||
|
}
|
||||||
allNodes.emplace_back(node);
|
allNodes.emplace_back(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
::ngraph::op::GenericIE::DisableReshape noReshape(allNodes);
|
::ngraph::op::GenericIE::DisableReshape noReshape(allNodes);
|
||||||
|
auto function = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
|
||||||
return std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
|
if (!result_nodes.empty()) {
|
||||||
|
for (const auto& assign : assign_nodes) {
|
||||||
|
assign->add_control_dependency(variable_id_to_read_value.at(assign->get_variable_id()));
|
||||||
|
// often Assign node is a leaf of the graph, we add control_dependency for one of the results
|
||||||
|
// to make Assign node visible for traversals get_ops(), get_ordered_ops()
|
||||||
|
result_nodes[0]->add_control_dependency(assign);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return function;
|
||||||
}
|
}
|
||||||
|
|
||||||
V10Parser::GenericLayerParams V10Parser::parseGenericParams(const pugi::xml_node& node) {
|
V10Parser::GenericLayerParams V10Parser::parseGenericParams(const pugi::xml_node& node) {
|
||||||
|
@ -334,6 +334,28 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
|||||||
return res;
|
return res;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
addSpecificCreator({"Assign"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
|
LayerParams attrs = {node->get_friendly_name(), "Memory",
|
||||||
|
details::convertPrecision(node->get_output_element_type(0))};
|
||||||
|
auto res = std::make_shared<CNNLayer>(attrs);
|
||||||
|
res->params["id"] = params.at("variable_id");
|
||||||
|
res->params["index"] = "0";
|
||||||
|
res->params["size"] = "2";
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
|
||||||
|
addSpecificCreator({"ReadValue"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
|
LayerParams attrs = {node->get_friendly_name(), "Memory",
|
||||||
|
details::convertPrecision(node->get_output_element_type(0))};
|
||||||
|
auto res = std::make_shared<CNNLayer>(attrs);
|
||||||
|
res->params["id"] = params.at("variable_id");
|
||||||
|
res->params["index"] = "1";
|
||||||
|
res->params["size"] = "2";
|
||||||
|
return res;
|
||||||
|
});
|
||||||
|
|
||||||
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
addSpecificCreator({"RNNCell"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||||
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
|
THROW_IE_EXCEPTION << "RNNCell operation has a form that is not supported." << node->get_friendly_name()
|
||||||
@ -656,14 +678,16 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
|||||||
// Collect all names from current graph
|
// Collect all names from current graph
|
||||||
// It is necessary in order to differentiate outputs from constant layers when we share constants
|
// It is necessary in order to differentiate outputs from constant layers when we share constants
|
||||||
// (Constant operations contains outputs for converted and original functions)
|
// (Constant operations contains outputs for converted and original functions)
|
||||||
|
const ngraph::NodeVector& nodes = graph->get_ops();
|
||||||
|
|
||||||
std::unordered_set<std::string> op_names;
|
std::unordered_set<std::string> op_names;
|
||||||
for (const auto &layer : graph->get_ops())
|
for (const auto &layer : nodes)
|
||||||
op_names.insert(layer->get_name());
|
op_names.insert(layer->get_name());
|
||||||
|
|
||||||
bool keep_constants = ::ngraph::op::util::has_op_with_type<::ngraph::op::FakeQuantize>(graph);
|
bool keep_constants = ::ngraph::op::util::has_op_with_type<::ngraph::op::FakeQuantize>(graph);
|
||||||
|
|
||||||
// Create layers and output data
|
// Create layers and output data
|
||||||
for (const auto &layer : graph->get_ops()) {
|
for (const auto &layer : nodes) {
|
||||||
if (isInternalLayer(layer, op_names, keep_constants)) continue;
|
if (isInternalLayer(layer, op_names, keep_constants)) continue;
|
||||||
|
|
||||||
// TODO: remove this rt info when all blobs will be inputs
|
// TODO: remove this rt info when all blobs will be inputs
|
||||||
@ -703,8 +727,18 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
|||||||
}
|
}
|
||||||
inputCount++;
|
inputCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "1") {
|
||||||
|
inputCount = 0;
|
||||||
|
}
|
||||||
|
|
||||||
cnnLayer->insData.resize(inputCount);
|
cnnLayer->insData.resize(inputCount);
|
||||||
|
|
||||||
for (size_t i = 0; i < layer->get_output_size(); i++) {
|
for (size_t i = 0; i < layer->get_output_size(); i++) {
|
||||||
|
if (cnnLayer->type == "Memory" && cnnLayer->params["index"] == "0") {
|
||||||
|
cnnLayer->outData.clear();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::string outName = layer->get_friendly_name();
|
std::string outName = layer->get_friendly_name();
|
||||||
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
|
if (layer->get_output_size() != 1) outName += "." + std::to_string(i);
|
||||||
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
|
DataPtr &ptr = cnnNetworkImpl->getData(outName.c_str());
|
||||||
@ -747,6 +781,8 @@ std::shared_ptr<CNNNetworkImpl> convertFunctionToICNNNetwork(const std::shared_p
|
|||||||
|
|
||||||
// Set input data
|
// Set input data
|
||||||
for (const auto &layer : graph->get_ordered_ops()) {
|
for (const auto &layer : graph->get_ordered_ops()) {
|
||||||
|
if (std::dynamic_pointer_cast<::ngraph::op::ReadValue>(layer))
|
||||||
|
continue;
|
||||||
if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) {
|
if (std::dynamic_pointer_cast<::ngraph::op::Result>(layer)) {
|
||||||
IE_ASSERT(layer->get_inputs().size() == 1);
|
IE_ASSERT(layer->get_inputs().size() == 1);
|
||||||
const auto &input = layer->input_value(0);
|
const auto &input = layer->input_value(0);
|
||||||
|
@ -62,9 +62,9 @@ void ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::convert_normalize_l2_
|
|||||||
across_spatial,
|
across_spatial,
|
||||||
channel_shared);
|
channel_shared);
|
||||||
|
|
||||||
normalize_ie->set_friendly_name(normalize->get_friendly_name());
|
normalize_ie->set_friendly_name(mul->get_friendly_name());
|
||||||
ngraph::copy_runtime_info(normalize, normalize_ie);
|
ngraph::copy_runtime_info({normalize, mul}, normalize_ie);
|
||||||
ngraph::replace_node(normalize, normalize_ie);
|
ngraph::replace_node(mul, normalize_ie);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ extensions/back/ParameterToPlaceholder.py
|
|||||||
extensions/back/pass_separator.py
|
extensions/back/pass_separator.py
|
||||||
extensions/back/priorbox_mutation.py
|
extensions/back/priorbox_mutation.py
|
||||||
extensions/back/ProposalMutation.py
|
extensions/back/ProposalMutation.py
|
||||||
|
extensions/back/ReadValueAssignToMemory.py
|
||||||
extensions/back/ReduceToPooling.py
|
extensions/back/ReduceToPooling.py
|
||||||
extensions/back/ReduceTransposeDimensions.py
|
extensions/back/ReduceTransposeDimensions.py
|
||||||
extensions/back/remove_last_softmax_pattern.py
|
extensions/back/remove_last_softmax_pattern.py
|
||||||
@ -872,6 +873,7 @@ mo/middle/pattern_match.py
|
|||||||
mo/middle/replacement.py
|
mo/middle/replacement.py
|
||||||
mo/ops/__init__.py
|
mo/ops/__init__.py
|
||||||
mo/ops/activation.py
|
mo/ops/activation.py
|
||||||
|
mo/ops/assign.py
|
||||||
mo/ops/broadcast.py
|
mo/ops/broadcast.py
|
||||||
mo/ops/clamp.py
|
mo/ops/clamp.py
|
||||||
mo/ops/concat.py
|
mo/ops/concat.py
|
||||||
@ -897,6 +899,7 @@ mo/ops/pad.py
|
|||||||
mo/ops/permute.py
|
mo/ops/permute.py
|
||||||
mo/ops/pooling.py
|
mo/ops/pooling.py
|
||||||
mo/ops/power.py
|
mo/ops/power.py
|
||||||
|
mo/ops/read_value.py
|
||||||
mo/ops/reshape.py
|
mo/ops/reshape.py
|
||||||
mo/ops/result.py
|
mo/ops/result.py
|
||||||
mo/ops/roipooling.py
|
mo/ops/roipooling.py
|
||||||
|
@ -23,7 +23,7 @@ from mo.ops.crop import Crop
|
|||||||
from mo.utils.logger import log
|
from mo.utils.logger import log
|
||||||
|
|
||||||
|
|
||||||
class CutMemory(BackReplacementPattern):
|
class CutMemoryInput(BackReplacementPattern):
|
||||||
"""
|
"""
|
||||||
Cut Memory layers and have inputs/outputs in graph instead of them
|
Cut Memory layers and have inputs/outputs in graph instead of them
|
||||||
"""
|
"""
|
||||||
@ -38,30 +38,56 @@ class CutMemory(BackReplacementPattern):
|
|||||||
def pattern():
|
def pattern():
|
||||||
return dict(
|
return dict(
|
||||||
nodes=[
|
nodes=[
|
||||||
('op', dict(kind='op', op='Memory'))],
|
('op', dict(kind='op', op='ReadValue'))],
|
||||||
edges=[]
|
edges=[]
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def replace_pattern(graph: Graph, match: dict):
|
def replace_pattern(graph: Graph, match: dict):
|
||||||
node = match['op']
|
node = match['op']
|
||||||
node_id = node['id']
|
node_id = node['variable_id']
|
||||||
|
|
||||||
if node.in_port(0).disconnected():
|
i = 0
|
||||||
i = 0
|
node.in_port(0).disconnect()
|
||||||
for dest in node.out_port(0).get_destinations():
|
for dest in node.out_port(0).get_destinations():
|
||||||
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
|
new_in = Parameter(graph, {'name': "Parameter_"+str(i)+"_for_"+node_id,
|
||||||
'shape': dest.data.get_shape()}).create_node()
|
'shape': dest.data.get_shape()}).create_node()
|
||||||
i += 1
|
i += 1
|
||||||
dest.disconnect()
|
dest.disconnect()
|
||||||
new_in.out_port(0).connect(dest)
|
new_in.out_port(0).connect(dest)
|
||||||
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
|
log.error("Add input/output mapped {} -> {} ".format(new_in.name, "Result_for_"+node_id),
|
||||||
extra={'is_warning': True})
|
extra={'is_warning': True})
|
||||||
else:
|
|
||||||
out_node_port = node.out_port(0).get_destination()
|
|
||||||
in_node_port = node.in_port(0).get_source()
|
class CutMemoryOutput(BackReplacementPattern):
|
||||||
node.in_port(0).disconnect()
|
"""
|
||||||
node.out_port(0).disconnect()
|
Cut Memory layers and have inputs/outputs in graph instead of them
|
||||||
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]), 'axis': np.array([0])}).create_node()
|
"""
|
||||||
in_node_port.connect(crop.in_port(0))
|
enabled = True
|
||||||
crop.out_port(0).connect(out_node_port)
|
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi" and graph.graph['cmd_params'].remove_memory]
|
||||||
|
force_clean_up = True
|
||||||
|
|
||||||
|
def run_before(self):
|
||||||
|
return [ParameterToInput]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pattern():
|
||||||
|
return dict(
|
||||||
|
nodes=[
|
||||||
|
('op', dict(kind='op', op='Assign'))],
|
||||||
|
edges=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def replace_pattern(graph: Graph, match: dict):
|
||||||
|
node = match['op']
|
||||||
|
node_id = node['variable_id']
|
||||||
|
|
||||||
|
out_node_port = node.out_port(0).get_destination()
|
||||||
|
in_node_port = node.in_port(0).get_source()
|
||||||
|
node.in_port(0).disconnect()
|
||||||
|
node.out_port(0).disconnect()
|
||||||
|
crop = Crop(graph, {'name': 'Result_for_'+node_id, 'dim': np.array([1]), 'offset': np.array([0]),
|
||||||
|
'axis': np.array([0])}).create_node()
|
||||||
|
in_node_port.connect(crop.in_port(0))
|
||||||
|
crop.out_port(0).connect(out_node_port)
|
||||||
|
@ -17,7 +17,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.back.CutMemory import CutMemory
|
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
|
||||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
from mo.utils.unittest.graph import build_graph
|
from mo.utils.unittest.graph import build_graph
|
||||||
|
|
||||||
@ -29,18 +29,21 @@ class CutMemoryTest(unittest.TestCase):
|
|||||||
nodes_attrs={
|
nodes_attrs={
|
||||||
'input': {'kind': 'op'},
|
'input': {'kind': 'op'},
|
||||||
'data_in': {'kind': 'data', 'shape': None, 'value': None},
|
'data_in': {'kind': 'data', 'shape': None, 'value': None},
|
||||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'index': 1, 'id': 'memory_', 'in_ports_count': 1},
|
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||||
|
'const_0_data': {'kind': 'data'},
|
||||||
|
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'variable_id': 'memory_'},
|
||||||
'data_mem': {'kind': 'data', 'shape': None, 'value': None},
|
'data_mem': {'kind': 'data', 'shape': None, 'value': None},
|
||||||
'concat': {'kind': 'op', 'op': 'Concat', 'axis': 0},
|
'concat': {'kind': 'op', 'op': 'Concat', 'axis': 0},
|
||||||
'concat_data': {'kind': 'data', 'shape': None, 'value': None},
|
'concat_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||||
'some_op': {'kind': 'op'},
|
'some_op': {'kind': 'op'},
|
||||||
'some_op_data': {'kind': 'data', 'shape': None, 'value': None},
|
'some_op_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'index': 0, 'id': 'memory_'},
|
'memory_out': {'kind': 'op', 'op': 'Assign', 'variable_id': 'memory_'},
|
||||||
'data_mem_out': {'kind': 'data', 'shape': None, 'value': None},
|
'data_mem_out': {'kind': 'data', 'shape': None, 'value': None},
|
||||||
'mem_out_result': {'kind': 'op', 'op': 'Result'}
|
'mem_out_result': {'kind': 'op', 'op': 'Result'}
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
('input', 'data_in'), ('memory_in', 'data_mem'),
|
('input', 'data_in'),
|
||||||
|
('const_0', 'const_0_data'), ('const_0_data', 'memory_in'), ('memory_in', 'data_mem'),
|
||||||
('data_in', 'concat', {'in': 0}), ('data_mem', 'concat', {'in': 1}),
|
('data_in', 'concat', {'in': 0}), ('data_mem', 'concat', {'in': 1}),
|
||||||
('concat', 'concat_data'), ('concat_data', 'some_op'),
|
('concat', 'concat_data'), ('concat_data', 'some_op'),
|
||||||
('some_op', 'some_op_data'), ('some_op_data', 'memory_out'),
|
('some_op', 'some_op_data'), ('some_op_data', 'memory_out'),
|
||||||
@ -69,7 +72,8 @@ class CutMemoryTest(unittest.TestCase):
|
|||||||
('crop', 'crop_data'), ('crop_data', 'mem_out_result')
|
('crop', 'crop_data'), ('crop_data', 'mem_out_result')
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
CutMemory().find_and_replace_pattern(graph)
|
CutMemoryInput().find_and_replace_pattern(graph)
|
||||||
|
CutMemoryOutput().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mem_out_result', check_op_attrs=True)
|
(flag, resp) = compare_graphs(graph, graph_ref, last_node='mem_out_result', check_op_attrs=True)
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
140
model-optimizer/extensions/back/ReadValueAssignToMemory.py
Normal file
140
model-optimizer/extensions/back/ReadValueAssignToMemory.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
Copyright (C) 2020 Intel Corporation
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
from extensions.back.CutMemory import CutMemoryInput, CutMemoryOutput
|
||||||
|
from mo.back.replacement import BackReplacementPattern
|
||||||
|
from mo.graph.graph import Graph
|
||||||
|
from mo.ops.memory import Memory
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
All transformations in this file should be removed after removing IR v7 support
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceReadValueByMemory(BackReplacementPattern):
|
||||||
|
"""
|
||||||
|
Replace ReadValue by Memory. Should be removed after v7 IR support removing.
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||||
|
force_clean_up = True
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
return [CutMemoryInput]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pattern():
|
||||||
|
return dict(
|
||||||
|
nodes=[
|
||||||
|
('op', dict(kind='op', op='ReadValue'))],
|
||||||
|
edges=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def replace_pattern(graph: Graph, match: dict):
|
||||||
|
node = match['op']
|
||||||
|
node_id = node['variable_id']
|
||||||
|
|
||||||
|
node.in_port(0).disconnect()
|
||||||
|
new_in = Memory(graph, {'name': node.id, 'id': node_id, 'index': 1, 'size': 2,
|
||||||
|
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
|
||||||
|
for dest in node.out_port(0).get_destinations():
|
||||||
|
dest.disconnect()
|
||||||
|
new_in.out_port(0).connect(dest)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplaceAssignByMemory(BackReplacementPattern):
|
||||||
|
"""
|
||||||
|
Replace Assign by Memory. Should be removed after v7 IR support removing.
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||||
|
force_clean_up = True
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
return [CutMemoryOutput]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pattern():
|
||||||
|
return dict(
|
||||||
|
nodes=[
|
||||||
|
('op', dict(kind='op', op='Assign'))],
|
||||||
|
edges=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def replace_pattern(graph: Graph, match: dict):
|
||||||
|
node = match['op']
|
||||||
|
node_id = node['variable_id']
|
||||||
|
|
||||||
|
new_out = Memory(graph, {'name': node.id, 'id': node_id, 'index': 0, 'size': 2,
|
||||||
|
'shape': list(node.out_port(0).data.get_shape())[1:]}).create_node()
|
||||||
|
node.in_port(0).get_source().connect(new_out.in_port(0))
|
||||||
|
node.in_port(0).disconnect()
|
||||||
|
node.out_port(0).get_connection().set_source(new_out.out_port(0))
|
||||||
|
|
||||||
|
|
||||||
|
class KaldiRemoveMemoryOutputBackReplacementPatternV7(BackReplacementPattern):
|
||||||
|
enabled = True
|
||||||
|
graph_condition = [lambda graph: not graph.graph['cmd_params'].generate_experimental_IR_V10]
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
from extensions.back.pass_separator import BackFinish
|
||||||
|
return [BackFinish]
|
||||||
|
|
||||||
|
def run_before(self):
|
||||||
|
from extensions.back.SpecialNodesFinalization import RemoveOutputOps
|
||||||
|
return [RemoveOutputOps]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pattern():
|
||||||
|
return dict(
|
||||||
|
nodes=[
|
||||||
|
('memory_node', dict(op='Memory')),
|
||||||
|
('data_node', dict(kind='data')),
|
||||||
|
('op_output', dict(op='Result'))
|
||||||
|
],
|
||||||
|
edges=[
|
||||||
|
('memory_node', 'data_node'),
|
||||||
|
('data_node', 'op_output')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def replace_pattern(graph: Graph, match: dict):
|
||||||
|
"""
|
||||||
|
Need to find the pattern: Memory -> Data -> Result
|
||||||
|
|
||||||
|
It is needed to make Memory nodes appear in IR,
|
||||||
|
but they are output nodes by default and we remove the Result node after each output memory.
|
||||||
|
|
||||||
|
DO NOT use graph clean up after it
|
||||||
|
otherwise Memory nodes would be removed as they are not on the path from input to output
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
graph : Graph
|
||||||
|
Graph with loaded model.
|
||||||
|
match : dict
|
||||||
|
Patterns which were found in graph structure.
|
||||||
|
"""
|
||||||
|
memory = match['memory_node']
|
||||||
|
data = match['data_node']
|
||||||
|
op_output = match['op_output']
|
||||||
|
|
||||||
|
graph.remove_edge(memory.id, data.id)
|
||||||
|
graph.remove_node(data.id)
|
||||||
|
graph.remove_node(op_output.id)
|
@ -33,7 +33,7 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
|
|||||||
def pattern():
|
def pattern():
|
||||||
return dict(
|
return dict(
|
||||||
nodes=[
|
nodes=[
|
||||||
('memory_node', dict(op='Memory')),
|
('memory_node', dict(op='Assign')),
|
||||||
('data_node', dict(kind='data')),
|
('data_node', dict(kind='data')),
|
||||||
('op_output', dict(op='Result'))
|
('op_output', dict(op='Result'))
|
||||||
],
|
],
|
||||||
@ -63,6 +63,8 @@ class KaldiRemoveMemoryOutputBackReplacementPattern(BackReplacementPattern):
|
|||||||
"""
|
"""
|
||||||
memory = match['memory_node']
|
memory = match['memory_node']
|
||||||
data = match['data_node']
|
data = match['data_node']
|
||||||
|
op_output = match['op_output']
|
||||||
|
|
||||||
graph.remove_edge(memory.id, data.id)
|
graph.remove_edge(memory.id, data.id)
|
||||||
graph.remove_node(data.id)
|
graph.remove_node(data.id)
|
||||||
|
graph.remove_node(op_output.id)
|
||||||
|
@ -26,7 +26,7 @@ class KaldiRemoveMemoryOutputTest(unittest.TestCase):
|
|||||||
'kind': 'data'
|
'kind': 'data'
|
||||||
},
|
},
|
||||||
'memory_node': {
|
'memory_node': {
|
||||||
'op': 'Memory',
|
'op': 'Assign',
|
||||||
'kind': 'op'
|
'kind': 'op'
|
||||||
},
|
},
|
||||||
'output_node': {
|
'output_node': {
|
||||||
|
@ -63,7 +63,7 @@ def apply_biases_to_last_layer(graph, counts):
|
|||||||
outputs_ids = find_outputs(graph)
|
outputs_ids = find_outputs(graph)
|
||||||
for output in outputs_ids.copy():
|
for output in outputs_ids.copy():
|
||||||
node = Node(graph, output)
|
node = Node(graph, output)
|
||||||
if node.op != 'Memory':
|
if node.op != 'Assign':
|
||||||
continue
|
continue
|
||||||
outputs_ids.remove(output)
|
outputs_ids.remove(output)
|
||||||
|
|
||||||
|
@ -13,12 +13,12 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
from extensions.ops.elementwise import Add, Mul
|
||||||
from extensions.ops.split import Split
|
from extensions.ops.split import Split
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.common.replacement import FrontReplacementOp
|
from mo.front.common.replacement import FrontReplacementOp
|
||||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
from mo.graph.graph import Node, Graph
|
from mo.graph.graph import Node, Graph
|
||||||
from mo.ops.eltwise import Eltwise
|
|
||||||
from mo.ops.eltwise_n import EltwiseN
|
from mo.ops.eltwise_n import EltwiseN
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
|
|
||||||
@ -43,8 +43,12 @@ class ReplaceEltwiseNin1NodePattern(FrontReplacementOp):
|
|||||||
edge_attrs = inp[0][1]
|
edge_attrs = inp[0][1]
|
||||||
graph.add_edge(in_node, ss_node.id, **edge_attrs)
|
graph.add_edge(in_node, ss_node.id, **edge_attrs)
|
||||||
if ss_node.num_splits == 2:
|
if ss_node.num_splits == 2:
|
||||||
eltwise_node = Eltwise(graph, attrs={'name': 'Eltwise_' + node.name,
|
if node['operation'] == 'mul':
|
||||||
'operation': node['operation']}).create_node()
|
eltwise_node = Mul(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
|
||||||
|
elif node['operation'] == 'sum':
|
||||||
|
eltwise_node = Add(graph, attrs={'name': 'Eltwise_' + node.name}).create_node()
|
||||||
|
else:
|
||||||
|
raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation'])
|
||||||
elif ss_node.num_splits > 2:
|
elif ss_node.num_splits > 2:
|
||||||
eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
|
eltwise_node = EltwiseN(graph, attrs={'name': 'Eltwise_' + node.name,
|
||||||
'operation': node['operation']}).create_node()
|
'operation': node['operation']}).create_node()
|
||||||
|
@ -20,13 +20,19 @@ from extensions.ops.activation_ops import Tanh, Sigmoid
|
|||||||
from extensions.ops.elementwise import Add, Mul
|
from extensions.ops.elementwise import Add, Mul
|
||||||
from extensions.ops.split import Split
|
from extensions.ops.split import Split
|
||||||
from mo.front.caffe.extractors.utils import input_as_const
|
from mo.front.caffe.extractors.utils import input_as_const
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.common.replacement import FrontReplacementOp
|
from mo.front.common.replacement import FrontReplacementOp
|
||||||
from mo.graph.graph import Node, Graph
|
from mo.graph.graph import Node, Graph, Port
|
||||||
|
from mo.ops.assign import Assign
|
||||||
|
from mo.ops.broadcast import Broadcast
|
||||||
from mo.ops.clamp import Clamp
|
from mo.ops.clamp import Clamp
|
||||||
|
from mo.ops.crop import Crop
|
||||||
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.memory import Memory
|
from mo.ops.read_value import ReadValue
|
||||||
from mo.ops.result import Result
|
from mo.ops.result import Result
|
||||||
from mo.ops.scale_shift import ScaleShiftOp
|
from mo.ops.scale_shift import ScaleShiftOp
|
||||||
|
from mo.ops.shape import Shape
|
||||||
|
|
||||||
|
|
||||||
def unique_id(prefix: str = 'id') -> str:
|
def unique_id(prefix: str = 'id') -> str:
|
||||||
@ -46,6 +52,35 @@ def unique_id(prefix: str = 'id') -> str:
|
|||||||
unique_id.names = []
|
unique_id.names = []
|
||||||
|
|
||||||
|
|
||||||
|
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
|
||||||
|
# create init_graph connected to ReadValue
|
||||||
|
graph = input_out_port.node.graph
|
||||||
|
input_name = input_out_port.node.name
|
||||||
|
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
|
||||||
|
shape_of_input.in_port(0).connect(input_out_port)
|
||||||
|
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
|
||||||
|
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
|
||||||
|
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
|
||||||
|
'axis': int64_array([0]), 'offset': int64_array([0])
|
||||||
|
}).create_node()
|
||||||
|
get_batch.in_port(0).connect(shape_of_input.out_port(0))
|
||||||
|
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
|
||||||
|
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
|
||||||
|
'value': int64_array([second_dim]),
|
||||||
|
'shape': int64_array([1])}).create_node()
|
||||||
|
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
|
||||||
|
'axis': 0, 'in_ports_count': 2}).create_node()
|
||||||
|
mem_shape.in_port(0).connect(get_batch.out_port(0))
|
||||||
|
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
|
||||||
|
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
|
||||||
|
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
|
||||||
|
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
|
||||||
|
}).create_node()
|
||||||
|
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
|
||||||
|
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
|
||||||
|
return init_value_prev_lstm_output
|
||||||
|
|
||||||
|
|
||||||
class ReplaceLSTMNodePattern(FrontReplacementOp):
|
class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||||
op = "LSTMCell"
|
op = "LSTMCell"
|
||||||
enabled = True
|
enabled = True
|
||||||
@ -69,7 +104,7 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def replace_op(self, graph: Graph, node: Node):
|
def replace_op(self, graph: Graph, node: Node):
|
||||||
input_node = node.in_node()
|
input_out_port = node.in_port(0).get_source()
|
||||||
|
|
||||||
memory_pair_input = unique_id('id')
|
memory_pair_input = unique_id('id')
|
||||||
memory_pair_output = unique_id('id')
|
memory_pair_output = unique_id('id')
|
||||||
@ -81,16 +116,17 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
|||||||
'bias_term': True,
|
'bias_term': True,
|
||||||
}
|
}
|
||||||
|
|
||||||
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node([input_node])
|
fc_layer_after_input = FullyConnected(graph, fc_layer_after_input_attrs).create_node()
|
||||||
|
fc_layer_after_input.in_port(0).connect(input_out_port)
|
||||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
|
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
|
||||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
|
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
|
||||||
|
|
||||||
prev_lstm_output = Memory(graph, {'name': 'prev_memory_output',
|
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
|
||||||
'id': memory_pair_input,
|
node.gifo_r_weights_shape[1])
|
||||||
'index': 1,
|
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
|
||||||
'size': 2,
|
'variable_id': memory_pair_input
|
||||||
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
|
}).create_node()
|
||||||
}).create_node()
|
prev_lstm_output.in_port(0).connect(init_value_prev_lstm_output.out_port(0))
|
||||||
|
|
||||||
# *Memory(output) -> FullyConnected
|
# *Memory(output) -> FullyConnected
|
||||||
fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
|
fc_layer_from_prev_state_attrs = {'name': 'prev_memory_output_fullyconnected',
|
||||||
@ -99,15 +135,16 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
|||||||
'bias_term': False,
|
'bias_term': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node(
|
fc_layer_from_prev_state = FullyConnected(graph, fc_layer_from_prev_state_attrs).create_node()
|
||||||
[prev_lstm_output])
|
fc_layer_from_prev_state.in_port(0).connect(prev_lstm_output.out_port(0))
|
||||||
input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
|
input_as_const(fc_layer_from_prev_state, fc_layer_from_prev_state_attrs, 1, 'weights', node.gifo_r_weights)
|
||||||
|
|
||||||
# Memory -> FullyConnected \
|
# Memory -> FullyConnected \
|
||||||
# *Eltwise(sum)
|
# *Eltwise(sum)
|
||||||
# Input -> FullyConnected /
|
# Input -> FullyConnected /
|
||||||
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise',
|
join_input_prev_state_sum = Add(graph, {'name': 'join_input_eltwise'}).create_node()
|
||||||
}).create_node([fc_layer_from_prev_state, fc_layer_after_input])
|
join_input_prev_state_sum.in_port(0).connect(fc_layer_from_prev_state.out_port(0))
|
||||||
|
join_input_prev_state_sum.in_port(1).connect(fc_layer_after_input.out_port(0))
|
||||||
|
|
||||||
# *Eltwise(sum) -> Split
|
# *Eltwise(sum) -> Split
|
||||||
# it is split into 4 nodes: Act, Eltw*3
|
# it is split into 4 nodes: Act, Eltw*3
|
||||||
@ -120,131 +157,147 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
|||||||
# |____(4)Eltwise(sum)
|
# |____(4)Eltwise(sum)
|
||||||
split_joined_input_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
split_joined_input_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||||
split_joined_input = Split(graph, {'name': 'join_input_split',
|
split_joined_input = Split(graph, {'name': 'join_input_split',
|
||||||
'num_splits': 4,
|
'num_splits': 4, 'out_ports_count': 4}).create_node()
|
||||||
}).create_node([join_input_prev_state_sum, split_joined_input_axis])
|
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
|
||||||
|
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
|
||||||
|
|
||||||
prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
||||||
'id': memory_pair_output,
|
# 'id': memory_pair_output,
|
||||||
'index': 1,
|
# 'index': 1,
|
||||||
'size': 2,
|
# 'size': 2,
|
||||||
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||||
}).create_node()
|
# }).create_node()
|
||||||
|
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
|
||||||
|
node.input_gate_weights.shape[0])
|
||||||
|
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
|
||||||
|
'variable_id': memory_pair_output}).create_node()
|
||||||
|
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))
|
||||||
|
|
||||||
# *Memory(state) -> *ScaleShift(input)
|
# *Memory(state) -> *ScaleShift(input)
|
||||||
state_input_scaleshift_attrs = {'name': 'input_scaleshift',
|
state_input_scaleshift_attrs = {'name': 'input_scaleshift',
|
||||||
'bias_term': False
|
'bias_term': False
|
||||||
}
|
}
|
||||||
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node([prev_lstm_state])
|
state_input_scaleshift = ScaleShiftOp(graph, state_input_scaleshift_attrs).create_node()
|
||||||
|
state_input_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
|
||||||
input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
|
input_as_const(state_input_scaleshift, state_input_scaleshift_attrs, 1, 'weights', node.input_gate_weights)
|
||||||
|
|
||||||
# *Memory(state) -> *ScaleShift(forget)
|
# *Memory(state) -> *ScaleShift(forget)
|
||||||
state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
|
state_forget_scaleshift_attrs = {'name': 'forget_scaleshift',
|
||||||
'bias_term': False
|
'bias_term': False
|
||||||
}
|
}
|
||||||
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node([prev_lstm_state])
|
state_forget_scaleshift = ScaleShiftOp(graph, state_forget_scaleshift_attrs).create_node()
|
||||||
|
state_forget_scaleshift.in_port(0).connect(prev_lstm_state.out_port(0))
|
||||||
input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
|
input_as_const(state_forget_scaleshift, state_forget_scaleshift_attrs, 1, 'weights', node.forget_gate_weights)
|
||||||
|
|
||||||
# Split \
|
# Split \
|
||||||
# (2)Eltwise(sum)
|
# (2)Eltwise(sum)
|
||||||
# Memory(state) -> *ScaleShift(input) /
|
# Memory(state) -> *ScaleShift(input) /
|
||||||
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise',
|
join_prev_lstm_input_joined_input_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_input_eltwise'
|
||||||
}).create_node([(split_joined_input, 1),
|
}).create_node()
|
||||||
state_input_scaleshift
|
join_prev_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(1))
|
||||||
])
|
join_prev_lstm_input_joined_input_sum.in_port(1).connect(state_input_scaleshift.out_port(0))
|
||||||
# Split \
|
# Split \
|
||||||
# (3)Eltwise(sum)
|
# (3)Eltwise(sum)
|
||||||
# Memory(state) -> *ScaleShift(forget) /
|
# Memory(state) -> *ScaleShift(forget) /
|
||||||
join_prev_lstm_input_joined_forget_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
|
join_prev_lstm_input_joined_forget_sum = Add(graph, {'name': 'join_prev_lstm_input_joined_forget_sum',
|
||||||
}).create_node([(split_joined_input, 2),
|
}).create_node()
|
||||||
state_forget_scaleshift
|
join_prev_lstm_input_joined_forget_sum.in_port(0).connect(split_joined_input.out_port(2))
|
||||||
])
|
join_prev_lstm_input_joined_forget_sum.in_port(1).connect(state_forget_scaleshift.out_port(0))
|
||||||
|
|
||||||
# Split -> Tanh
|
# Split -> Tanh
|
||||||
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node([(split_joined_input, 0)])
|
remember_tahn = Tanh(graph, {'name': 'remember_tahnv'}).create_node()
|
||||||
|
remember_tahn.in_port(0).connect(split_joined_input.out_port(0))
|
||||||
|
|
||||||
# Split -> (2)Eltwise(sum) -> *Sigmoid
|
# Split -> (2)Eltwise(sum) -> *Sigmoid
|
||||||
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'
|
remember_sigmoid = Sigmoid(graph, {'name': 'remember_sigmoid'}).create_node()
|
||||||
}).create_node([join_prev_lstm_input_joined_input_sum])
|
remember_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_input_sum.out_port(0))
|
||||||
|
|
||||||
# Split -> (3)Eltwise(sum) -> **Sigmoid
|
# Split -> (3)Eltwise(sum) -> **Sigmoid
|
||||||
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'
|
forget_sigmoid = Sigmoid(graph, {'name': 'forget_sigmoid'}).create_node()
|
||||||
}).create_node([join_prev_lstm_input_joined_forget_sum])
|
forget_sigmoid.in_port(0).connect(join_prev_lstm_input_joined_forget_sum.out_port(0))
|
||||||
|
|
||||||
# *Memory(state) \
|
# *Memory(state) \
|
||||||
# (6)Eltwise(mul)
|
# (6)Eltwise(mul)
|
||||||
# Split -> (3)Eltwise(sum) -> **Sigmoid /
|
# Split -> (3)Eltwise(sum) -> **Sigmoid /
|
||||||
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul',
|
join_forget_prev_state_mul = Mul(graph, {'name': 'join_forget_prev_state_mul'}).create_node()
|
||||||
}).create_node([forget_sigmoid, prev_lstm_state])
|
join_forget_prev_state_mul.in_port(0).connect(forget_sigmoid.out_port(0))
|
||||||
|
join_forget_prev_state_mul.in_port(1).connect(prev_lstm_state.out_port(0))
|
||||||
|
|
||||||
# Split -> Tahn \
|
# Split -> Tahn \
|
||||||
# (5)Eltwise(mul)
|
# (5)Eltwise(mul)
|
||||||
# Split -> (2)Eltwise(sum) -> *Sigmoid /
|
# Split -> (2)Eltwise(sum) -> *Sigmoid /
|
||||||
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul',
|
join_remember_candidates_mul = Mul(graph, {'name': 'join_remember_candidates_mul'}).create_node()
|
||||||
}).create_node([remember_tahn, remember_sigmoid])
|
join_remember_candidates_mul.in_port(0).connect(remember_tahn.out_port(0))
|
||||||
|
join_remember_candidates_mul.in_port(1).connect(remember_sigmoid.out_port(0))
|
||||||
|
|
||||||
# (5)Eltwise(mul) \
|
# (5)Eltwise(mul) \
|
||||||
# (7)Eltwise(sum)
|
# (7)Eltwise(sum)
|
||||||
# (6)Eltwise(mul) /
|
# (6)Eltwise(mul) /
|
||||||
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum',
|
join_forget_remember_sum = Add(graph, {'name': 'join_forget_remember_sum'}).create_node()
|
||||||
}).create_node(
|
join_forget_remember_sum.in_port(0).connect(join_forget_prev_state_mul.out_port(0))
|
||||||
[join_forget_prev_state_mul, join_remember_candidates_mul])
|
join_forget_remember_sum.in_port(1).connect(join_remember_candidates_mul.out_port(0))
|
||||||
|
|
||||||
# (7)Eltwise(sum) -> Clamp
|
# (7)Eltwise(sum) -> Clamp
|
||||||
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
|
join_forget_clamp = Clamp(graph, {'name': 'join_forget_clamp',
|
||||||
'max': node.clip_value,
|
'max': node.clip_value,
|
||||||
'min': -node.clip_value
|
'min': -node.clip_value}).create_node()
|
||||||
}).create_node(
|
join_forget_clamp.in_port(0).connect(join_forget_remember_sum.out_port(0))
|
||||||
[join_forget_remember_sum])
|
|
||||||
#
|
#
|
||||||
# Clamp -> (2)Memory(state)
|
# Clamp -> (2)Memory(state)
|
||||||
next_lstm_state = Memory(graph, {'name': 'next_lstm_state',
|
next_lstm_state = Assign(graph, {'name': 'next_lstm_state',
|
||||||
'id': memory_pair_output,
|
'variable_id': memory_pair_output}).create_node()
|
||||||
'index': 0,
|
next_lstm_state.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||||
'size': 2,
|
|
||||||
'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
res_node = Result(graph, {'name': 'next_lstm_state_out'}).create_node()
|
||||||
}).create_node([join_forget_clamp])
|
res_node.in_port(0).connect(next_lstm_state.out_port(0))
|
||||||
Result(graph, {'name': 'next_lstm_state_out'}).create_node([next_lstm_state])
|
|
||||||
|
|
||||||
# Clamp -> (2)Tahn
|
# Clamp -> (2)Tahn
|
||||||
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node([join_forget_clamp])
|
state_filtered_tahn = Tanh(graph, {'name': 'state_filtered_tahn'}).create_node()
|
||||||
|
state_filtered_tahn.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||||
|
|
||||||
# Clamp -> (2)ScaleShift
|
# Clamp -> (2)ScaleShift
|
||||||
clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
|
clamp_scaleshift_attrs = {'name': 'clamp_scaleshift',
|
||||||
'bias_term': False}
|
'bias_term': False}
|
||||||
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node([join_forget_clamp])
|
clamp_scaleshift = ScaleShiftOp(graph, clamp_scaleshift_attrs).create_node()
|
||||||
|
clamp_scaleshift.in_port(0).connect(join_forget_clamp.out_port(0))
|
||||||
input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
|
input_as_const(clamp_scaleshift, clamp_scaleshift_attrs, 1, 'weights', node.output_gate_weights)
|
||||||
|
|
||||||
# Split \
|
# Split \
|
||||||
# (4)Eltwise(sum)
|
# (4)Eltwise(sum)
|
||||||
# Clamp -> (2)ScaleShift /
|
# Clamp -> (2)ScaleShift /
|
||||||
join_next_lstm_input_joined_input_sum = Add(graph, {'name': 'join_next_lstm_input_joined_input_sum',
|
join_next_lstm_input_joined_input_sum = Add(graph, {'name': 'join_next_lstm_input_joined_input_sum',
|
||||||
}).create_node([(split_joined_input, 3), clamp_scaleshift])
|
}).create_node()
|
||||||
|
join_next_lstm_input_joined_input_sum.in_port(0).connect(split_joined_input.out_port(3))
|
||||||
|
join_next_lstm_input_joined_input_sum.in_port(1).connect(clamp_scaleshift.out_port(0))
|
||||||
|
|
||||||
# (4)Eltwise(sum) -> (3)Sigmoid
|
# (4)Eltwise(sum) -> (3)Sigmoid
|
||||||
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node([join_next_lstm_input_joined_input_sum])
|
output_sigmoid = Sigmoid(graph, {'name': 'output_sigmoid'}).create_node()
|
||||||
|
output_sigmoid.in_port(0).connect(join_next_lstm_input_joined_input_sum.out_port(0))
|
||||||
|
|
||||||
# (4)Eltwise(sum) -> (3)Sigmoid \
|
# (4)Eltwise(sum) -> (3)Sigmoid \
|
||||||
# (5)Eltwise(mul)
|
# (5)Eltwise(mul)
|
||||||
# Clamp -> (2)Tahn /
|
# Clamp -> (2)Tahn /
|
||||||
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node([state_filtered_tahn, output_sigmoid])
|
joined_output_mul = Mul(graph, {'name': 'joined_output_mul'}).create_node()
|
||||||
|
joined_output_mul.in_port(0).connect(state_filtered_tahn.out_port(0))
|
||||||
|
joined_output_mul.in_port(1).connect(output_sigmoid.out_port(0))
|
||||||
|
|
||||||
# (5)Eltwise(mul) -> (3)FullyConnected
|
# (5)Eltwise(mul) -> (3)FullyConnected
|
||||||
fc_output_attrs = {'name': 'FullyConnected',
|
fc_output_attrs = {'name': 'FullyConnected',
|
||||||
'out-size': node.projection_weights_shape[0],
|
'out-size': node.projection_weights_shape[0],
|
||||||
'transpose_weights': True,
|
'transpose_weights': True,
|
||||||
'bias_term': False}
|
'bias_term': False}
|
||||||
fc_output = FullyConnected(graph, fc_output_attrs).create_node([joined_output_mul])
|
fc_output = FullyConnected(graph, fc_output_attrs).create_node()
|
||||||
|
fc_output.in_port(0).connect(joined_output_mul.out_port(0))
|
||||||
input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights)
|
input_as_const(fc_output, fc_output_attrs, 1, 'weights', node.projection_weights)
|
||||||
|
|
||||||
# / (2)Memory(output)
|
# / (2)Memory(output)
|
||||||
# (3)FullyConnected
|
# (3)FullyConnected
|
||||||
# \ Output (any next node) (edge created automatically after replacement)
|
# \ Output (any next node) (edge created automatically after replacement)
|
||||||
next_lstm_output = Memory(graph, {'name': 'next_lstm_output',
|
next_lstm_output = Assign(graph, {'name': 'next_lstm_output',
|
||||||
'id': memory_pair_input,
|
'variable_id': memory_pair_input}).create_node()
|
||||||
'index': 0,
|
next_lstm_output.in_port(0).connect(fc_output.out_port(0))
|
||||||
'size': 2,
|
|
||||||
'shape': np.array([node.gifo_r_weights_shape[1]], dtype=np.int64)
|
res_node_lstm_output = Result(graph, {'name': 'next_lstm_output_out'}).create_node()
|
||||||
}).create_node([fc_output])
|
res_node_lstm_output.in_port(0).connect(next_lstm_output.out_port(0))
|
||||||
Result(graph, {'name': 'next_lstm_output_out'}).create_node([next_lstm_output])
|
|
||||||
|
|
||||||
return [fc_output.id]
|
return [fc_output.id]
|
||||||
|
@ -22,7 +22,7 @@ from mo.front.common.replacement import FrontReplacementOp
|
|||||||
from mo.graph.graph import Node, Graph
|
from mo.graph.graph import Node, Graph
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.eltwise import Eltwise
|
from extensions.ops.elementwise import Add, Mul
|
||||||
from mo.ops.scale_shift import ScaleShiftOp
|
from mo.ops.scale_shift import ScaleShiftOp
|
||||||
|
|
||||||
|
|
||||||
@ -41,19 +41,19 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
|||||||
def replace_op(self, graph: Graph, node: Node):
|
def replace_op(self, graph: Graph, node: Node):
|
||||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||||
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||||
split_node = Split(graph, {'name': graph.unique_id(prefix='Split_lstm_input_'),
|
split_node = Split(graph, {'name': 'Split_lstm_input_',
|
||||||
'num_splits': 5}).create_node()
|
'num_splits': 5}).create_node()
|
||||||
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
|
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
|
||||||
split_node.in_port(1).connect(split_node_axis.out_port(0))
|
split_node.in_port(1).connect(split_node_axis.out_port(0))
|
||||||
|
|
||||||
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
||||||
i_scale_attrs = {'name': graph.unique_id(prefix='i_scaleshift'),
|
i_scale_attrs = {'name': 'i_scaleshift',
|
||||||
'bias_term': False}
|
'bias_term': False}
|
||||||
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
|
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
|
||||||
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
|
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
|
||||||
split_node.out_port(4).connect(i_scale.in_port(0))
|
split_node.out_port(4).connect(i_scale.in_port(0))
|
||||||
|
|
||||||
sum_i_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_i_c_'), 'operation': 'sum'}).create_node()
|
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
|
||||||
split_node.out_port(0).connect(sum_i_c.in_port(0))
|
split_node.out_port(0).connect(sum_i_c.in_port(0))
|
||||||
i_scale.out_port(0).connect(sum_i_c.in_port(1))
|
i_scale.out_port(0).connect(sum_i_c.in_port(1))
|
||||||
|
|
||||||
@ -61,13 +61,13 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
|||||||
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
|
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
|
||||||
|
|
||||||
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
||||||
f_scale_attrs = {'name': graph.unique_id(prefix='f_scaleshift'),
|
f_scale_attrs = {'name': 'f_scaleshift',
|
||||||
'bias_term': False}
|
'bias_term': False}
|
||||||
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
|
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
|
||||||
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
|
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
|
||||||
split_node.out_port(4).connect(f_scale.in_port(0))
|
split_node.out_port(4).connect(f_scale.in_port(0))
|
||||||
|
|
||||||
sum_f_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_c_'), 'operation': 'sum'}).create_node()
|
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
|
||||||
split_node.out_port(1).connect(sum_f_c.in_port(0))
|
split_node.out_port(1).connect(sum_f_c.in_port(0))
|
||||||
f_scale.out_port(0).connect(sum_f_c.in_port(1))
|
f_scale.out_port(0).connect(sum_f_c.in_port(1))
|
||||||
|
|
||||||
@ -78,28 +78,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
|||||||
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
|
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
|
||||||
split_node.out_port(2).connect(c_tanh.in_port(0))
|
split_node.out_port(2).connect(c_tanh.in_port(0))
|
||||||
|
|
||||||
prod_i_c_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_i_c_tanh_'),
|
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
|
||||||
'operation': 'mul'}).create_node()
|
|
||||||
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
|
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
|
||||||
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
|
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
|
||||||
|
|
||||||
prod_f_ct_1 = Eltwise(graph, {'name': graph.unique_id(prefix='prod_f_ct_1_'),
|
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
|
||||||
'operation': 'mul'}).create_node()
|
|
||||||
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
|
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
|
||||||
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
|
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
|
||||||
|
|
||||||
sum_f_i = Eltwise(graph, {'name': graph.unique_id(prefix='sum_f_i_'), 'operation': 'sum'}).create_node()
|
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
|
||||||
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
|
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
|
||||||
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
|
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
|
||||||
|
|
||||||
# o_t = Sigmoid(o_part + w_oc*c_t)
|
# o_t = Sigmoid(o_part + w_oc*c_t)
|
||||||
o_scale_attrs = {'name': graph.unique_id(prefix='o_scaleshift'),
|
o_scale_attrs = {'name': 'o_scaleshift',
|
||||||
'bias_term': False}
|
'bias_term': False}
|
||||||
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
|
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
|
||||||
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
|
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
|
||||||
sum_f_i.out_port(0).connect(o_scale.in_port(0))
|
sum_f_i.out_port(0).connect(o_scale.in_port(0))
|
||||||
|
|
||||||
sum_o_c = Eltwise(graph, {'name': graph.unique_id(prefix='sum_o_c_'), 'operation': 'sum'}).create_node()
|
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
|
||||||
split_node.out_port(3).connect(sum_o_c.in_port(0))
|
split_node.out_port(3).connect(sum_o_c.in_port(0))
|
||||||
o_scale.out_port(0).connect(sum_o_c.in_port(1))
|
o_scale.out_port(0).connect(sum_o_c.in_port(1))
|
||||||
|
|
||||||
@ -110,13 +108,12 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
|||||||
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
|
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
|
||||||
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
|
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
|
||||||
|
|
||||||
prod_o_c_t_tanh = Eltwise(graph, {'name': graph.unique_id(prefix='prod_o_c_t_tanh_'),
|
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
|
||||||
'operation': 'mul'}).create_node()
|
|
||||||
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
|
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
|
||||||
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
|
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
|
||||||
|
|
||||||
# add concat to create 1 output
|
# add concat to create 1 output
|
||||||
concat = Concat(graph, {'name': graph.unique_id(prefix='Concat_c_m')}).create_node()
|
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
|
||||||
concat.add_sequence_of_ports('in', range(2))
|
concat.add_sequence_of_ports('in', range(2))
|
||||||
sum_f_i.out_port(0).connect(concat.in_port(0))
|
sum_f_i.out_port(0).connect(concat.in_port(0))
|
||||||
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))
|
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))
|
||||||
|
@ -15,15 +15,18 @@
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||||
|
from extensions.ops.elementwise import Equal
|
||||||
from extensions.ops.select import Select
|
from extensions.ops.select import Select
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
|
from mo.ops.assign import Assign
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.crop import Crop
|
from mo.ops.crop import Crop
|
||||||
from mo.ops.memory import Memory
|
from mo.ops.read_value import ReadValue
|
||||||
from mo.ops.result import Result
|
from mo.ops.result import Result
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
from mo.utils.graph import invert_sub_graph_between_nodes
|
from mo.utils.graph import invert_sub_graph_between_nodes
|
||||||
@ -48,7 +51,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def pattern():
|
def pattern():
|
||||||
return dict(
|
return dict(
|
||||||
nodes=[('op', dict(op='Memory', index=0))],
|
nodes=[('op', dict(op='Assign'))],
|
||||||
edges=[])
|
edges=[])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -93,9 +96,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
select_node.in_port(2).connect(zero_else.out_port(0))
|
select_node.in_port(2).connect(zero_else.out_port(0))
|
||||||
|
|
||||||
# check if we have already appropriate iteration counter
|
# check if we have already appropriate iteration counter
|
||||||
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='Memory', index=1,
|
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
|
||||||
shape=int64_array([context_len]))),
|
('mem_in_data', dict(shape=int64_array([context_len]))),
|
||||||
('mem_in_data', dict()),
|
|
||||||
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
|
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
|
||||||
offset=int64_array([1]),
|
offset=int64_array([1]),
|
||||||
dim=int64_array([context_len-1]))),
|
dim=int64_array([context_len-1]))),
|
||||||
@ -104,8 +106,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
('concat_data', dict()),
|
('concat_data', dict()),
|
||||||
('const_1', dict(op='Const')),
|
('const_1', dict(op='Const')),
|
||||||
('const_1_data', dict()),
|
('const_1_data', dict()),
|
||||||
('mem_out', dict(op='Memory', index=0,
|
('mem_out', dict(op='Assign')),
|
||||||
shape=int64_array([context_len]))),
|
|
||||||
('crop_out', dict(op='Crop', axis=int64_array([1]),
|
('crop_out', dict(op='Crop', axis=int64_array([1]),
|
||||||
offset=int64_array([0]),
|
offset=int64_array([0]),
|
||||||
dim=int64_array([1]))),
|
dim=int64_array([1]))),
|
||||||
@ -122,12 +123,13 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
('crop_out_data', 'select')])
|
('crop_out_data', 'select')])
|
||||||
counter_match = next(existing_counters, None)
|
counter_match = next(existing_counters, None)
|
||||||
if counter_match is not None:
|
if counter_match is not None:
|
||||||
|
ones = Node(graph, inverse_dict(counter_match)['const_1'])
|
||||||
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
|
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
|
||||||
else:
|
else:
|
||||||
mem_out = Memory(graph, {'name': 'iteration_number', 'size': 2,
|
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
|
||||||
'index': 1, 'id': 'iteration_' + node.name,
|
mem_out = ReadValue(graph, {'name': 'iteration_number',
|
||||||
'shape': int64_array([context_len]),
|
'variable_id': 'iteration_'+node.name}).create_node()
|
||||||
'dst_type': np.int32}).create_node()
|
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
|
||||||
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
||||||
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
||||||
cut_first.in_port(0).connect(mem_out.out_port(0))
|
cut_first.in_port(0).connect(mem_out.out_port(0))
|
||||||
@ -135,9 +137,8 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
||||||
concat.in_port(0).connect(cut_first.out_port(0))
|
concat.in_port(0).connect(cut_first.out_port(0))
|
||||||
concat.in_port(1).connect(ones.out_port(0))
|
concat.in_port(1).connect(ones.out_port(0))
|
||||||
mem_in = Memory(graph, {'name': 'iteration_number_out', 'size': 2,
|
mem_in = Assign(graph, {'name': 'iteration_number_out',
|
||||||
'index': 0, 'id': 'iteration_' + node.name,
|
'variable_id': 'iteration_'+node.name}).create_node()
|
||||||
'shape': int64_array([context_len])}).create_node()
|
|
||||||
mem_in.in_port(0).connect(concat.out_port(0))
|
mem_in.in_port(0).connect(concat.out_port(0))
|
||||||
res = Result(graph, {}).create_node()
|
res = Result(graph, {}).create_node()
|
||||||
mem_in.out_port(0).connect(res.in_port(0))
|
mem_in.out_port(0).connect(res.in_port(0))
|
||||||
@ -146,6 +147,12 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
cut_last.in_port(0).connect(concat.out_port(0))
|
cut_last.in_port(0).connect(concat.out_port(0))
|
||||||
input_port = cut_last.out_port(0)
|
input_port = cut_last.out_port(0)
|
||||||
|
|
||||||
select_node.in_port(0).connect(input_port)
|
# Check if data from memory is 1
|
||||||
|
# if it is True, we have correct data and should proceed with saving it to memory
|
||||||
|
# else we have not gathered context and have garbage here, shouldn't change initial state of memory
|
||||||
|
cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
|
||||||
|
cast_in.in_port(0).connect(ones.out_port(0))
|
||||||
|
cast_in.in_port(1).connect(input_port)
|
||||||
|
select_node.in_port(0).connect(cast_in.out_port(0))
|
||||||
select_node.out_port(0).connect(node.in_port(0))
|
select_node.out_port(0).connect(node.in_port(0))
|
||||||
select_node.out_port(0).data.set_shape(in_node_shape)
|
select_node.out_port(0).data.set_shape(in_node_shape)
|
||||||
|
@ -30,23 +30,14 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||||
'placeholder_1': {'kind': 'op', 'op': None},
|
'placeholder_1': {'kind': 'op', 'op': None},
|
||||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'memory')
|
('placeholder_data_1', 'memory')
|
||||||
],
|
],
|
||||||
nodes_with_edges_only=True)
|
nodes_with_edges_only=True)
|
||||||
|
ref_graph = graph.copy()
|
||||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||||
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
|
||||||
'placeholder_1': {'kind': 'op', 'op': None},
|
|
||||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
|
||||||
},
|
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
|
||||||
('placeholder_data_1', 'memory')
|
|
||||||
],
|
|
||||||
nodes_with_edges_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
|
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
@ -60,7 +51,7 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||||
'placeholder_2': {'kind': 'op', 'op': None},
|
'placeholder_2': {'kind': 'op', 'op': None},
|
||||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||||
@ -76,15 +67,32 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||||
'placeholder_2': {'kind': 'op', 'op': None},
|
'placeholder_2': {'kind': 'op', 'op': None},
|
||||||
|
|
||||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||||
|
'shape_data': {'kind': 'data'},
|
||||||
|
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||||
|
'crop_batch_data': {'kind': 'data'},
|
||||||
|
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||||
|
'crop_batch_dim_data': {'kind': 'data'},
|
||||||
|
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||||
|
'second_dim_data': {'kind': 'data'},
|
||||||
|
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'gather_shape_data': {'kind': 'data'},
|
||||||
|
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||||
|
'fill_value_data': {'kind': 'data'},
|
||||||
|
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||||
|
'broadcast_data': {'kind': 'data'},
|
||||||
|
|
||||||
|
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
|
||||||
'memory_in_data': {'kind': 'data'},
|
'memory_in_data': {'kind': 'data'},
|
||||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
|
||||||
'memory_out_data': {'kind': 'data'},
|
'memory_out_data': {'kind': 'data'},
|
||||||
'result': {'kind': 'op', 'op': 'Result'},
|
'result': {'kind': 'op', 'op': 'Result'},
|
||||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||||
'crop_in_data': {'kind': 'data'},
|
'crop_in_data': {'kind': 'data'},
|
||||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||||
'crop_out_data': {'kind': 'data'},
|
'crop_out_data': {'kind': 'data'},
|
||||||
|
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||||
|
'equal_data': {'kind': 'data'},
|
||||||
'select': {'kind': 'op', 'op': 'Select'},
|
'select': {'kind': 'op', 'op': 'Select'},
|
||||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||||
@ -95,22 +103,34 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'concat_data': {'kind': 'data'},
|
'concat_data': {'kind': 'data'},
|
||||||
|
|
||||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||||
('placeholder_data_2', 'select', {'in': 1}),
|
('placeholder_data_2', 'select', {'in': 1}),
|
||||||
|
|
||||||
|
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||||
|
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||||
|
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||||
|
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||||
|
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||||
|
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||||
|
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||||
|
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||||
|
('broadcast_data', 'memory_in'),
|
||||||
|
|
||||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||||
('crop_out_data', 'select', {'in': 0}),
|
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
('equal', 'equal_data'),
|
||||||
|
('equal_data', 'select', {'in': 0}),
|
||||||
|
|
||||||
|
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||||
('select', 'select_out_data'),
|
('select', 'select_out_data'),
|
||||||
('select_out_data', 'memory')
|
('select_out_data', 'memory')
|
||||||
],
|
],
|
||||||
@ -132,7 +152,7 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||||
'placeholder_2': {'kind': 'op', 'op': None},
|
'placeholder_2': {'kind': 'op', 'op': None},
|
||||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||||
@ -151,15 +171,32 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||||
'placeholder_2': {'kind': 'op', 'op': None},
|
'placeholder_2': {'kind': 'op', 'op': None},
|
||||||
|
|
||||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||||
|
'shape_data': {'kind': 'data'},
|
||||||
|
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||||
|
'crop_batch_data': {'kind': 'data'},
|
||||||
|
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||||
|
'crop_batch_dim_data': {'kind': 'data'},
|
||||||
|
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||||
|
'second_dim_data': {'kind': 'data'},
|
||||||
|
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'gather_shape_data': {'kind': 'data'},
|
||||||
|
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||||
|
'fill_value_data': {'kind': 'data'},
|
||||||
|
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||||
|
'broadcast_data': {'kind': 'data'},
|
||||||
|
|
||||||
|
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||||
'memory_in_data': {'kind': 'data'},
|
'memory_in_data': {'kind': 'data'},
|
||||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([5])},
|
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||||
'memory_out_data': {'kind': 'data'},
|
'memory_out_data': {'kind': 'data'},
|
||||||
'result': {'kind': 'op', 'op': 'Result'},
|
'result': {'kind': 'op', 'op': 'Result'},
|
||||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||||
'crop_in_data': {'kind': 'data'},
|
'crop_in_data': {'kind': 'data'},
|
||||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||||
'crop_out_data': {'kind': 'data'},
|
'crop_out_data': {'kind': 'data'},
|
||||||
|
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||||
|
'equal_data': {'kind': 'data'},
|
||||||
'select': {'kind': 'op', 'op': 'Select'},
|
'select': {'kind': 'op', 'op': 'Select'},
|
||||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||||
@ -170,7 +207,7 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'concat_data': {'kind': 'data'},
|
'concat_data': {'kind': 'data'},
|
||||||
|
|
||||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||||
@ -178,13 +215,25 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||||
('placeholder_data_2', 'select', {'in': 1}),
|
('placeholder_data_2', 'select', {'in': 1}),
|
||||||
|
|
||||||
|
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||||
|
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||||
|
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||||
|
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||||
|
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||||
|
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||||
|
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||||
|
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||||
|
('broadcast_data', 'memory_in'),
|
||||||
|
|
||||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||||
('crop_out_data', 'select', {'in': 0}),
|
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||||
|
('equal', 'equal_data'),
|
||||||
|
('equal_data', 'select', {'in': 0}),
|
||||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||||
|
|
||||||
('select', 'select_out_data'),
|
('select', 'select_out_data'),
|
||||||
@ -208,7 +257,7 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||||
'placeholder_2': {'kind': 'op', 'op': None},
|
'placeholder_2': {'kind': 'op', 'op': None},
|
||||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||||
@ -227,15 +276,32 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||||
'placeholder_2': {'kind': 'op', 'op': None},
|
'placeholder_2': {'kind': 'op', 'op': None},
|
||||||
|
|
||||||
'memory_in': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
|
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||||
|
'shape_data': {'kind': 'data'},
|
||||||
|
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||||
|
'crop_batch_data': {'kind': 'data'},
|
||||||
|
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||||
|
'crop_batch_dim_data': {'kind': 'data'},
|
||||||
|
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
|
||||||
|
'second_dim_data': {'kind': 'data'},
|
||||||
|
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'gather_shape_data': {'kind': 'data'},
|
||||||
|
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||||
|
'fill_value_data': {'kind': 'data'},
|
||||||
|
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||||
|
'broadcast_data': {'kind': 'data'},
|
||||||
|
|
||||||
|
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||||
'memory_in_data': {'kind': 'data'},
|
'memory_in_data': {'kind': 'data'},
|
||||||
'memory_out': {'kind': 'op', 'op': 'Memory', 'shape': int64_array([7])},
|
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||||
'memory_out_data': {'kind': 'data'},
|
'memory_out_data': {'kind': 'data'},
|
||||||
'result': {'kind': 'op', 'op': 'Result'},
|
'result': {'kind': 'op', 'op': 'Result'},
|
||||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
|
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
|
||||||
'crop_in_data': {'kind': 'data'},
|
'crop_in_data': {'kind': 'data'},
|
||||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||||
'crop_out_data': {'kind': 'data'},
|
'crop_out_data': {'kind': 'data'},
|
||||||
|
'equal': {'kind': 'op', 'op': 'Equal'},
|
||||||
|
'equal_data': {'kind': 'data'},
|
||||||
'select': {'kind': 'op', 'op': 'Select'},
|
'select': {'kind': 'op', 'op': 'Select'},
|
||||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||||
@ -246,7 +312,7 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
'concat_data': {'kind': 'data'},
|
'concat_data': {'kind': 'data'},
|
||||||
|
|
||||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
'memory': {'kind': 'op', 'op': 'Memory', 'index': 0},
|
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||||
},
|
},
|
||||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||||
@ -254,13 +320,25 @@ class InsertSelectTests(unittest.TestCase):
|
|||||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||||
('placeholder_data_2', 'select', {'in': 1}),
|
('placeholder_data_2', 'select', {'in': 1}),
|
||||||
|
|
||||||
|
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||||
|
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||||
|
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||||
|
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||||
|
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||||
|
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||||
|
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||||
|
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||||
|
('broadcast_data', 'memory_in'),
|
||||||
|
|
||||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||||
('crop_out_data', 'select', {'in': 0}),
|
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||||
|
('equal', 'equal_data'),
|
||||||
|
('equal_data', 'select', {'in': 0}),
|
||||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||||
|
|
||||||
('select', 'select_out_data'),
|
('select', 'select_out_data'),
|
||||||
|
@ -59,7 +59,18 @@ class RemoveUselessCropsPattern(MiddleReplacementPattern):
|
|||||||
if out['op'] == 'Crop' and out['axis'] == axis and \
|
if out['op'] == 'Crop' and out['axis'] == axis and \
|
||||||
len(out.out_port(0).get_destinations()) == 1 and \
|
len(out.out_port(0).get_destinations()) == 1 and \
|
||||||
out.out_port(0).get_destination().node == concat_node:
|
out.out_port(0).get_destination().node == concat_node:
|
||||||
offsets_dims.append((out['offset'], out['dim']))
|
# crop type 1
|
||||||
|
if 'dim' in out:
|
||||||
|
offsets_dims.append((out['offset'], out['dim']))
|
||||||
|
# crop type 3
|
||||||
|
elif 'crop_begin' in out and 'crop_end' in out:
|
||||||
|
offsets_dims.append((out['crop_begin'], out['crop_end']-out['crop_begin']))
|
||||||
|
# crop type 2 with const dim
|
||||||
|
elif not out.in_port(1).disconnected() and out.in_port(1).data.get_value() is not None:
|
||||||
|
offsets_dims.append((out['offset'], out.in_port(1).data.get_value()))
|
||||||
|
# crop type 2 with non-const dim or strange type of crop
|
||||||
|
else:
|
||||||
|
return
|
||||||
crop_list.append(out)
|
crop_list.append(out)
|
||||||
|
|
||||||
offsets_dims.sort(key=lambda off_dim: off_dim[0])
|
offsets_dims.sort(key=lambda off_dim: off_dim[0])
|
||||||
|
@ -84,6 +84,136 @@ class RemoveUselessCropsPatternTests(unittest.TestCase):
|
|||||||
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_useless_crops_type2(self):
|
||||||
|
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
|
||||||
|
'const_26_data': {'kind': 'data', 'value': 26},
|
||||||
|
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'axis': -1},
|
||||||
|
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
},
|
||||||
|
[('placeholder_in', 'in_node'),
|
||||||
|
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||||
|
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
|
||||||
|
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
|
||||||
|
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||||
|
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||||
|
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||||
|
('crop_data_1', 'concat'),
|
||||||
|
('crop_data_2', 'concat'),
|
||||||
|
('crop_data_3', 'concat'),
|
||||||
|
('crop_data_4', 'concat'),
|
||||||
|
('crop_data_5', 'concat'),
|
||||||
|
('concat', 'concat_data'),
|
||||||
|
('concat_data', 'placeholder')])
|
||||||
|
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
|
||||||
|
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'const_26': {'kind': 'op', 'op': 'Const', 'value': 26},
|
||||||
|
'const_26_data': {'kind': 'data', 'value': 26},
|
||||||
|
'crop2': {'kind': 'op', 'op': 'Crop', 'offset': 26, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('placeholder_in', 'in_node'),
|
||||||
|
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||||
|
('in_node', 'crop2', {'in': 0}), ('const_26', 'const_26_data'),
|
||||||
|
('const_26_data', 'crop2', {'in': 1}), ('crop2', 'crop_data_2'),
|
||||||
|
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||||
|
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||||
|
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||||
|
('concat', 'concat_data'),
|
||||||
|
('in_node', 'placeholder')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_useless_crops_type3(self):
|
||||||
|
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
|
||||||
|
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
},
|
||||||
|
[('placeholder_in', 'in_node'),
|
||||||
|
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||||
|
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
|
||||||
|
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||||
|
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||||
|
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||||
|
('crop_data_1', 'concat'),
|
||||||
|
('crop_data_2', 'concat'),
|
||||||
|
('crop_data_3', 'concat'),
|
||||||
|
('crop_data_4', 'concat'),
|
||||||
|
('crop_data_5', 'concat'),
|
||||||
|
('concat', 'concat_data'),
|
||||||
|
('concat_data', 'placeholder')])
|
||||||
|
RemoveUselessCropsPattern().find_and_replace_pattern(graph)
|
||||||
|
ref_graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'crop1': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_1': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop2': {'kind': 'op', 'op': 'Crop', 'crop_begin': 26, 'crop_end': 52, 'axis': -1},
|
||||||
|
'crop_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop3': {'kind': 'op', 'op': 'Crop', 'offset': 52, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_3': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop4': {'kind': 'op', 'op': 'Crop', 'offset': 78, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_4': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'crop5': {'kind': 'op', 'op': 'Crop', 'offset': 104, 'dim': 26, 'axis': -1},
|
||||||
|
'crop_data_5': {'kind': 'data', 'shape': [1, 26]},
|
||||||
|
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'concat_data': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
'placeholder': {'kind': 'op', 'op': 'Parameter'},
|
||||||
|
},
|
||||||
|
[
|
||||||
|
('placeholder_in', 'in_node'),
|
||||||
|
('in_node', 'crop1'), ('crop1', 'crop_data_1'),
|
||||||
|
('in_node', 'crop2'), ('crop2', 'crop_data_2'),
|
||||||
|
('in_node', 'crop3'), ('crop3', 'crop_data_3'),
|
||||||
|
('in_node', 'crop4'), ('crop4', 'crop_data_4'),
|
||||||
|
('in_node', 'crop5'), ('crop5', 'crop_data_5'),
|
||||||
|
('concat', 'concat_data'),
|
||||||
|
('in_node', 'placeholder')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, ref_graph, 'placeholder')
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
def test_useful_crops(self):
|
def test_useful_crops(self):
|
||||||
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
graph = build_graph({'placeholder_in': {'kind': 'op', 'op': 'Parameter'},
|
||||||
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
'in_node': {'kind': 'data', 'shape': [1, 130]},
|
||||||
|
@ -15,13 +15,15 @@
|
|||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||||
from extensions.ops.splice import Splice
|
from extensions.ops.splice import Splice
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
|
from mo.ops.assign import Assign
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.crop import Crop
|
from mo.ops.crop import Crop
|
||||||
from mo.ops.memory import Memory
|
from mo.ops.read_value import ReadValue
|
||||||
from mo.ops.result import Result
|
from mo.ops.result import Result
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
|
|
||||||
@ -67,7 +69,8 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
|
|||||||
|
|
||||||
splice = Splice(graph, {'name': node_name,
|
splice = Splice(graph, {'name': node_name,
|
||||||
'id': node_id,
|
'id': node_id,
|
||||||
'context': int64_array(range(node_t, 1)) if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
|
'context': int64_array(range(node_t, 1))
|
||||||
|
if node_t < 0 else int64_array(range(0, node_t+1))}).create_node()
|
||||||
splice.in_port(0).connect(input_node_out_port)
|
splice.in_port(0).connect(input_node_out_port)
|
||||||
|
|
||||||
# offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
|
# offset of Crop will be 0 (first element) if node_t < 0 and in_shape[1]*node_t (last element) if node_t > 0
|
||||||
@ -106,6 +109,7 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
|
Replace MemoryOffset with Memory if IfDefined used with it to avoid cycles
|
||||||
"""
|
"""
|
||||||
enabled = True
|
enabled = True
|
||||||
|
force_shape_inference = True
|
||||||
|
|
||||||
def run_before(self):
|
def run_before(self):
|
||||||
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
from extensions.middle.RemoveDuplicationMemory import RemoveMemoryDuplicationPattern
|
||||||
@ -141,43 +145,34 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
|||||||
in_shape = input_port.data.get_shape()
|
in_shape = input_port.data.get_shape()
|
||||||
node_t = abs(node.t)
|
node_t = abs(node.t)
|
||||||
|
|
||||||
memory_out = Memory(graph, {'name': pair_name, 'id': node_name+pair_name,
|
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
|
||||||
'index': 1, 'size': 2,
|
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
|
||||||
'shape': np.array([in_shape[1]*node_t])}).create_node()
|
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
|
||||||
|
|
||||||
if node_t > 1:
|
if node_t > 1:
|
||||||
crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
|
crop_concat = Crop(graph, {'name': 'Memory_crop', 'dim': np.array([in_shape[1]*(node_t-1)]),
|
||||||
'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
|
'offset': np.array([in_shape[1]]), 'axis': np.array([1])}).create_node()
|
||||||
memory_out.out_port(0).connect(crop_concat.in_port(0))
|
memory_out.out_port(0).connect(crop_concat.in_port(0))
|
||||||
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
|
||||||
concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
|
concat = Concat(graph, {'name': 'Memory_concat'}).create_node()
|
||||||
concat.add_sequence_of_ports('in', range(2))
|
concat.add_sequence_of_ports('in', range(2))
|
||||||
crop_concat.out_port(0).connect(concat.in_port(0))
|
crop_concat.out_port(0).connect(concat.in_port(0))
|
||||||
crop_concat.out_port(0).data.set_shape(np.array([in_shape[0], crop_concat.dim]))
|
|
||||||
concat.in_port(1).connect(input_port)
|
concat.in_port(1).connect(input_port)
|
||||||
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
|
|
||||||
'index': 0, 'size': 2,
|
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
|
||||||
'shape': memory_out.shape}).create_node()
|
|
||||||
concat.out_port(0).connect(memory_in.in_port(0))
|
concat.out_port(0).connect(memory_in.in_port(0))
|
||||||
concat.out_port(0).data.set_shape(np.array([in_shape[0], memory_in.shape[0]]))
|
|
||||||
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
||||||
memory_in.out_port(0).connect(out.in_port(0))
|
memory_in.out_port(0).connect(out.in_port(0))
|
||||||
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
|
||||||
|
|
||||||
crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
|
crop_out = Crop(graph, {'name': 'Memory_crop_out', 'dim': np.array([in_shape[1]]),
|
||||||
'offset': np.array([0]), 'axis': np.array([1])}).create_node()
|
'offset': np.array([0]), 'axis': np.array([1])}).create_node()
|
||||||
memory_out.out_port(0).connect(crop_out.in_port(0))
|
memory_out.out_port(0).connect(crop_out.in_port(0))
|
||||||
out_port.get_connection().set_source(crop_out.out_port(0))
|
out_port.get_connection().set_source(crop_out.out_port(0))
|
||||||
crop_out.out_port(0).data.set_shape(np.array([in_shape[0], crop_out.dim]))
|
|
||||||
else:
|
else:
|
||||||
memory_in = Memory(graph, {'name': node_name, 'id': node_name + pair_name,
|
memory_in = Assign(graph, {'name': node_name, 'variable_id': node_name + pair_name}).create_node()
|
||||||
'index': 0, 'size': 2,
|
|
||||||
'shape': memory_out.shape}).create_node()
|
|
||||||
memory_in.in_port(0).connect(input_port)
|
memory_in.in_port(0).connect(input_port)
|
||||||
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
out = Result(graph, {'name': 'Memory_output'}).create_node()
|
||||||
memory_in.out_port(0).connect(out.in_port(0))
|
memory_in.out_port(0).connect(out.in_port(0))
|
||||||
memory_in.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
|
||||||
out_port.get_connection().set_source(memory_out.out_port(0))
|
out_port.get_connection().set_source(memory_out.out_port(0))
|
||||||
memory_out.out_port(0).data.set_shape(np.array([in_shape[0], memory_out.shape[0]]))
|
|
||||||
|
|
||||||
graph.remove_node(op_output_id)
|
graph.remove_node(op_output_id)
|
||||||
graph.remove_node(node.id)
|
graph.remove_node(node.id)
|
||||||
|
@ -13,15 +13,16 @@
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
|
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
|
||||||
from extensions.ops.split import VariadicSplit
|
from extensions.ops.split import VariadicSplit
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
|
from mo.ops.assign import Assign
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.crop import Crop
|
from mo.ops.crop import Crop
|
||||||
from mo.ops.memory import Memory
|
from mo.ops.read_value import ReadValue
|
||||||
from mo.ops.result import Result
|
from mo.ops.result import Result
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
|||||||
So this pass will convert this graph to the next one:
|
So this pass will convert this graph to the next one:
|
||||||
|
|
||||||
Input [N, H] __
|
Input [N, H] __
|
||||||
\ /
|
/ /
|
||||||
Concat [N, k*H]
|
Concat [N, k*H]
|
||||||
/ \
|
/ \
|
||||||
Memory [N, k*H] -> Slice [N, (k-1)*H] Memory [N, k*H]
|
Memory [N, k*H] -> Slice [N, (k-1)*H] Memory [N, k*H]
|
||||||
@ -67,11 +68,9 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
|||||||
|
|
||||||
memory_pair_id = unique_id('id')
|
memory_pair_id = unique_id('id')
|
||||||
# Memory(in)
|
# Memory(in)
|
||||||
input_memory = Memory(graph, {'name': 'prev_splice_memory',
|
input_memory = ReadValue(graph, {'name': 'prev_splice_memory',
|
||||||
'id': memory_pair_id,
|
'variable_id': memory_pair_id}).create_node()
|
||||||
'index': 1,
|
|
||||||
'size': 2,
|
|
||||||
'shape': int64_array([memory_size])}).create_node()
|
|
||||||
# Memory(in) \
|
# Memory(in) \
|
||||||
# Crop
|
# Crop
|
||||||
# Input(temp) /
|
# Input(temp) /
|
||||||
@ -90,11 +89,7 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
|||||||
concat_node.in_port(0).connect(crop.out_port(0))
|
concat_node.in_port(0).connect(crop.out_port(0))
|
||||||
|
|
||||||
# Concat -> Memory(out)
|
# Concat -> Memory(out)
|
||||||
mem_out = Memory(graph, {'name': 'out_splice_memory',
|
mem_out = Assign(graph, {'name': 'out_splice_memory', 'variable_id': memory_pair_id}).create_node()
|
||||||
'id': memory_pair_id,
|
|
||||||
'index': 0,
|
|
||||||
'size': 2,
|
|
||||||
'shape': int64_array([memory_size])}).create_node()
|
|
||||||
mem_out.in_port(0).connect(concat_node.out_port(0))
|
mem_out.in_port(0).connect(concat_node.out_port(0))
|
||||||
Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))
|
Result(graph).create_node().in_port(0).connect(mem_out.out_port(0))
|
||||||
|
|
||||||
@ -110,11 +105,12 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
|||||||
|
|
||||||
# create separate splice construction for const_dim
|
# create separate splice construction for const_dim
|
||||||
memory_pair_id = unique_id('memory_for_const_dim')
|
memory_pair_id = unique_id('memory_for_const_dim')
|
||||||
input_memory_const_dim = Memory(graph, {'name': 'const_dim_in_memory',
|
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
|
||||||
'id': memory_pair_id,
|
memory_size_constdim)
|
||||||
'index': 1,
|
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
|
||||||
'size': 2,
|
'variable_id': memory_pair_id}).create_node()
|
||||||
'shape': int64_array([memory_size_constdim])}).create_node()
|
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
|
||||||
|
|
||||||
crop_const_dim = Crop(graph, {'name': 'const_dim_crop',
|
crop_const_dim = Crop(graph, {'name': 'const_dim_crop',
|
||||||
'axis': int64_array([1]),
|
'axis': int64_array([1]),
|
||||||
'offset': int64_array([memory_element_constdim]),
|
'offset': int64_array([memory_element_constdim]),
|
||||||
@ -127,11 +123,8 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
|||||||
'axis': 1}).create_node()
|
'axis': 1}).create_node()
|
||||||
concat_node_const_dim.in_port(0).connect(crop_const_dim.out_port(0))
|
concat_node_const_dim.in_port(0).connect(crop_const_dim.out_port(0))
|
||||||
|
|
||||||
mem_out_const_dim = Memory(graph, {'name': 'const_dim_out_memory',
|
mem_out_const_dim = Assign(graph, {'name': 'const_dim_out_memory',
|
||||||
'id': memory_pair_id,
|
'variable_id': memory_pair_id}).create_node()
|
||||||
'index': 0,
|
|
||||||
'size': 2,
|
|
||||||
'shape': int64_array([memory_size_constdim])}).create_node()
|
|
||||||
mem_out_const_dim.in_port(0).connect(concat_node_const_dim.out_port(0))
|
mem_out_const_dim.in_port(0).connect(concat_node_const_dim.out_port(0))
|
||||||
Result(graph).create_node().in_port(0).connect(mem_out_const_dim.out_port(0))
|
Result(graph).create_node().in_port(0).connect(mem_out_const_dim.out_port(0))
|
||||||
|
|
||||||
@ -148,9 +141,15 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
|||||||
concat_const.in_port(1).connect(crop_first.out_port(0))
|
concat_const.in_port(1).connect(crop_first.out_port(0))
|
||||||
concat_const.in_port(0).connect(concat_node.out_port(0))
|
concat_const.in_port(0).connect(concat_node.out_port(0))
|
||||||
|
|
||||||
|
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
|
||||||
|
memory_size)
|
||||||
|
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||||
node.in_port(0).get_connection().set_destination(split.in_port(0))
|
node.in_port(0).get_connection().set_destination(split.in_port(0))
|
||||||
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
|
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
|
||||||
else:
|
else:
|
||||||
|
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
|
||||||
|
memory_size)
|
||||||
|
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||||
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
|
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
|
||||||
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
|
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
|
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
from mo.utils.unittest.graph import build_graph
|
from mo.utils.unittest.graph import build_graph
|
||||||
@ -42,19 +43,47 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
|||||||
|
|
||||||
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
|
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
|
||||||
'in_node': {'kind': 'data', 'shape': [1, 13]},
|
'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||||
'memory_in': {'kind': 'op', 'op': 'Memory'},
|
|
||||||
|
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||||
|
'shape_data': {'kind': 'data'},
|
||||||
|
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||||
|
'crop_batch_data': {'kind': 'data'},
|
||||||
|
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||||
|
'crop_batch_dim_data': {'kind': 'data'},
|
||||||
|
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
|
||||||
|
'second_dim_data': {'kind': 'data'},
|
||||||
|
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'gather_shape_data': {'kind': 'data'},
|
||||||
|
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||||
|
'fill_value_data': {'kind': 'data'},
|
||||||
|
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||||
|
'broadcast_data': {'kind': 'data'},
|
||||||
|
|
||||||
|
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||||
'memory_in_data': {'kind': 'data'},
|
'memory_in_data': {'kind': 'data'},
|
||||||
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
|
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 13, 'dim': 130},
|
||||||
'crop_mem_data': {'kind': 'data'},
|
'crop_mem_data': {'kind': 'data'},
|
||||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||||
'concat_data': {'kind': 'data', 'shape': [1, 143]},
|
'concat_data': {'kind': 'data', 'shape': [1, 143]},
|
||||||
'memory_out': {'kind': 'op', 'op': 'Memory'},
|
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||||
'memory_out_data': {'kind': 'data'},
|
'memory_out_data': {'kind': 'data'},
|
||||||
'result': {'kind': 'op', 'op': 'Result'},
|
'result': {'kind': 'op', 'op': 'Result'},
|
||||||
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
|
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
('in_placeholder', 'in_node'),
|
('in_placeholder', 'in_node'),
|
||||||
|
|
||||||
|
('in_node', 'shape'), ('shape', 'shape_data'),
|
||||||
|
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||||
|
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||||
|
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||||
|
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||||
|
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||||
|
('gather_shape', 'gather_shape_data'),
|
||||||
|
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||||
|
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||||
|
('broadcast_data', 'memory_in'),
|
||||||
|
|
||||||
('memory_in', 'memory_in_data'),
|
('memory_in', 'memory_in_data'),
|
||||||
('memory_in_data', 'crop_mem'),
|
('memory_in_data', 'crop_mem'),
|
||||||
('crop_mem', 'crop_mem_data'),
|
('crop_mem', 'crop_mem_data'),
|
||||||
@ -86,22 +115,54 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
|||||||
'split': {'kind': 'op', 'op': 'Split'},
|
'split': {'kind': 'op', 'op': 'Split'},
|
||||||
'split_data_0': {'kind': 'data'},
|
'split_data_0': {'kind': 'data'},
|
||||||
'split_data_1': {'kind': 'data'},
|
'split_data_1': {'kind': 'data'},
|
||||||
'memory_in': {'kind': 'op', 'op': 'Memory'},
|
|
||||||
|
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||||
|
'shape_data': {'kind': 'data'},
|
||||||
|
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||||
|
'crop_batch_data': {'kind': 'data'},
|
||||||
|
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||||
|
'crop_batch_dim_data': {'kind': 'data'},
|
||||||
|
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||||
|
'second_dim_data': {'kind': 'data'},
|
||||||
|
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'gather_shape_data': {'kind': 'data'},
|
||||||
|
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||||
|
'fill_value_data': {'kind': 'data'},
|
||||||
|
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||||
|
'broadcast_data': {'kind': 'data'},
|
||||||
|
|
||||||
|
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||||
'memory_in_data': {'kind': 'data'},
|
'memory_in_data': {'kind': 'data'},
|
||||||
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
|
'crop_mem': {'kind': 'op', 'op': 'Crop', 'offset': 3, 'dim': 30},
|
||||||
'crop_mem_data': {'kind': 'data'},
|
'crop_mem_data': {'kind': 'data'},
|
||||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||||
'concat_data': {'kind': 'data'},
|
'concat_data': {'kind': 'data'},
|
||||||
'memory_out': {'kind': 'op', 'op': 'Memory'},
|
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||||
'memory_out_data': {'kind': 'data'},
|
'memory_out_data': {'kind': 'data'},
|
||||||
'result': {'kind': 'op', 'op': 'Result'},
|
'result': {'kind': 'op', 'op': 'Result'},
|
||||||
'memory_in_constdims': {'kind': 'op', 'op': 'Memory'},
|
|
||||||
|
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
|
||||||
|
'shape_2_data': {'kind': 'data'},
|
||||||
|
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||||
|
'crop_batch_2_data': {'kind': 'data'},
|
||||||
|
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||||
|
'crop_batch_dim_2_data': {'kind': 'data'},
|
||||||
|
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||||
|
'second_dim_2_data': {'kind': 'data'},
|
||||||
|
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
|
||||||
|
'gather_shape_2_data': {'kind': 'data'},
|
||||||
|
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||||
|
'fill_value_2_data': {'kind': 'data'},
|
||||||
|
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||||
|
'broadcast_2_data': {'kind': 'data'},
|
||||||
|
|
||||||
|
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
|
||||||
'memory_in_constdims_data': {'kind': 'data'},
|
'memory_in_constdims_data': {'kind': 'data'},
|
||||||
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
|
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
|
||||||
'crop_mem_constdims_data': {'kind': 'data'},
|
'crop_mem_constdims_data': {'kind': 'data'},
|
||||||
'concat_constdims': {'kind': 'op', 'op': 'Concat'},
|
'concat_constdims': {'kind': 'op', 'op': 'Concat'},
|
||||||
'concat_constdims_data': {'kind': 'data'},
|
'concat_constdims_data': {'kind': 'data'},
|
||||||
'memory_out_constdims': {'kind': 'op', 'op': 'Memory'},
|
'memory_out_constdims': {'kind': 'op', 'op': 'Assign'},
|
||||||
'memory_out_constdims_data': {'kind': 'data'},
|
'memory_out_constdims_data': {'kind': 'data'},
|
||||||
'result_constdims': {'kind': 'op', 'op': 'Result'},
|
'result_constdims': {'kind': 'op', 'op': 'Result'},
|
||||||
'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
|
'crop_first_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 10},
|
||||||
@ -121,6 +182,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
|||||||
('in_node', 'split', {'in': 0}),
|
('in_node', 'split', {'in': 0}),
|
||||||
('split', 'split_data_0', {'out': 0}),
|
('split', 'split_data_0', {'out': 0}),
|
||||||
('split', 'split_data_1', {'out': 1}),
|
('split', 'split_data_1', {'out': 1}),
|
||||||
|
|
||||||
|
('split_data_0', 'shape'), ('shape', 'shape_data'),
|
||||||
|
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||||
|
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||||
|
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||||
|
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||||
|
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||||
|
('gather_shape', 'gather_shape_data'),
|
||||||
|
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||||
|
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||||
|
('broadcast_data', 'memory_in'),
|
||||||
|
|
||||||
('memory_in', 'memory_in_data'),
|
('memory_in', 'memory_in_data'),
|
||||||
('memory_in_data', 'crop_mem'),
|
('memory_in_data', 'crop_mem'),
|
||||||
('crop_mem', 'crop_mem_data'),
|
('crop_mem', 'crop_mem_data'),
|
||||||
@ -130,6 +203,18 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
|||||||
('concat_data', 'memory_out'),
|
('concat_data', 'memory_out'),
|
||||||
('memory_out', 'memory_out_data'),
|
('memory_out', 'memory_out_data'),
|
||||||
('memory_out_data', 'result'),
|
('memory_out_data', 'result'),
|
||||||
|
|
||||||
|
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
|
||||||
|
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
|
||||||
|
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
|
||||||
|
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
|
||||||
|
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
|
||||||
|
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
|
||||||
|
('gather_shape_2', 'gather_shape_2_data'),
|
||||||
|
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
|
||||||
|
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
|
||||||
|
('broadcast_2_data', 'memory_in_constdims'),
|
||||||
|
|
||||||
('memory_in_constdims', 'memory_in_constdims_data'),
|
('memory_in_constdims', 'memory_in_constdims_data'),
|
||||||
('memory_in_constdims_data', 'crop_mem_constdims'),
|
('memory_in_constdims_data', 'crop_mem_constdims'),
|
||||||
('crop_mem_constdims', 'crop_mem_constdims_data'),
|
('crop_mem_constdims', 'crop_mem_constdims_data'),
|
||||||
|
@ -113,9 +113,6 @@ def prepare_ir(argv: argparse.Namespace):
|
|||||||
elif (is_kaldi or is_onnx) and not argv.input_model:
|
elif (is_kaldi or is_onnx) and not argv.input_model:
|
||||||
raise Error('Path to input model is required: use --input_model.')
|
raise Error('Path to input model is required: use --input_model.')
|
||||||
|
|
||||||
if is_kaldi:
|
|
||||||
argv.generate_experimental_IR_V10 = False
|
|
||||||
|
|
||||||
log.debug(str(argv))
|
log.debug(str(argv))
|
||||||
log.debug("Model Optimizer started")
|
log.debug("Model Optimizer started")
|
||||||
|
|
||||||
|
40
model-optimizer/mo/ops/assign.py
Normal file
40
model-optimizer/mo/ops/assign.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
Copyright (C) 2020 Intel Corporation
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
from mo.graph.graph import Graph, Node
|
||||||
|
from mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
|
class Assign(Op):
|
||||||
|
op = 'Assign'
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
|
super().__init__(graph, {
|
||||||
|
'type': self.op,
|
||||||
|
'op': self.op,
|
||||||
|
'version': 'opset3',
|
||||||
|
'infer': self.infer,
|
||||||
|
'in_ports_count': 1,
|
||||||
|
'out_ports_count': 1,
|
||||||
|
}, attrs)
|
||||||
|
|
||||||
|
def backend_attrs(self):
|
||||||
|
return ['variable_id']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer(node: Node):
|
||||||
|
assert node.has_valid('variable_id'), \
|
||||||
|
"There is no required attribute variable_id in Assign op with name " + node.id
|
||||||
|
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
45
model-optimizer/mo/ops/read_value.py
Normal file
45
model-optimizer/mo/ops/read_value.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Copyright (C) 2020 Intel Corporation
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
from mo.graph.graph import Graph, Node
|
||||||
|
from mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
|
class ReadValue(Op):
|
||||||
|
op = 'ReadValue'
|
||||||
|
|
||||||
|
def __init__(self, graph: Graph, attrs: dict):
|
||||||
|
super().__init__(graph, {
|
||||||
|
'type': self.op,
|
||||||
|
'op': self.op,
|
||||||
|
'version': 'opset3',
|
||||||
|
'infer': self.infer,
|
||||||
|
'type_infer': self.type_infer,
|
||||||
|
'in_ports_count': 1,
|
||||||
|
'out_ports_count': 1,
|
||||||
|
}, attrs)
|
||||||
|
|
||||||
|
def backend_attrs(self):
|
||||||
|
return ['variable_id']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def type_infer(node: Node):
|
||||||
|
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer(node: Node):
|
||||||
|
assert node.has_valid('variable_id'), \
|
||||||
|
"There is no required attribute variable_id in ReadValue op with name " + node.id
|
||||||
|
node.out_port(0).data.set_shape(node.in_port(0).data.get_shape())
|
@ -32,6 +32,7 @@ from mo.ops.convolution import Convolution
|
|||||||
from mo.ops.deconvolution import Deconvolution
|
from mo.ops.deconvolution import Deconvolution
|
||||||
from mo.ops.op import Op
|
from mo.ops.op import Op
|
||||||
from mo.ops.pooling import Pooling
|
from mo.ops.pooling import Pooling
|
||||||
|
from mo.ops.result import Result
|
||||||
from mo.utils.class_registration import update_registration
|
from mo.utils.class_registration import update_registration
|
||||||
from mo.utils.import_extensions import import_by_path
|
from mo.utils.import_extensions import import_by_path
|
||||||
from mo.utils.ir_reader.extender import Extender
|
from mo.utils.ir_reader.extender import Extender
|
||||||
@ -218,6 +219,18 @@ def ti_add_edge_attrs(op: Node):
|
|||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
|
|
||||||
|
def assign_add_output_result(op: Node):
|
||||||
|
"""
|
||||||
|
Function adds necessary output result node for Assign node
|
||||||
|
:param op:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
assert op.soft_get('type') == 'Assign', 'Wrong operation type, {} instead of Assign!' \
|
||||||
|
''.format(op.soft_get('type'))
|
||||||
|
tmp_result = Result(op.graph, {'name': op.soft_get('name', op.id) + '/Result'}).create_node()
|
||||||
|
op.out_port(0).connect(tmp_result.in_port(0))
|
||||||
|
|
||||||
|
|
||||||
def copy_input_blobs(op: Node, copy_op: Node):
|
def copy_input_blobs(op: Node, copy_op: Node):
|
||||||
"""
|
"""
|
||||||
Function copy input blob data nodes from restored graph to copied one
|
Function copy input blob data nodes from restored graph to copied one
|
||||||
@ -243,6 +256,7 @@ preprocessing_op_nodes = {
|
|||||||
|
|
||||||
# Map with postprocessing functions for nodes
|
# Map with postprocessing functions for nodes
|
||||||
postprocessing_op_nodes = {
|
postprocessing_op_nodes = {
|
||||||
|
'Assign': assign_add_output_result,
|
||||||
'TensorIterator': ti_add_edge_attrs,
|
'TensorIterator': ti_add_edge_attrs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ from extensions.back.SpecialNodesFinalization import RemoveConstOps, CreateConst
|
|||||||
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
||||||
from extensions.back.TopKNormalizer import TopKNormalizer
|
from extensions.back.TopKNormalizer import TopKNormalizer
|
||||||
from extensions.back.blob_normalizer import BlobNormalizer
|
from extensions.back.blob_normalizer import BlobNormalizer
|
||||||
|
from extensions.back.kaldi_remove_memory_output import KaldiRemoveMemoryOutputBackReplacementPattern
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.middle.passes.convert_data_type import data_type_str_to_precision
|
from mo.middle.passes.convert_data_type import data_type_str_to_precision
|
||||||
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
from mo.middle.pattern_match import for_graph_and_each_sub_graph_recursively
|
||||||
@ -77,6 +78,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
|||||||
PackBinaryWeights,
|
PackBinaryWeights,
|
||||||
BlobNormalizer,
|
BlobNormalizer,
|
||||||
ConvolutionNormalizer,
|
ConvolutionNormalizer,
|
||||||
|
KaldiRemoveMemoryOutputBackReplacementPattern,
|
||||||
]
|
]
|
||||||
|
|
||||||
# We need to run some specific passes from MO back stage.
|
# We need to run some specific passes from MO back stage.
|
||||||
|
@ -137,6 +137,8 @@ set (SRC
|
|||||||
op/asin.hpp
|
op/asin.hpp
|
||||||
op/asinh.cpp
|
op/asinh.cpp
|
||||||
op/asinh.hpp
|
op/asinh.hpp
|
||||||
|
op/assign.cpp
|
||||||
|
op/assign.hpp
|
||||||
op/atan.cpp
|
op/atan.cpp
|
||||||
op/atan.hpp
|
op/atan.hpp
|
||||||
op/atanh.cpp
|
op/atanh.cpp
|
||||||
@ -314,6 +316,8 @@ set (SRC
|
|||||||
op/proposal.cpp
|
op/proposal.cpp
|
||||||
op/psroi_pooling.hpp
|
op/psroi_pooling.hpp
|
||||||
op/psroi_pooling.cpp
|
op/psroi_pooling.cpp
|
||||||
|
op/read_value.hpp
|
||||||
|
op/read_value.cpp
|
||||||
op/reduce_logical_and.cpp
|
op/reduce_logical_and.cpp
|
||||||
op/reduce_logical_and.hpp
|
op/reduce_logical_and.hpp
|
||||||
op/reduce_logical_or.cpp
|
op/reduce_logical_or.cpp
|
||||||
@ -518,6 +522,7 @@ set (SRC
|
|||||||
op/util/scatter_base.hpp
|
op/util/scatter_base.hpp
|
||||||
op/util/unary_elementwise_arithmetic.cpp
|
op/util/unary_elementwise_arithmetic.cpp
|
||||||
op/util/unary_elementwise_arithmetic.hpp
|
op/util/unary_elementwise_arithmetic.hpp
|
||||||
|
op/util/variable.hpp
|
||||||
ops.hpp
|
ops.hpp
|
||||||
opsets/opset.cpp
|
opsets/opset.cpp
|
||||||
partial_shape.cpp
|
partial_shape.cpp
|
||||||
|
81
ngraph/src/ngraph/op/assign.cpp
Normal file
81
ngraph/src/ngraph/op/assign.cpp
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include "ngraph/op/assign.hpp"
|
||||||
|
#include <ops.hpp>
|
||||||
|
#include "ngraph/op/read_value.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
constexpr NodeTypeInfo op::v3::Assign::type_info;
|
||||||
|
|
||||||
|
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
|
||||||
|
: Op({new_value})
|
||||||
|
, m_variable_id(variable_id)
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::v3::Assign::validate_and_infer_types()
|
||||||
|
{
|
||||||
|
auto value = input_value(0);
|
||||||
|
auto arg_t = get_input_element_type(0);
|
||||||
|
auto output_shape = get_input_partial_shape(0);
|
||||||
|
if (!m_variable)
|
||||||
|
{
|
||||||
|
NodeVector start_nodes;
|
||||||
|
for (const auto& input : inputs())
|
||||||
|
{
|
||||||
|
start_nodes.push_back(input.get_source_output().get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
auto nodes = topological_sort(start_nodes);
|
||||||
|
for (const auto& node : nodes)
|
||||||
|
{
|
||||||
|
if (auto read_value = as_type_ptr<op::v3::ReadValue>(node))
|
||||||
|
{
|
||||||
|
if (read_value->get_variable_id() == m_variable_id)
|
||||||
|
m_variable = read_value->get_variable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto variable_info = m_variable->get_info();
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
m_variable_id == variable_info.variable_id,
|
||||||
|
"Variables identifiers are inconsistent.");
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
output_shape == variable_info.data_shape,
|
||||||
|
"Variables output shapes are inconsistent.");
|
||||||
|
|
||||||
|
set_output_type(0, arg_t, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
|
{
|
||||||
|
check_new_args_count(this, new_args);
|
||||||
|
return make_shared<Assign>(new_args.at(0), m_variable_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v3::Assign::visit_attributes(AttributeVisitor& visitor)
|
||||||
|
{
|
||||||
|
visitor.on_attribute("variable_id", m_variable_id);
|
||||||
|
return true;
|
||||||
|
}
|
67
ngraph/src/ngraph/op/assign.hpp
Normal file
67
ngraph/src/ngraph/op/assign.hpp
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ngraph/op/op.hpp"
|
||||||
|
#include "ngraph/op/util/variable.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace op
|
||||||
|
{
|
||||||
|
namespace v3
|
||||||
|
{
|
||||||
|
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
||||||
|
class NGRAPH_API Assign : public Op
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"Assign", 3};
|
||||||
|
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||||
|
Assign() = default;
|
||||||
|
|
||||||
|
/// \brief Constructs an Assign operation.
|
||||||
|
///
|
||||||
|
/// \param new_value Node that produces the input tensor.
|
||||||
|
/// \param variable_id identificator of the variable to be updated.
|
||||||
|
Assign(const Output<Node>& new_value, const std::string& variable_id);
|
||||||
|
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
std::shared_ptr<Node>
|
||||||
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
|
std::string get_variable_id() { return m_variable_id; }
|
||||||
|
std::shared_ptr<ngraph::Variable> get_variable() { return m_variable; }
|
||||||
|
void set_variable_id(const std::string& variable_id)
|
||||||
|
{
|
||||||
|
m_variable_id = variable_id;
|
||||||
|
}
|
||||||
|
void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
|
||||||
|
{
|
||||||
|
m_variable = variable;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string m_variable_id;
|
||||||
|
std::shared_ptr<ngraph::Variable> m_variable;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
using v3::Assign;
|
||||||
|
}
|
||||||
|
}
|
@ -265,3 +265,5 @@ NGRAPH_OP(Transpose, ngraph::op::v1, 1)
|
|||||||
NGRAPH_OP(Unsqueeze, ngraph::op::v0, 0)
|
NGRAPH_OP(Unsqueeze, ngraph::op::v0, 0)
|
||||||
NGRAPH_OP(VariadicSplit, ngraph::op::v1, 1)
|
NGRAPH_OP(VariadicSplit, ngraph::op::v1, 1)
|
||||||
NGRAPH_OP(Xor, ngraph::op::v0, 0)
|
NGRAPH_OP(Xor, ngraph::op::v0, 0)
|
||||||
|
NGRAPH_OP(Assign, ngraph::op::v3, 3)
|
||||||
|
NGRAPH_OP(ReadValue, ngraph::op::v3, 3)
|
||||||
|
51
ngraph/src/ngraph/op/read_value.cpp
Normal file
51
ngraph/src/ngraph/op/read_value.cpp
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include "ngraph/op/read_value.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
constexpr NodeTypeInfo op::ReadValue::type_info;
|
||||||
|
|
||||||
|
op::ReadValue::ReadValue(const Output<Node>& new_value, const std::string& variable_id)
|
||||||
|
: Op({new_value})
|
||||||
|
, m_variable_id(variable_id)
|
||||||
|
{
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
void op::ReadValue::validate_and_infer_types()
|
||||||
|
{
|
||||||
|
auto arg_t = get_input_element_type(0);
|
||||||
|
auto output_shape = get_input_partial_shape(0);
|
||||||
|
|
||||||
|
VariableInfo info = {output_shape, arg_t, m_variable_id};
|
||||||
|
m_variable = std::make_shared<Variable>(info);
|
||||||
|
set_output_type(0, arg_t, output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Node> op::ReadValue::clone_with_new_inputs(const OutputVector& new_args) const
|
||||||
|
{
|
||||||
|
check_new_args_count(this, new_args);
|
||||||
|
return make_shared<ReadValue>(new_args.at(0), m_variable_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v3::ReadValue::visit_attributes(AttributeVisitor& visitor)
|
||||||
|
{
|
||||||
|
visitor.on_attribute("variable_id", m_variable_id);
|
||||||
|
return true;
|
||||||
|
}
|
68
ngraph/src/ngraph/op/read_value.hpp
Normal file
68
ngraph/src/ngraph/op/read_value.hpp
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ngraph/op/op.hpp"
|
||||||
|
#include "ngraph/op/util/variable.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace op
|
||||||
|
{
|
||||||
|
namespace v3
|
||||||
|
{
|
||||||
|
/// \brief ReadValue operation creates the variable with `variable_id` and returns value
|
||||||
|
/// of this variable.
|
||||||
|
class NGRAPH_API ReadValue : public Op
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static constexpr NodeTypeInfo type_info{"ReadValue", 3};
|
||||||
|
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||||
|
ReadValue() = default;
|
||||||
|
|
||||||
|
/// \brief Constructs a ReadValue operation.
|
||||||
|
///
|
||||||
|
/// \param new_value Node that produces the input tensor.
|
||||||
|
/// \param variable_id identificator of the variable to create.
|
||||||
|
ReadValue(const Output<Node>& new_value, const std::string& variable_id);
|
||||||
|
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
std::shared_ptr<Node>
|
||||||
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
|
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||||
|
|
||||||
|
std::string get_variable_id() { return m_variable_id; }
|
||||||
|
std::shared_ptr<ngraph::Variable> get_variable() { return m_variable; }
|
||||||
|
void set_variable_id(const std::string& variable_id)
|
||||||
|
{
|
||||||
|
m_variable_id = variable_id;
|
||||||
|
}
|
||||||
|
void set_variable(const std::shared_ptr<ngraph::Variable>& variable)
|
||||||
|
{
|
||||||
|
m_variable = variable;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string m_variable_id;
|
||||||
|
std::shared_ptr<ngraph::Variable> m_variable;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
using v3::ReadValue;
|
||||||
|
}
|
||||||
|
}
|
@ -23,6 +23,9 @@
|
|||||||
#include "ngraph/shape.hpp"
|
#include "ngraph/shape.hpp"
|
||||||
#include "ngraph/validation_util.hpp"
|
#include "ngraph/validation_util.hpp"
|
||||||
|
|
||||||
|
#include "ngraph/runtime/host_tensor.hpp"
|
||||||
|
#include "ngraph/runtime/reference/topk.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
@ -225,9 +228,224 @@ void op::v0::TopK::generate_adjoints(autodiff::Adjoints& /* adjoints */,
|
|||||||
throw ngraph_error("Forward-propagation-only operation");
|
throw ngraph_error("Forward-propagation-only operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
template <element::Type_t INPUT_ET, element::Type_t INDEX_ET>
|
||||||
|
inline bool evaluate_execute(const HostTensorPtr& arg0,
|
||||||
|
const HostTensorPtr& out_indices,
|
||||||
|
const HostTensorPtr& out_values,
|
||||||
|
const Shape out_shape,
|
||||||
|
const size_t axis,
|
||||||
|
const size_t k,
|
||||||
|
const bool compute_max,
|
||||||
|
const op::TopK::SortType sort)
|
||||||
|
{
|
||||||
|
using T = typename element_type_traits<INPUT_ET>::value_type;
|
||||||
|
using U = typename element_type_traits<INDEX_ET>::value_type;
|
||||||
|
const Shape in_shape = arg0->get_shape();
|
||||||
|
out_indices->set_shape(out_shape);
|
||||||
|
out_indices->set_element_type(INDEX_ET);
|
||||||
|
|
||||||
|
out_values->set_shape(out_shape);
|
||||||
|
out_values->set_element_type(arg0->get_element_type());
|
||||||
|
|
||||||
|
runtime::reference::topk<T, U>(arg0->get_data_ptr<INPUT_ET>(),
|
||||||
|
out_indices->get_data_ptr<INDEX_ET>(),
|
||||||
|
out_values->get_data_ptr<INPUT_ET>(),
|
||||||
|
in_shape,
|
||||||
|
out_shape,
|
||||||
|
axis,
|
||||||
|
k,
|
||||||
|
compute_max,
|
||||||
|
sort);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <element::Type_t INPUT_ET>
|
||||||
|
bool evaluate(const HostTensorPtr& arg,
|
||||||
|
const HostTensorPtr& out_indices,
|
||||||
|
const HostTensorPtr& out_values,
|
||||||
|
const Shape out_shape,
|
||||||
|
const size_t axis,
|
||||||
|
const size_t k,
|
||||||
|
const bool max,
|
||||||
|
const op::TopK::SortType sort,
|
||||||
|
const element::Type index_et)
|
||||||
|
{
|
||||||
|
bool rc = true;
|
||||||
|
switch (index_et)
|
||||||
|
{
|
||||||
|
case element::Type_t::i64:
|
||||||
|
evaluate_execute<INPUT_ET, element::Type_t::i64>(
|
||||||
|
arg, out_indices, out_values, out_shape, axis, k, max, sort);
|
||||||
|
break;
|
||||||
|
case element::Type_t::i32:
|
||||||
|
evaluate_execute<INPUT_ET, element::Type_t::i32>(
|
||||||
|
arg, out_indices, out_values, out_shape, axis, k, max, sort);
|
||||||
|
break;
|
||||||
|
default: rc = false; break;
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool evaluate_topk(const HostTensorPtr& arg,
|
||||||
|
const HostTensorPtr& out_indices,
|
||||||
|
const HostTensorPtr& out_values,
|
||||||
|
const Shape out_shape,
|
||||||
|
const size_t axis,
|
||||||
|
const size_t k,
|
||||||
|
const bool max,
|
||||||
|
const op::TopK::SortType sort,
|
||||||
|
const element::Type index_et)
|
||||||
|
{
|
||||||
|
bool rc = true;
|
||||||
|
switch (arg->get_element_type())
|
||||||
|
{
|
||||||
|
TYPE_CASE(i8)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(i16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(i32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(i64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(u8)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(u16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(u32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(u64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(bf16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(f16)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(f32)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
TYPE_CASE(f64)(arg, out_indices, out_values, out_shape, axis, k, max, sort, index_et);
|
||||||
|
break;
|
||||||
|
default: rc = false; break;
|
||||||
|
}
|
||||||
|
return rc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <element::Type_t K_ET>
|
||||||
|
size_t get_k_from_hosttensor(const HostTensorPtr& arg)
|
||||||
|
{
|
||||||
|
using T = typename element_type_traits<K_ET>::value_type;
|
||||||
|
auto p = arg->get_data_ptr<T>();
|
||||||
|
size_t k = p[0];
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CASE_GET_K(a) \
|
||||||
|
case element::Type_t::a: k = get_k_from_hosttensor<element::Type_t::a>
|
||||||
|
|
||||||
|
size_t read_k_from_host_tensor(const HostTensorPtr& arg_k)
|
||||||
|
{
|
||||||
|
size_t k = 0;
|
||||||
|
switch (arg_k->get_element_type())
|
||||||
|
{
|
||||||
|
CASE_GET_K(i8)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(i16)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(i32)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(i64)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(u8)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(u16)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(u32)(arg_k);
|
||||||
|
break;
|
||||||
|
CASE_GET_K(u64)(arg_k);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
// other types are not supported and would have thrown in ctor
|
||||||
|
ngraph_error("read_k_from_host_tensor: type is not integral\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
|
||||||
|
// used in only v0, where type is set as int64_t
|
||||||
|
size_t read_top_k_axis_from_host_tensor(const HostTensorPtr& arg)
|
||||||
|
{
|
||||||
|
NGRAPH_CHECK(arg->get_element_type() == element::i64,
|
||||||
|
"TopK axis element type should be i64");
|
||||||
|
auto p = arg->get_data_ptr<int64_t>();
|
||||||
|
size_t axis = static_cast<size_t>(p[0]);
|
||||||
|
return axis;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape op::v0::TopK::compute_output_shape(const Shape input_shape,
|
||||||
|
const int64_t k,
|
||||||
|
const size_t axis)
|
||||||
|
{
|
||||||
|
Shape output_shape{input_shape};
|
||||||
|
if (k != 0)
|
||||||
|
{
|
||||||
|
output_shape[axis] = k;
|
||||||
|
}
|
||||||
|
return output_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool op::v0::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
|
||||||
|
{
|
||||||
|
// check data types for arg, k and output element type
|
||||||
|
Shape arg_shape = inputs[0]->get_shape();
|
||||||
|
|
||||||
|
// 1. get axis, mode ( max/min), sort_type
|
||||||
|
size_t axis = 0;
|
||||||
|
Dimension axis_dim = get_top_k_axis_dynamic();
|
||||||
|
if (axis_dim.is_static())
|
||||||
|
{
|
||||||
|
axis = axis_dim.get_length();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
axis = read_top_k_axis_from_host_tensor(inputs[2]);
|
||||||
|
NGRAPH_CHECK(axis <= arg_shape.size(), "TopK axis is out of bounds");
|
||||||
|
}
|
||||||
|
bool compute_max = get_compute_max();
|
||||||
|
SortType sort_type = get_sort();
|
||||||
|
|
||||||
|
// 2. get value of k - from constant node or from HT
|
||||||
|
size_t k = get_k();
|
||||||
|
if (k == 0)
|
||||||
|
{
|
||||||
|
k = read_k_from_host_tensor(inputs[1]);
|
||||||
|
if (k == 0)
|
||||||
|
{
|
||||||
|
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
|
||||||
|
k = arg_shape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
NGRAPH_CHECK(k <= arg_shape.at(axis), "K exceeds the dimension of the TopK axis");
|
||||||
|
|
||||||
|
// 3. Compute output_shape
|
||||||
|
auto output_shape = compute_output_shape(inputs[0]->get_shape(), k, axis);
|
||||||
|
|
||||||
|
return evaluate_topk(inputs[0],
|
||||||
|
outputs[0],
|
||||||
|
outputs[1],
|
||||||
|
output_shape,
|
||||||
|
axis,
|
||||||
|
k,
|
||||||
|
compute_max,
|
||||||
|
sort_type,
|
||||||
|
get_index_element_type());
|
||||||
|
}
|
||||||
|
|
||||||
// v1 version starts
|
// v1 version starts
|
||||||
constexpr NodeTypeInfo op::v1::TopK::type_info;
|
constexpr NodeTypeInfo op::v1::TopK::type_info;
|
||||||
|
|
||||||
|
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
|
||||||
|
|
||||||
op::v1::TopK::TopK(const Output<Node>& data,
|
op::v1::TopK::TopK(const Output<Node>& data,
|
||||||
const Output<Node>& k,
|
const Output<Node>& k,
|
||||||
const int64_t axis,
|
const int64_t axis,
|
||||||
@ -236,7 +454,7 @@ op::v1::TopK::TopK(const Output<Node>& data,
|
|||||||
const element::Type& index_element_type)
|
const element::Type& index_element_type)
|
||||||
: Op{{data, k}}
|
: Op{{data, k}}
|
||||||
, m_axis{axis}
|
, m_axis{axis}
|
||||||
, m_normalized_axis{0}
|
, m_normalized_axis{UNKNOWN_NORMALIZED_AXIS}
|
||||||
, m_mode{as_enum<Mode>(mode)}
|
, m_mode{as_enum<Mode>(mode)}
|
||||||
, m_sort{as_enum<SortType>(sort)}
|
, m_sort{as_enum<SortType>(sort)}
|
||||||
, m_index_element_type{index_element_type}
|
, m_index_element_type{index_element_type}
|
||||||
@ -244,8 +462,6 @@ op::v1::TopK::TopK(const Output<Node>& data,
|
|||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
static const std::uint64_t UNKNOWN_NORMALIZED_AXIS = std::numeric_limits<uint64_t>::max();
|
|
||||||
|
|
||||||
op::v1::TopK::TopK(const Output<Node>& data,
|
op::v1::TopK::TopK(const Output<Node>& data,
|
||||||
const Output<Node>& k,
|
const Output<Node>& k,
|
||||||
const int64_t axis,
|
const int64_t axis,
|
||||||
@ -318,6 +534,25 @@ void op::v1::TopK::validate_and_infer_types()
|
|||||||
set_output_type(1, m_index_element_type, output_shape);
|
set_output_type(1, m_index_element_type, output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Shape op::v1::TopK::compute_output_shape(const std::string& node_description,
|
||||||
|
const PartialShape input_partial_shape,
|
||||||
|
const int64_t k)
|
||||||
|
{
|
||||||
|
PartialShape output_shape{input_partial_shape};
|
||||||
|
|
||||||
|
m_normalized_axis = ngraph::normalize_axis(node_description, m_axis, output_shape.rank());
|
||||||
|
if (k != 0)
|
||||||
|
{
|
||||||
|
output_shape[m_normalized_axis] = k;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
output_shape[m_normalized_axis] = input_partial_shape[m_normalized_axis];
|
||||||
|
}
|
||||||
|
|
||||||
|
return output_shape.get_shape();
|
||||||
|
}
|
||||||
|
|
||||||
void op::v1::TopK::set_axis(const int64_t axis)
|
void op::v1::TopK::set_axis(const int64_t axis)
|
||||||
{
|
{
|
||||||
const auto input_rank = get_input_partial_shape(0).rank();
|
const auto input_rank = get_input_partial_shape(0).rank();
|
||||||
@ -332,6 +567,19 @@ void op::v1::TopK::set_axis(const int64_t axis)
|
|||||||
m_axis = axis;
|
m_axis = axis;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void op::v1::TopK::set_axis(const Rank input_rank, const int64_t axis)
|
||||||
|
{
|
||||||
|
if (input_rank.is_static())
|
||||||
|
{
|
||||||
|
m_normalized_axis = ngraph::normalize_axis(this, axis, input_rank);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
m_normalized_axis = UNKNOWN_NORMALIZED_AXIS;
|
||||||
|
}
|
||||||
|
m_axis = axis;
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t op::v1::TopK::get_axis() const
|
uint64_t op::v1::TopK::get_axis() const
|
||||||
{
|
{
|
||||||
NODE_VALIDATION_CHECK(
|
NODE_VALIDATION_CHECK(
|
||||||
@ -433,6 +681,49 @@ void op::v1::TopK::set_k(size_t k)
|
|||||||
op::Constant::create(element::i64, Shape{}, {k})->output(0));
|
op::Constant::create(element::i64, Shape{}, {k})->output(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
|
||||||
|
{
|
||||||
|
Shape arg_shape = inputs[0]->get_shape();
|
||||||
|
// 1. get axis, mode ( max/min), sort_type
|
||||||
|
set_axis(arg_shape.size(), m_axis);
|
||||||
|
size_t axis = get_axis();
|
||||||
|
bool compute_max = get_mode() == TopKMode::MAX ? true : false;
|
||||||
|
SortType sort_type = get_sort_type();
|
||||||
|
|
||||||
|
// 2. get value of k - from constant node or from HT
|
||||||
|
size_t k = 0;
|
||||||
|
if (input_value(1).get_node_shared_ptr()->is_constant())
|
||||||
|
{
|
||||||
|
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
|
||||||
|
get_input_element_type(1));
|
||||||
|
NGRAPH_CHECK(k <= arg_shape[axis], "'K' exceeds the dimension of top_k_axis");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
k = read_k_from_host_tensor(inputs[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Compute output_shape
|
||||||
|
auto output_shape = compute_output_shape(this->description(), inputs[0]->get_shape(), k);
|
||||||
|
|
||||||
|
// do this after compute_output_shape
|
||||||
|
if (k == 0)
|
||||||
|
{
|
||||||
|
// the kernel can't handle k = 0, but output_shape[axis] = arg_shape[axis]
|
||||||
|
k = arg_shape[axis];
|
||||||
|
}
|
||||||
|
|
||||||
|
return evaluate_topk(inputs[0],
|
||||||
|
outputs[1],
|
||||||
|
outputs[0],
|
||||||
|
output_shape,
|
||||||
|
axis,
|
||||||
|
k,
|
||||||
|
compute_max,
|
||||||
|
sort_type,
|
||||||
|
get_index_element_type());
|
||||||
|
}
|
||||||
|
|
||||||
// v3 version starts
|
// v3 version starts
|
||||||
constexpr NodeTypeInfo op::v3::TopK::type_info;
|
constexpr NodeTypeInfo op::v3::TopK::type_info;
|
||||||
|
|
||||||
@ -516,3 +807,8 @@ shared_ptr<Node> op::v3::TopK::clone_with_new_inputs(const OutputVector& new_arg
|
|||||||
|
|
||||||
return std::move(new_v3_topk);
|
return std::move(new_v3_topk);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool op::v3::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
|
||||||
|
{
|
||||||
|
return op::v1::TopK::evaluate(outputs, inputs);
|
||||||
|
}
|
||||||
|
@ -102,12 +102,18 @@ namespace ngraph
|
|||||||
bool get_compute_max() const { return m_compute_max; }
|
bool get_compute_max() const { return m_compute_max; }
|
||||||
SortType get_sort() const { return m_sort; }
|
SortType get_sort() const { return m_sort; }
|
||||||
size_t get_default_output_index() const override { return no_default_index(); }
|
size_t get_default_output_index() const override { return no_default_index(); }
|
||||||
|
bool evaluate(const HostTensorVector& outputs,
|
||||||
|
const HostTensorVector& inputs) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
element::Type m_index_element_type;
|
element::Type m_index_element_type;
|
||||||
bool m_compute_max{false};
|
bool m_compute_max{false};
|
||||||
SortType m_sort{SortType::NONE};
|
SortType m_sort{SortType::NONE};
|
||||||
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
|
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
|
||||||
const OutputVector& deltas) override;
|
const OutputVector& deltas) override;
|
||||||
|
Shape compute_output_shape(const Shape input_shape,
|
||||||
|
const int64_t k,
|
||||||
|
const size_t axis);
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
|
||||||
@ -181,6 +187,9 @@ namespace ngraph
|
|||||||
size_t get_k() const;
|
size_t get_k() const;
|
||||||
void set_k(size_t k);
|
void set_k(size_t k);
|
||||||
size_t get_default_output_index() const override { return no_default_index(); }
|
size_t get_default_output_index() const override { return no_default_index(); }
|
||||||
|
bool evaluate(const HostTensorVector& outputs,
|
||||||
|
const HostTensorVector& inputs) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
int64_t m_axis;
|
int64_t m_axis;
|
||||||
uint64_t m_normalized_axis;
|
uint64_t m_normalized_axis;
|
||||||
@ -196,6 +205,10 @@ namespace ngraph
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
|
size_t validate_and_get_k(const std::shared_ptr<op::Constant>& k_constant) const;
|
||||||
|
Shape compute_output_shape(const std::string& node_description,
|
||||||
|
const PartialShape input_partial_shape,
|
||||||
|
const int64_t k);
|
||||||
|
void set_axis(const Rank input_rank, const int64_t axis);
|
||||||
};
|
};
|
||||||
} // namespace v1
|
} // namespace v1
|
||||||
|
|
||||||
@ -240,6 +253,9 @@ namespace ngraph
|
|||||||
virtual std::shared_ptr<Node>
|
virtual std::shared_ptr<Node>
|
||||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||||
|
|
||||||
|
bool evaluate(const HostTensorVector& outputs,
|
||||||
|
const HostTensorVector& inputs) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual size_t
|
virtual size_t
|
||||||
read_k_from_constant_node(const std::shared_ptr<Node>& node,
|
read_k_from_constant_node(const std::shared_ptr<Node>& node,
|
||||||
|
@ -142,7 +142,8 @@ void op::util::BroadcastBase::validate_and_infer_types()
|
|||||||
auto output_rank = input_value(1).get_partial_shape();
|
auto output_rank = input_value(1).get_partial_shape();
|
||||||
if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static())
|
if (input_rank.is_static() && output_rank.is_static() && output_rank[0].is_static())
|
||||||
{
|
{
|
||||||
result_shape = PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
|
result_shape =
|
||||||
|
PartialShape::dynamic(std::max(input_rank.get_length(), output_rank[0].get_length()));
|
||||||
}
|
}
|
||||||
const auto shape_constant = as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
|
const auto shape_constant = as_type_ptr<op::v0::Constant>(input_value(1).get_node_shared_ptr());
|
||||||
|
|
||||||
|
47
ngraph/src/ngraph/op/util/variable.hpp
Normal file
47
ngraph/src/ngraph/op/util/variable.hpp
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "ngraph/op/op.hpp"
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
struct VariableInfo
|
||||||
|
{
|
||||||
|
PartialShape data_shape;
|
||||||
|
element::Type data_type;
|
||||||
|
std::string variable_id;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NGRAPH_API Variable
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
Variable() = default;
|
||||||
|
|
||||||
|
explicit Variable(const VariableInfo& variable_info)
|
||||||
|
: m_info(variable_info)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
VariableInfo get_info() { return m_info; }
|
||||||
|
void update(const VariableInfo& variable_info) { m_info = variable_info; }
|
||||||
|
private:
|
||||||
|
VariableInfo m_info;
|
||||||
|
};
|
||||||
|
}
|
@ -30,6 +30,7 @@
|
|||||||
#include "ngraph/op/argmin.hpp"
|
#include "ngraph/op/argmin.hpp"
|
||||||
#include "ngraph/op/asin.hpp"
|
#include "ngraph/op/asin.hpp"
|
||||||
#include "ngraph/op/asinh.hpp"
|
#include "ngraph/op/asinh.hpp"
|
||||||
|
#include "ngraph/op/assign.hpp"
|
||||||
#include "ngraph/op/atan.hpp"
|
#include "ngraph/op/atan.hpp"
|
||||||
#include "ngraph/op/atan2.hpp"
|
#include "ngraph/op/atan2.hpp"
|
||||||
#include "ngraph/op/atanh.hpp"
|
#include "ngraph/op/atanh.hpp"
|
||||||
@ -149,6 +150,7 @@
|
|||||||
#include "ngraph/op/quantized_convolution.hpp"
|
#include "ngraph/op/quantized_convolution.hpp"
|
||||||
#include "ngraph/op/quantized_dot.hpp"
|
#include "ngraph/op/quantized_dot.hpp"
|
||||||
#include "ngraph/op/range.hpp"
|
#include "ngraph/op/range.hpp"
|
||||||
|
#include "ngraph/op/read_value.hpp"
|
||||||
#include "ngraph/op/recv.hpp"
|
#include "ngraph/op/recv.hpp"
|
||||||
#include "ngraph/op/reduce_logical_and.hpp"
|
#include "ngraph/op/reduce_logical_and.hpp"
|
||||||
#include "ngraph/op/reduce_logical_or.hpp"
|
#include "ngraph/op/reduce_logical_or.hpp"
|
||||||
|
@ -166,4 +166,6 @@ NGRAPH_OP(ScatterElementsUpdate, ngraph::op::v3)
|
|||||||
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
|
NGRAPH_OP(ScatterUpdate, ngraph::op::v3)
|
||||||
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
|
NGRAPH_OP(ShuffleChannels, ngraph::op::v0)
|
||||||
NGRAPH_OP(ShapeOf, ngraph::op::v3)
|
NGRAPH_OP(ShapeOf, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(Assign, ngraph::op::v3)
|
||||||
|
NGRAPH_OP(ReadValue, ngraph::op::v3)
|
||||||
NGRAPH_OP(TopK, ngraph::op::v3)
|
NGRAPH_OP(TopK, ngraph::op::v3)
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
#include "ngraph/specialize_function.hpp"
|
#include "ngraph/specialize_function.hpp"
|
||||||
#include <pass/constant_folding.hpp>
|
#include <pass/constant_folding.hpp>
|
||||||
|
#include "ngraph/op/assign.hpp"
|
||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
#include "ngraph/op/tensor_iterator.hpp"
|
#include "ngraph/op/tensor_iterator.hpp"
|
||||||
|
|
||||||
@ -84,7 +85,18 @@ std::shared_ptr<Function>
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
m[old_node.get()] = old_node->copy_with_new_inputs(new_args);
|
NodeVector cloned_dependencies;
|
||||||
|
for (auto& dependency : old_node->get_control_dependencies())
|
||||||
|
{
|
||||||
|
std::shared_ptr<Node> dependent = m.at(dependency.get());
|
||||||
|
if (find(cloned_dependencies.begin(), cloned_dependencies.end(), dependent) ==
|
||||||
|
cloned_dependencies.end())
|
||||||
|
{
|
||||||
|
cloned_dependencies.push_back(dependent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m[old_node.get()] = old_node->copy_with_new_inputs(new_args, cloned_dependencies);
|
||||||
|
|
||||||
// TODO: workaround for shape inference, delete it after fix
|
// TODO: workaround for shape inference, delete it after fix
|
||||||
if (::ngraph::as_type_ptr<ngraph::op::TensorIterator>(m[old_node.get()]))
|
if (::ngraph::as_type_ptr<ngraph::op::TensorIterator>(m[old_node.get()]))
|
||||||
{
|
{
|
||||||
|
@ -114,6 +114,7 @@ set(SRC
|
|||||||
tensor.cpp
|
tensor.cpp
|
||||||
type_prop/all.cpp
|
type_prop/all.cpp
|
||||||
type_prop/any.cpp
|
type_prop/any.cpp
|
||||||
|
type_prop/assign.cpp
|
||||||
type_prop/avg_pool.cpp
|
type_prop/avg_pool.cpp
|
||||||
type_prop/batch_mat_mul.cpp
|
type_prop/batch_mat_mul.cpp
|
||||||
type_prop/batch_mat_mul_transpose.cpp
|
type_prop/batch_mat_mul_transpose.cpp
|
||||||
@ -178,6 +179,7 @@ set(SRC
|
|||||||
type_prop/quantized_dot.cpp
|
type_prop/quantized_dot.cpp
|
||||||
type_prop/random_uniform.cpp
|
type_prop/random_uniform.cpp
|
||||||
type_prop/range.cpp
|
type_prop/range.cpp
|
||||||
|
type_prop/read_value.cpp
|
||||||
type_prop/replace_slice.cpp
|
type_prop/replace_slice.cpp
|
||||||
type_prop/reshape.cpp
|
type_prop/reshape.cpp
|
||||||
type_prop/reverse.cpp
|
type_prop/reverse.cpp
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -64,10 +64,10 @@
|
|||||||
#include "ngraph/op/stop_gradient.hpp"
|
#include "ngraph/op/stop_gradient.hpp"
|
||||||
#include "ngraph/op/tan.hpp"
|
#include "ngraph/op/tan.hpp"
|
||||||
#include "ngraph/op/tanh.hpp"
|
#include "ngraph/op/tanh.hpp"
|
||||||
|
#include "ngraph/op/topk.hpp"
|
||||||
#include "ngraph/op/transpose.hpp"
|
#include "ngraph/op/transpose.hpp"
|
||||||
#include "ngraph/runtime/host_tensor.hpp"
|
#include "ngraph/runtime/host_tensor.hpp"
|
||||||
#include "ngraph/validation_util.hpp"
|
#include "ngraph/validation_util.hpp"
|
||||||
#include "runtime/backend.hpp"
|
|
||||||
#include "util/all_close_f.hpp"
|
#include "util/all_close_f.hpp"
|
||||||
#include "util/ndarray.hpp"
|
#include "util/ndarray.hpp"
|
||||||
#include "util/test_tools.hpp"
|
#include "util/test_tools.hpp"
|
||||||
@ -1555,3 +1555,358 @@ TEST(eval, evaluate_dynamic_scatter_elements_update_one_elem_i32)
|
|||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
|
||||||
ASSERT_EQ(cval, out);
|
ASSERT_EQ(cval, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v1)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
Shape rshape{2, 2, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
const auto k = op::Constant::create(element::i32, Shape{}, {2});
|
||||||
|
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
|
||||||
|
|
||||||
|
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v1_dyn)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||||
|
auto B = make_shared<op::v1::TopK>(A, k, 1, "max", "index", element::i32);
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v3_dyn)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||||
|
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "index", element::i32);
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 6, 3, 11, 7};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 0, 1, 2, 2};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v3_dyn_values)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||||
|
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 11, 7, 6, 3};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 2, 2, 0, 1};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v3_dyn_values_k0)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||||
|
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i32>(Shape{}, {0})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v0_dyn)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
|
||||||
|
element::Type result_et{element::i32};
|
||||||
|
bool compute_max = true;
|
||||||
|
|
||||||
|
auto B = make_shared<op::v0::TopK>(
|
||||||
|
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||||
|
|
||||||
|
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||||
|
ParameterVector{A, k, axis});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {2}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result1_val = read_vector<float>(result1);
|
||||||
|
auto result0_val = read_vector<int32_t>(result0);
|
||||||
|
|
||||||
|
vector<float> expec1{12, 9, 10, 4, 11, 7, 6, 3};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
|
||||||
|
vector<int32_t> expec0{0, 1, 1, 2, 2, 2, 0, 1};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v0_dyn_k0)
|
||||||
|
{
|
||||||
|
Shape shape{2, 3, 2};
|
||||||
|
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, shape);
|
||||||
|
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
|
||||||
|
element::Type result_et{element::i32};
|
||||||
|
bool compute_max = true;
|
||||||
|
|
||||||
|
auto B = make_shared<op::v0::TopK>(
|
||||||
|
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||||
|
|
||||||
|
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||||
|
ParameterVector{A, k, axis});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {0}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
auto result1_val = read_vector<float>(result1);
|
||||||
|
auto result0_val = read_vector<int32_t>(result0);
|
||||||
|
|
||||||
|
vector<float> expec1{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
|
||||||
|
vector<int32_t> expec0{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v3_param_dyn_values_k0)
|
||||||
|
{
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||||
|
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i32>(Shape{}, {0})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v3_param_dyn_values_k2)
|
||||||
|
{
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto k = make_shared<op::Parameter>(element::u32, Shape{});
|
||||||
|
auto B = make_shared<op::v3::TopK>(A, k, 1, "max", "value", element::i32);
|
||||||
|
|
||||||
|
auto fun =
|
||||||
|
make_shared<Function>(OutputVector{B->output(0), B->output(1)}, ParameterVector{A, k});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i32>(Shape{}, {2})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result0_val = read_vector<float>(result0);
|
||||||
|
auto result1_val = read_vector<int32_t>(result1);
|
||||||
|
vector<float> expec0{12, 9, 10, 4, 11, 7, 6, 3};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
|
||||||
|
vector<int32_t> expec1{0, 1, 1, 2, 2, 2, 0, 1};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v0_param_dyn_k2)
|
||||||
|
{
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
|
||||||
|
element::Type result_et{element::i32};
|
||||||
|
bool compute_max = true;
|
||||||
|
|
||||||
|
auto B = make_shared<op::v0::TopK>(
|
||||||
|
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||||
|
|
||||||
|
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||||
|
ParameterVector{A, k, axis});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {2}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 2, 2}));
|
||||||
|
auto result1_val = read_vector<float>(result1);
|
||||||
|
auto result0_val = read_vector<int32_t>(result0);
|
||||||
|
|
||||||
|
vector<float> expec1{12, 9, 10, 4, 11, 7, 6, 3};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
|
||||||
|
vector<int32_t> expec0{0, 1, 1, 2, 2, 2, 0, 1};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(eval, topk_v0_param_dyn_k0)
|
||||||
|
{
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
auto k = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
auto axis = make_shared<op::Parameter>(element::i64, Shape{});
|
||||||
|
|
||||||
|
element::Type result_et{element::i32};
|
||||||
|
bool compute_max = true;
|
||||||
|
|
||||||
|
auto B = make_shared<op::v0::TopK>(
|
||||||
|
A, k, axis, result_et, compute_max, op::v0::TopK::SortType::SORT_VALUES);
|
||||||
|
|
||||||
|
auto fun = make_shared<Function>(OutputVector{B->output(0), B->output(1)},
|
||||||
|
ParameterVector{A, k, axis});
|
||||||
|
|
||||||
|
auto result0 = make_shared<HostTensor>();
|
||||||
|
auto result1 = make_shared<HostTensor>();
|
||||||
|
ASSERT_TRUE(fun->evaluate({result0, result1},
|
||||||
|
{make_host_tensor<element::Type_t::f32>(
|
||||||
|
Shape{2, 3, 2}, {12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {0}),
|
||||||
|
make_host_tensor<element::Type_t::i64>(Shape{}, {1})}));
|
||||||
|
EXPECT_EQ(result0->get_element_type(), element::i32);
|
||||||
|
EXPECT_EQ(result0->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
EXPECT_EQ(result1->get_element_type(), element::f32);
|
||||||
|
EXPECT_EQ(result1->get_partial_shape(), (PartialShape{2, 3, 2}));
|
||||||
|
auto result1_val = read_vector<float>(result1);
|
||||||
|
auto result0_val = read_vector<int32_t>(result0);
|
||||||
|
|
||||||
|
vector<float> expec1{12, 9, 10, 4, 8, 2, 11, 7, 6, 3, 5, 1};
|
||||||
|
ASSERT_EQ(result1_val, expec1);
|
||||||
|
|
||||||
|
vector<int32_t> expec0{0, 1, 1, 2, 2, 0, 2, 2, 0, 1, 1, 0};
|
||||||
|
ASSERT_EQ(result0_val, expec0);
|
||||||
|
}
|
||||||
|
52
ngraph/test/type_prop/assign.cpp
Normal file
52
ngraph/test/type_prop/assign.cpp
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
TEST(type_prop, assign_variable_not_found)
|
||||||
|
{
|
||||||
|
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto space_to_depth = make_shared<op::Assign>(A, "variable_id");
|
||||||
|
// Should have thrown, so fail if it didn't
|
||||||
|
FAIL() << "Should not find variable with variable_id";
|
||||||
|
}
|
||||||
|
catch (const NodeValidationFailure& error)
|
||||||
|
{
|
||||||
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
|
std::string("Can't find variable with id = variable_id"));
|
||||||
|
}
|
||||||
|
catch (...)
|
||||||
|
{
|
||||||
|
FAIL() << "Deduced type check failed for unexpected reason";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, assign_deduce)
|
||||||
|
{
|
||||||
|
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
|
||||||
|
auto read_value = make_shared<op::ReadValue>(input, "variable_id");
|
||||||
|
auto assign = make_shared<op::Assign>(read_value, "variable_id");
|
||||||
|
|
||||||
|
ASSERT_EQ(assign->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(assign->get_shape(), (Shape{1, 2, 64, 64}));
|
||||||
|
}
|
31
ngraph/test/type_prop/read_value.cpp
Normal file
31
ngraph/test/type_prop/read_value.cpp
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
//*****************************************************************************
|
||||||
|
// Copyright 2020 Intel Corporation
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//*****************************************************************************
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
TEST(type_prop, read_value_deduce)
|
||||||
|
{
|
||||||
|
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 2, 64, 64});
|
||||||
|
auto read_value = make_shared<op::ReadValue>(input, "variable_id");
|
||||||
|
|
||||||
|
ASSERT_EQ(read_value->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(read_value->get_shape(), (Shape{1, 2, 64, 64}));
|
||||||
|
}
|
@ -49,7 +49,7 @@ ngraph::test::NgraphTestCase::NgraphTestCase(const std::shared_ptr<Function>& fu
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
|
::testing::AssertionResult ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
|
||||||
{
|
{
|
||||||
m_tolerance_bits = tolerance_bits;
|
m_tolerance_bits = tolerance_bits;
|
||||||
const auto& function_results = m_function->get_results();
|
const auto& function_results = m_function->get_results();
|
||||||
@ -85,6 +85,7 @@ void ngraph::test::NgraphTestCase::run(size_t tolerance_bits)
|
|||||||
m_output_index = 0;
|
m_output_index = 0;
|
||||||
m_expected_outputs.clear();
|
m_expected_outputs.clear();
|
||||||
m_input_tensors.clear();
|
m_input_tensors.clear();
|
||||||
|
return ::testing::AssertionSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
|
ngraph::test::NgraphTestCase& ngraph::test::NgraphTestCase::dump_results(bool dump)
|
||||||
|
@ -187,7 +187,7 @@ namespace ngraph
|
|||||||
add_expected_output<T>(expected_shape, value);
|
add_expected_output<T>(expected_shape, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
|
::testing::AssertionResult run(size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Loading…
Reference in New Issue
Block a user