[MO, TF] Support Custom Wide and Deep CTR model by MO (#8505)
* [MO, TF] Support Custom Wide and Deep CTR model by MO It implements implicit support of EmbeddingSegmentsMean operation through decomposition. Also, this extends the current transformation to fuse TensorFlow sub-graph (for Wide and Deep model family) containing SparseSegmentSum and SparseSegmentMean operations into EmbeddingSegmentsSum or EmbeddingSegmentsMean. Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix unit-tests after modifications of SparseToDense and EmbeddingSegmentsOperationFusing Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Document SparseSegmentMean support Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Add computation scheme for normalization coeffs and correct documentation Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
c307f206dc
commit
8e327bd2ff
@ -304,6 +304,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
|
||||
| SparseFillEmptyRows | Supported only when it is part of a sub-graph of the special form |
|
||||
| SparseReshape | Supported only when it is part of a sub-graph of the special form |
|
||||
| SparseSegmentSum | Supported only when it is part of a sub-graph of the special form |
|
||||
| SparseSegmentMean | Supported only when it is part of a sub-graph of the special form |
|
||||
| SparseToDense | CPU only |
|
||||
| Split | |
|
||||
| SplitV | |
|
||||
|
@ -99,30 +99,30 @@ python mo.py
|
||||
IteratorGetNext:2[2],
|
||||
IteratorGetNext:4[2],
|
||||
IteratorGetNext:7[2],
|
||||
linear/linear_model/linear_model/linear_model/education/to_sparse_input/indices:0[10 2]{i32},
|
||||
linear/linear_model/linear_model/linear_model/education/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
linear/linear_model/linear_model/linear_model/education/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/indices:0[10 2]{i32},
|
||||
linear/linear_model/linear_model/linear_model/marital_status/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
linear/l inear_model/linear_model/linear_model/relationship/to_sparse_input/indices:0[10 2]{i32},
|
||||
linear/linear_model/linear_model/linear_model/relationship/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
linear/linear_model/linear_model/linear_model/relationship/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/indices:0[10 2]{i32},
|
||||
linear/linear_model/linear_model/linear_model/workclass/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/indices:0[10 2]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/education_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/indices:0[10 2]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/marital_status_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/indices:0[10 2]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/relationship_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/indices:0[10 2]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/workclass_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
|
||||
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50]"
|
||||
linear/linear_model/linear_model/linear_model/education/to_sparse_input/indices:0[10 2]{i64},
|
||||
linear/linear_model/linear_model/linear_model/education/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
linear/linear_model/linear_model/linear_model/education/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/indices:0[10 2]{i64},
|
||||
linear/linear_model/linear_model/linear_model/marital_status/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
linear/l inear_model/linear_model/linear_model/relationship/to_sparse_input/indices:0[10 2]{i64},
|
||||
linear/linear_model/linear_model/linear_model/relationship/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
linear/linear_model/linear_model/linear_model/relationship/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/indices:0[10 2]{i64},
|
||||
linear/linear_model/linear_model/linear_model/workclass/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/indices:0[10 2]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/education_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/indices:0[10 2]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/marital_status_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/indices:0[10 2]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/relationship_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
|
||||
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/indices:0[10 2]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/workclass_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
|
||||
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50]"
|
||||
--output head/predictions/probabilities
|
||||
```
|
||||
|
||||
|
@ -403,7 +403,8 @@ extensions/front/tf/efficient_det_support_api_v2.0.json
|
||||
extensions/front/tf/efficient_det_support_api_v2.4.json
|
||||
extensions/front/tf/einsum_ext.py
|
||||
extensions/front/tf/elementwise_ext.py
|
||||
extensions/front/tf/embedding_segments_sum.py
|
||||
extensions/front/tf/embedding_segments_mean_decomposition.py
|
||||
extensions/front/tf/embedding_segments_operation_fusing.py
|
||||
extensions/front/tf/expand_dims_ext.py
|
||||
extensions/front/tf/extract_image_patches_ext.py
|
||||
extensions/front/tf/fake_const_ext.py
|
||||
|
@ -20,9 +20,11 @@ class WhereDecomposition(FrontReplacementOp):
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
|
||||
from extensions.front.tf.embedding_segments_operation_fusing import \
|
||||
EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing
|
||||
from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
|
||||
return [EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2, TransposeOrderNormalizer]
|
||||
return [EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing,
|
||||
TransposeOrderNormalizer]
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
@ -0,0 +1,149 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.ConvertLike import ConvertLike
|
||||
from extensions.ops.ReduceOps import ReduceSum
|
||||
from extensions.ops.elementwise import Div
|
||||
from extensions.ops.elementwise import Equal
|
||||
from extensions.ops.embedding_bag import EmbeddingSegmentsSum
|
||||
from extensions.ops.range import Range
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
|
||||
|
||||
class EmbeddingSegmentsMeanDecomposition(FrontReplacementPattern):
|
||||
"""
|
||||
This transformation decomposes EmbeddingSegmentsMean operation into EmbeddingSegmentSum operations taking into
|
||||
account that summed up embedding vectors for each vector must be normalized appropriately by a coefficient
|
||||
equal to a number of gathered embedding vectors for each object. If there is no gathered embedding vector
|
||||
for an object, the coefficient equals one.
|
||||
|
||||
Approximate computation scheme (Cast operations omitted) for the normalization coefficients:
|
||||
|
||||
Const(0)
|
||||
segment_ids -> Unsqueeze(axis=1) -----------------\ |
|
||||
\ \/
|
||||
---> Equal() --> Select --> ReduceSum(axis=0) --> Norm. Coeff.
|
||||
/ /\
|
||||
Range(0, num_segments) -> Unsqueeze(axis=0)------ / |
|
||||
Const(1)
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.tf.embedding_segments_operation_fusing import \
|
||||
EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing
|
||||
return [EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for embedding_segments_mean in graph.get_op_nodes(op='EmbeddingSegmentsMean'):
|
||||
embedding_segments_mean_name = embedding_segments_mean.soft_get('name',
|
||||
embedding_segments_mean.id)
|
||||
embedding_table_input = embedding_segments_mean.in_port(0)
|
||||
segment_ids_input = embedding_segments_mean.in_port(2)
|
||||
num_segments_input = embedding_segments_mean.in_port(3)
|
||||
|
||||
# TODO: support EmbeddingSegmentsMean with specified weights vector.
|
||||
# now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
|
||||
# transformations do not handle it either
|
||||
if embedding_segments_mean.is_in_port_connected(5):
|
||||
return
|
||||
|
||||
# 1. compute indices membership matrix, i.e. which indices belong to some object
|
||||
# the shape of this matrix is [num_segments, num_indices]
|
||||
non_norm_range_1_to_num_segments = create_op_with_const_inputs(graph, Range,
|
||||
{0: int64_array(0),
|
||||
2: int64_array(1)},
|
||||
{'name': embedding_segments_mean_name +
|
||||
'/Range1ToNumSegments',
|
||||
'output_type': np.int64})
|
||||
num_segments_input.get_connection().add_destination(non_norm_range_1_to_num_segments.in_port(1))
|
||||
|
||||
range_1_to_num_segments = ConvertLike(graph, {'name': embedding_segments_mean_name +
|
||||
'/Range1ToNumSegmentsNorm'}
|
||||
).create_node()
|
||||
range_1_to_num_segments.in_port(0).connect(non_norm_range_1_to_num_segments.out_port(0))
|
||||
num_segments_input.get_connection().add_destination(range_1_to_num_segments.in_port(1))
|
||||
|
||||
unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(1)},
|
||||
{'name': embedding_segments_mean_name +
|
||||
'/Range1ToNumSegmentsUnsqueeze'})
|
||||
unsqueeze_range_1_to_num_segments.in_port(0).connect(range_1_to_num_segments.out_port(0))
|
||||
unsqueeze_segment_ids = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
|
||||
{'name': embedding_segments_mean_name +
|
||||
'/SegmentIdsUnsqueeze'})
|
||||
segment_ids_input.get_connection().add_destination(unsqueeze_segment_ids.in_port(0))
|
||||
boolean_membership_matrix = Equal(graph, {'name': embedding_segments_mean_name +
|
||||
'/BooleanMembershipMatrix'}
|
||||
).create_node()
|
||||
boolean_membership_matrix.in_port(0).connect(unsqueeze_range_1_to_num_segments.out_port(0))
|
||||
boolean_membership_matrix.in_port(1).connect(unsqueeze_segment_ids.out_port(0))
|
||||
shape_of_membership_matrix = Shape(graph, {'name': embedding_segments_mean_name +
|
||||
'/ShapeOfMembershipMatrix'}
|
||||
).create_node([boolean_membership_matrix])
|
||||
one_scalar_constant = Const(graph, {'name': embedding_segments_mean_name + '/OneScalar',
|
||||
'value': int64_array([1])}).create_node()
|
||||
one_constant = Broadcast(graph, {'name': embedding_segments_mean_name + '/One'}
|
||||
).create_node([one_scalar_constant,
|
||||
shape_of_membership_matrix])
|
||||
zero_constant = Const(graph, {'name': embedding_segments_mean_name + '/Zero',
|
||||
'value': int64_array(0)}).create_node()
|
||||
membership_matrix = Select(graph, {'name': embedding_segments_mean_name + '/MembershipMatrix',
|
||||
'auto_broadcast': 'numpy'}).create_node([boolean_membership_matrix,
|
||||
one_constant,
|
||||
zero_constant])
|
||||
|
||||
# 2. compute a number of indices belong to each object from the batch
|
||||
# it computes the normalization coefficients
|
||||
num_indices_per_object = create_op_with_const_inputs(graph, ReduceSum,
|
||||
{1: int64_array(1)},
|
||||
{'name': embedding_segments_mean_name +
|
||||
'/NumIndicesPerObject'})
|
||||
num_indices_per_object.in_port(0).connect(membership_matrix.out_port(0))
|
||||
|
||||
# 3. replace zero coefficient (zero number of indices belong to an object) with one
|
||||
# because for such object the single default embedding vector is used
|
||||
where_zero_number = Equal(graph, {'name': embedding_segments_mean_name +
|
||||
'/WhereZeroIndicesNumber'}
|
||||
).create_node([num_indices_per_object, zero_constant])
|
||||
normalized_num_indices_per_object = Select(graph, {'name': embedding_segments_mean_name +
|
||||
'/NormNumIndicesPerObject',
|
||||
'auto_broadcast': 'numpy'}
|
||||
).create_node([where_zero_number,
|
||||
one_scalar_constant,
|
||||
num_indices_per_object])
|
||||
|
||||
# 4. cast normalized_num_indices_per_object to the same type as embedding vector table
|
||||
norm_coefficients = ConvertLike(graph, {'name': embedding_segments_mean_name +
|
||||
'/NormCoefficients'}
|
||||
).create_node()
|
||||
norm_coefficients.in_port(0).connect(normalized_num_indices_per_object.out_port(0))
|
||||
embedding_table_input.get_connection().add_destination(norm_coefficients.in_port(1))
|
||||
|
||||
# 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
|
||||
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': embedding_segments_mean_name +
|
||||
'/EmbeddingSegmentsSum'}
|
||||
).create_node()
|
||||
for in_port in embedding_segments_mean.in_ports():
|
||||
if embedding_segments_mean.is_in_port_connected(in_port):
|
||||
embedding_segments_mean.in_port(in_port).get_connection().set_destination(
|
||||
embedding_segments_sum.in_port(in_port))
|
||||
|
||||
# 6. normalize EmbeddingSegmentSum results by computed coefficients
|
||||
result_node = Div(graph, {'name': embedding_segments_mean_name +
|
||||
'/Div'}
|
||||
).create_node([embedding_segments_sum, norm_coefficients])
|
||||
embedding_segments_mean.out_port(0).get_connection().set_source(result_node.out_port(0))
|
||||
|
||||
rename_nodes([(embedding_segments_mean, embedding_segments_mean_name + '/AbandonedName'),
|
||||
(result_node, embedding_segments_mean_name)])
|
||||
graph.remove_nodes_from([embedding_segments_mean.id])
|
@ -2,10 +2,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.embedding_bag import EmbeddingSegmentsSum
|
||||
from extensions.ops.embedding_bag import EmbeddingSegmentsMean, EmbeddingSegmentsSum
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
@ -14,10 +15,10 @@ from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.ops.squeeze import Squeeze
|
||||
|
||||
|
||||
class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
|
||||
class EmbeddingSegmentsOperationSingleFeatureFusing(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table
|
||||
for object feature values and sum up these embedding vectors for every object.
|
||||
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters
|
||||
table for object feature values, and sum up these embedding vectors for every object or compute their mean value.
|
||||
Such sub-graph is met in the Wide and Deep model in case of the SINGLE categorical feature.
|
||||
"""
|
||||
enabled = True
|
||||
@ -36,7 +37,7 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
|
||||
('strided_slice', dict(op='StridedSlice')),
|
||||
('cast', dict(op='Cast')),
|
||||
('gather', dict(type='Gather')),
|
||||
('sparse_segment_sum', dict(op='SparseSegmentSum')),
|
||||
('sparse_segment_op', dict(op=lambda op: op in ['SparseSegmentSum', 'SparseSegmentMean'])),
|
||||
('reshape', dict(op='Reshape')),
|
||||
('tile', dict(type='Tile')),
|
||||
('select', dict(op='Select'))
|
||||
@ -52,12 +53,12 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
|
||||
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
|
||||
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
|
||||
('unique', 'sparse_segment_op', {'out': 1, 'in': 1}),
|
||||
('unique', 'gather', {'out': 0, 'in': 1}),
|
||||
('strided_slice', 'cast', {'out': 0, 'in': 0}),
|
||||
('gather', 'sparse_segment_sum', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
|
||||
('gather', 'sparse_segment_op', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_op', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_op', 'select', {'out': 0, 'in': 2}),
|
||||
('reshape', 'tile', {'out': 0, 'in': 0}),
|
||||
('tile', 'select', {'out': 0, 'in': 0})
|
||||
])
|
||||
@ -71,41 +72,52 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
|
||||
gather = match['gather']
|
||||
select = match['select']
|
||||
where0 = match['where0']
|
||||
sparse_segment_op = match['sparse_segment_op']
|
||||
output_node_name = select.soft_get('name', select.id)
|
||||
|
||||
log.debug('Found EmbeddingSegmentsSum pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
|
||||
sparse_fill_empty_rows.name))
|
||||
log.debug('Found EmbeddingSparseSegmentsSingleFeature pattern after {} with name {}'.format(
|
||||
sparse_fill_empty_rows.op,
|
||||
sparse_fill_empty_rows.name))
|
||||
|
||||
split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2})
|
||||
squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
|
||||
split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
|
||||
squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
|
||||
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
|
||||
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
|
||||
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
|
||||
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
|
||||
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])
|
||||
|
||||
# TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
|
||||
# indices, and num_segments of different integer type.
|
||||
# Because the real cases show that it is possible to have it in TensorFlow
|
||||
cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
|
||||
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
|
||||
'dst_type': np.int32}).create_node()
|
||||
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
|
||||
'dst_type': np.int32}).create_node()
|
||||
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
|
||||
'dst_type': np.int32}).create_node()
|
||||
if sparse_segment_op.op == 'SparseSegmentSum':
|
||||
embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
|
||||
else:
|
||||
embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
|
||||
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])
|
||||
|
||||
# connect parameters table
|
||||
gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
|
||||
gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
|
||||
# connect indices values
|
||||
greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
|
||||
greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
|
||||
embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
|
||||
# split and connect segment ids
|
||||
gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
|
||||
squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
|
||||
embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
|
||||
embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
|
||||
# split and connect number of segments
|
||||
identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
|
||||
squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
|
||||
embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
|
||||
embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
|
||||
# connect default value
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
|
||||
embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
|
||||
embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
|
||||
# no input port for per_sample_weight
|
||||
|
||||
identity_spw.in_port(0).disconnect()
|
||||
@ -115,14 +127,15 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
|
||||
sparse_fill_empty_rows.in_port(2).disconnect()
|
||||
gather.in_port(0).disconnect()
|
||||
|
||||
select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
|
||||
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
|
||||
select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
|
||||
graph.remove_nodes_from(
|
||||
[gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
|
||||
|
||||
|
||||
class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
|
||||
class EmbeddingSegmentsOperationMultipleFeaturesFusing(FrontReplacementSubgraph):
|
||||
"""
|
||||
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table
|
||||
for object feature values and sum up these embedding vectors for every object.
|
||||
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters
|
||||
table for object feature values, and sum up these embedding vectors for every object or compute their mean value.
|
||||
Such sub-graph is met in the Wide and Deep model in case of MULTIPLE categorical features.
|
||||
"""
|
||||
enabled = True
|
||||
@ -143,7 +156,7 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
|
||||
('gather', dict(type='Gather')),
|
||||
('identity', dict(op='Identity')),
|
||||
('identity_1', dict(op='Identity')),
|
||||
('sparse_segment_sum', dict(op='SparseSegmentSum')),
|
||||
('sparse_segment_op', dict(op=lambda op: op in ['SparseSegmentSum', 'SparseSegmentMean'])),
|
||||
('reshape', dict(op='Reshape')),
|
||||
('tile', dict(type='Tile')),
|
||||
('select', dict(op='Select'))
|
||||
@ -159,14 +172,14 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
|
||||
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
|
||||
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
|
||||
('unique', 'sparse_segment_op', {'out': 1, 'in': 1}),
|
||||
('unique', 'gather', {'out': 0, 'in': 1}),
|
||||
('strided_slice', 'cast', {'out': 0, 'in': 0}),
|
||||
('gather', 'identity', {'out': 0, 'in': 0}),
|
||||
('identity', 'identity_1', {'out': 0, 'in': 0}),
|
||||
('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
|
||||
('identity_1', 'sparse_segment_op', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_op', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_op', 'select', {'out': 0, 'in': 2}),
|
||||
('reshape', 'tile', {'out': 0, 'in': 0}),
|
||||
('tile', 'select', {'out': 0, 'in': 0})
|
||||
])
|
||||
@ -180,10 +193,12 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
|
||||
gather = match['gather']
|
||||
select = match['select']
|
||||
where0 = match['where0']
|
||||
sparse_segment_op = match['sparse_segment_op']
|
||||
output_node_name = select.soft_get('name', select.id)
|
||||
|
||||
log.debug('Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
|
||||
sparse_fill_empty_rows.name))
|
||||
log.debug('Found EmbeddingSparseSegmentsMultipleFeatures pattern after {} with name {}'.format(
|
||||
sparse_fill_empty_rows.op,
|
||||
sparse_fill_empty_rows.name))
|
||||
|
||||
split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)},
|
||||
{'num_splits': 2,
|
||||
@ -193,32 +208,42 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
|
||||
{'num_splits': 2,
|
||||
'name': output_node_name + '/SplitForDenseShape'})
|
||||
squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
|
||||
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
|
||||
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
|
||||
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
|
||||
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
|
||||
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])
|
||||
|
||||
# TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
|
||||
# indices, and num_segments of different integer type.
|
||||
# Because the real cases show that it is possible to have it in TensorFlow
|
||||
cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
|
||||
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
|
||||
'dst_type': np.int32}).create_node()
|
||||
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
|
||||
'dst_type': np.int32}).create_node()
|
||||
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
|
||||
'dst_type': np.int32}).create_node()
|
||||
|
||||
if sparse_segment_op.op == 'SparseSegmentSum':
|
||||
embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
|
||||
else:
|
||||
embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
|
||||
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])
|
||||
|
||||
# connect parameters table
|
||||
gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
|
||||
gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
|
||||
# connect indices values
|
||||
greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
|
||||
greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
|
||||
embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
|
||||
# split and connect segment ids
|
||||
gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
|
||||
squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
|
||||
embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
|
||||
embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
|
||||
# split and connect number of segments
|
||||
identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
|
||||
squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
|
||||
embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
|
||||
embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
|
||||
# connect default value
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
|
||||
embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
|
||||
embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
|
||||
# no input port for per_sample_weight
|
||||
|
||||
identity_spw.in_port(0).disconnect()
|
||||
@ -228,5 +253,6 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
|
||||
sparse_fill_empty_rows.in_port(2).disconnect()
|
||||
gather.in_port(0).disconnect()
|
||||
|
||||
select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
|
||||
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
|
||||
select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
|
||||
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id,
|
||||
select.id, where0.id])
|
@ -3,7 +3,6 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.scatternd import ScatterNDUpdate
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.graph.graph import Node, Graph, rename_nodes
|
||||
@ -32,12 +31,7 @@ class SparseToDenseReplacer(FrontReplacementOp):
|
||||
broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
|
||||
node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
|
||||
if not node.in_port(3).disconnected():
|
||||
# TODO: remove casting once we start to support I64 model input
|
||||
# cast default value to I32 due limitation about I64 input support
|
||||
# so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate
|
||||
cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
|
||||
node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
|
||||
broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
|
||||
node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0))
|
||||
else:
|
||||
broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
|
||||
'value': np.float32(0)}
|
||||
|
@ -46,10 +46,10 @@ class EmbeddingBagOffsetsSum(EmbeddingBagBase):
|
||||
"for node: `{}`. Ports: {}".format(name, connected_in_ports)
|
||||
|
||||
weights_shape = node.in_port(0).data.get_shape()
|
||||
assert len(weights_shape) >= 2,\
|
||||
assert len(weights_shape) >= 2, \
|
||||
"EmbeddingBagOffsetsSum should have at least 2D weights for node: `{}`".format(name)
|
||||
offsets_shape = node.in_port(2).data.get_shape()
|
||||
assert offsets_shape is not None and len(offsets_shape) == 1,\
|
||||
assert offsets_shape is not None and len(offsets_shape) == 1, \
|
||||
"Rank of the offsets in EmbeddingBagOffsetsSum should be equal to 1 for node: `{}`".format(name)
|
||||
|
||||
node.out_port(0).data.set_shape(np.ma.concatenate((offsets_shape[:1], weights_shape[1:])))
|
||||
@ -88,18 +88,41 @@ class EmbeddingSegmentsSum(EmbeddingBagBase):
|
||||
|
||||
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
|
||||
assert len(connected_in_ports) >= 4 and all(p in connected_in_ports for p in [0, 1, 2, 3]), \
|
||||
"EmbeddingSegmentsSum should have at least 4 connected input port, but it doesn't for node: `{}`. " \
|
||||
"Ports: {}".format(name, connected_in_ports)
|
||||
"{} should have at least 4 connected input port, but it doesn't for node: `{}`. " \
|
||||
"Ports: {}".format(node.op, name, connected_in_ports)
|
||||
|
||||
weights_shape = node.in_port(0).data.get_shape()
|
||||
assert len(weights_shape) >= 2,\
|
||||
"EmbeddingSegmentsSum should have at least 2D weights for node: `{}`".format(name)
|
||||
assert len(weights_shape) >= 2, \
|
||||
"{} should have at least 2D weights for node: `{}`".format(node.op, name)
|
||||
indices_shape = node.in_port(1).data.get_shape()
|
||||
segment_ids = node.in_port(2).data.get_shape()
|
||||
assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids,\
|
||||
assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids, \
|
||||
"Both indices and segment_ids should have the same shape for node: `{}`".format(name)
|
||||
num_segments = node.in_port(3).data.get_value()
|
||||
assert num_segments is not None, "EmbeddingSegmentsSum should have a constant num_segments provided, but it " \
|
||||
"doesn't for node: `{}`.".format(name)
|
||||
assert num_segments is not None, "{} should have a constant num_segments provided, but it " \
|
||||
"doesn't for node: `{}`.".format(node.op, name)
|
||||
output_shape = np.ma.concatenate(([num_segments], weights_shape[1:]))
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
|
||||
|
||||
class EmbeddingSegmentsMean(Op):
|
||||
"""
|
||||
Internal Operation.
|
||||
|
||||
In order not to overload transformations (EmbeddingSegmentsOperationSingleFeatureFusing,
|
||||
EmbeddingSegmentsOperationMultipleFeaturesFusing) with additional sub-graph computing mean value of embedding
|
||||
vectors, we introduce internal operation EmbeddingSegmentsMean. After these transformations, this operation
|
||||
is decomposed into EmbeddingSegmentSum with appropriate computation of mean value for embedding vectors collected
|
||||
for each object in a batch.
|
||||
"""
|
||||
op = "EmbeddingSegmentsMean"
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'type': None,
|
||||
'op': self.op,
|
||||
'in_ports_count': 6,
|
||||
'out_ports_count': 1,
|
||||
# it must have the same shape infer function as EmbeddingSegmentsSum
|
||||
'infer': EmbeddingSegmentsSum.infer
|
||||
}, attrs)
|
||||
|
@ -3,19 +3,21 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
|
||||
from extensions.front.tf.embedding_segments_operation_fusing import EmbeddingSegmentsOperationMultipleFeaturesFusing, \
|
||||
EmbeddingSegmentsOperationSingleFeatureFusing
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, const
|
||||
|
||||
|
||||
class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
class EmbeddingSegmentsOperationFusingTest(unittest.TestCase):
|
||||
def test1(self):
|
||||
nodes_attributes = {
|
||||
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
|
||||
'op': 'Parameter'},
|
||||
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'identity_spw': {'kind': 'op', 'op': 'Identity'},
|
||||
@ -38,6 +40,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
|
||||
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'cast_indices': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
|
||||
@ -58,7 +61,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
|
||||
('input_params_table', 'gather', {'out': 0, 'in': 0}),
|
||||
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
|
||||
|
||||
|
||||
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
|
||||
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
|
||||
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
|
||||
@ -78,9 +81,9 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('reshape', 'tile', {'out': 0, 'in': 0}),
|
||||
('tile', 'select', {'out': 0, 'in': 0}),
|
||||
('select', 'last', {'out': 0, 'in': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
], nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
EmbeddingSegmentsSumFrontReplacer().find_and_replace_pattern(graph)
|
||||
EmbeddingSegmentsOperationSingleFeatureFusing().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('input_indices', 'split_for_indices', {'in': 0}),
|
||||
@ -89,7 +92,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
|
||||
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
|
||||
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
|
||||
('input_values', 'embedding_segments_sum', {'in': 1}),
|
||||
('input_values', 'cast_indices', {'in': 0}),
|
||||
('cast_indices', 'embedding_segments_sum', {'in': 1}),
|
||||
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
|
||||
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
|
||||
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
|
||||
@ -99,8 +103,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('input_params_table', 'embedding_segments_sum', {'in': 0}),
|
||||
('input_default_value', 'cast_default_value', {'in': 0}),
|
||||
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
|
||||
('embedding_segments_sum', 'last', {'in': 0}),],
|
||||
nodes_with_edges_only=True)
|
||||
('embedding_segments_sum', 'last', {'in': 0}), ],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
@ -110,7 +114,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
|
||||
'op': 'Parameter'},
|
||||
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'identity_spw': {'kind': 'op', 'op': 'Identity'},
|
||||
@ -126,7 +131,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'},
|
||||
'identity': {'kind': 'op', 'op': 'Identity'},
|
||||
'identity_1': {'kind': 'op', 'op': 'Identity'},
|
||||
'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'},
|
||||
'sparse_segment_mean': {'kind': 'op', 'op': 'SparseSegmentMean'},
|
||||
'reshape': {'kind': 'op', 'op': 'Reshape'},
|
||||
'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
@ -135,10 +140,11 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
|
||||
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'cast_indices': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
|
||||
'embedding_segments_sum': {'kind': 'op', 'op': 'EmbeddingSegmentsSum'},
|
||||
'embedding_segments_mean': {'kind': 'op', 'op': 'EmbeddingSegmentsMean'},
|
||||
|
||||
**const('split_for_indices_axis', int64_array(1)),
|
||||
**const('split_for_dense_shape_axis', int64_array(0)),
|
||||
@ -155,7 +161,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
|
||||
('input_params_table', 'gather', {'out': 0, 'in': 0}),
|
||||
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
|
||||
|
||||
|
||||
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
|
||||
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
|
||||
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
|
||||
@ -166,20 +172,20 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
|
||||
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
|
||||
('unique', 'sparse_segment_mean', {'out': 1, 'in': 1}),
|
||||
('unique', 'gather', {'out': 0, 'in': 1}),
|
||||
('strided_slice', 'cast', {'out': 0, 'in': 0}),
|
||||
('gather', 'identity', {'out': 0, 'in': 0}),
|
||||
('identity', 'identity_1', {'out': 0, 'in': 0}),
|
||||
('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
|
||||
('identity_1', 'sparse_segment_mean', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_mean', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_mean', 'select', {'out': 0, 'in': 2}),
|
||||
('reshape', 'tile', {'out': 0, 'in': 0}),
|
||||
('tile', 'select', {'out': 0, 'in': 0}),
|
||||
('select', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
EmbeddingSegmentsSumFrontReplacer2().find_and_replace_pattern(graph)
|
||||
EmbeddingSegmentsOperationMultipleFeaturesFusing().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('input_indices', 'split_for_indices', {'in': 0}),
|
||||
@ -187,18 +193,19 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('split_for_indices', 'squeeze_for_indices', {'in': 0}),
|
||||
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
|
||||
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
|
||||
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
|
||||
('input_values', 'embedding_segments_sum', {'in': 1}),
|
||||
('cast_segment_ids', 'embedding_segments_mean', {'in': 2, 'out': 0}),
|
||||
('input_values', 'cast_indices', {'in': 0}),
|
||||
('cast_indices', 'embedding_segments_mean', {'in': 1}),
|
||||
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
|
||||
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
|
||||
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
|
||||
('squeeze_axis', 'squeeze_to_scalar', {'in': 1}),
|
||||
('squeeze_to_scalar', 'cast_number_segments', {'in': 0}),
|
||||
('cast_number_segments', 'embedding_segments_sum', {'in': 3, 'out': 0}),
|
||||
('input_params_table', 'embedding_segments_sum', {'in': 0}),
|
||||
('cast_number_segments', 'embedding_segments_mean', {'in': 3, 'out': 0}),
|
||||
('input_params_table', 'embedding_segments_mean', {'in': 0}),
|
||||
('input_default_value', 'cast_default_value', {'in': 0}),
|
||||
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
|
||||
('embedding_segments_sum', 'last', {'in': 0}),],
|
||||
('cast_default_value', 'embedding_segments_mean', {'in': 4}),
|
||||
('embedding_segments_mean', 'last', {'in': 0}), ],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
@ -13,12 +13,11 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
|
||||
def test1(self):
|
||||
nodes_attributes = {
|
||||
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values' : {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'sparse_to_dense' : {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'broadcast' : {'kind': 'op', 'op': 'Broadcast'},
|
||||
'scatternd' : {'kind': 'op', 'op': 'ScatterNDUpdate'},
|
||||
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
|
||||
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'scatternd': {'kind': 'op', 'op': 'ScatterNDUpdate'},
|
||||
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
|
||||
@ -31,19 +30,18 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
|
||||
('input_values', 'sparse_to_dense', {'out': 0, 'in': 2}),
|
||||
('input_default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
|
||||
('sparse_to_dense', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
SparseToDenseReplacer().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('input_default_value', 'cast_default_value', {'in': 0}),
|
||||
('cast_default_value', 'broadcast', {'in': 0}),
|
||||
[('input_default_value', 'broadcast', {'in': 0}),
|
||||
('input_dense_shape', 'broadcast', {'in': 1}),
|
||||
('broadcast', 'scatternd', {'in': 0}),
|
||||
('input_indices', 'scatternd', {'in': 1}),
|
||||
('input_values', 'scatternd', {'in': 2}),
|
||||
('scatternd', 'last', {'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user