[MO, TF] Support Custom Wide and Deep CTR model by MO (#8505)

* [MO, TF] Support Custom Wide and Deep CTR model by MO

It implements implicit support of EmbeddingSegmentsMean operation through decomposition.
Also, this extends the current transformation to fuse TensorFlow sub-graph (for Wide and Deep model family)
containing SparseSegmentSum and SparseSegmentMean operations into EmbeddingSegmentsSum or EmbeddingSegmentsMean.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix unit-tests after modifications of SparseToDense and EmbeddingSegmentsOperationFusing

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Document SparseSegmentMean support

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add computation scheme for normalization coeffs and correct documentation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2021-11-17 11:44:04 +03:00 committed by GitHub
parent c307f206dc
commit 8e327bd2ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 329 additions and 128 deletions

View File

@ -304,6 +304,7 @@ Some TensorFlow\* operations do not match to any Inference Engine layer, but are
| SparseFillEmptyRows | Supported only when it is part of a sub-graph of the special form |
| SparseReshape | Supported only when it is part of a sub-graph of the special form |
| SparseSegmentSum | Supported only when it is part of a sub-graph of the special form |
| SparseSegmentMean | Supported only when it is part of a sub-graph of the special form |
| SparseToDense | CPU only |
| Split | |
| SplitV | |

View File

@ -99,30 +99,30 @@ python mo.py
IteratorGetNext:2[2],
IteratorGetNext:4[2],
IteratorGetNext:7[2],
linear/linear_model/linear_model/linear_model/education/to_sparse_input/indices:0[10 2]{i32},
linear/linear_model/linear_model/linear_model/education/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
linear/linear_model/linear_model/linear_model/education/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/indices:0[10 2]{i32},
linear/linear_model/linear_model/linear_model/marital_status/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
linear/l inear_model/linear_model/linear_model/relationship/to_sparse_input/indices:0[10 2]{i32},
linear/linear_model/linear_model/linear_model/relationship/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
linear/linear_model/linear_model/linear_model/relationship/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/indices:0[10 2]{i32},
linear/linear_model/linear_model/linear_model/workclass/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/indices:0[10 2]{i32},
dnn/input_from_feature_columns/input_layer/education_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/indices:0[10 2]{i32},
dnn/input_from_feature_columns/input_layer/marital_status_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/indices:0[10 2]{i32},
dnn/input_from_feature_columns/input_layer/relationship_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50],
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/indices:0[10 2]{i32},
dnn/input_from_feature_columns/input_layer/workclass_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i32},
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/dense_shape:0[2]{i32}->[2 50]"
linear/linear_model/linear_model/linear_model/education/to_sparse_input/indices:0[10 2]{i64},
linear/linear_model/linear_model/linear_model/education/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
linear/linear_model/linear_model/linear_model/education/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/indices:0[10 2]{i64},
linear/linear_model/linear_model/linear_model/marital_status/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
linear/linear_model/linear_model/linear_model/marital_status/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
linear/l inear_model/linear_model/linear_model/relationship/to_sparse_input/indices:0[10 2]{i64},
linear/linear_model/linear_model/linear_model/relationship/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
linear/linear_model/linear_model/linear_model/relationship/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/indices:0[10 2]{i64},
linear/linear_model/linear_model/linear_model/workclass/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
linear/linear_model/linear_model/linear_model/workclass/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/indices:0[10 2]{i64},
dnn/input_from_feature_columns/input_layer/education_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
dnn/input_from_feature_columns/input_layer/education_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/indices:0[10 2]{i64},
dnn/input_from_feature_columns/input_layer/marital_status_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
dnn/input_from_feature_columns/input_layer/marital_status_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/indices:0[10 2]{i64},
dnn/input_from_feature_columns/input_layer/relationship_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
dnn/input_from_feature_columns/input_layer/relationship_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50],
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/indices:0[10 2]{i64},
dnn/input_from_feature_columns/input_layer/workclass_indicator/hash_table_Lookup/LookupTableFindV2:0[10]{i64},
dnn/input_from_feature_columns/input_layer/workclass_indicator/to_sparse_input/dense_shape:0[2]{i64}->[2 50]"
--output head/predictions/probabilities
```

View File

@ -403,7 +403,8 @@ extensions/front/tf/efficient_det_support_api_v2.0.json
extensions/front/tf/efficient_det_support_api_v2.4.json
extensions/front/tf/einsum_ext.py
extensions/front/tf/elementwise_ext.py
extensions/front/tf/embedding_segments_sum.py
extensions/front/tf/embedding_segments_mean_decomposition.py
extensions/front/tf/embedding_segments_operation_fusing.py
extensions/front/tf/expand_dims_ext.py
extensions/front/tf/extract_image_patches_ext.py
extensions/front/tf/fake_const_ext.py

View File

@ -20,9 +20,11 @@ class WhereDecomposition(FrontReplacementOp):
enabled = True
def run_after(self):
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
from extensions.front.tf.embedding_segments_operation_fusing import \
EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing
from extensions.front.TransposeOrderNormalizer import TransposeOrderNormalizer
return [EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2, TransposeOrderNormalizer]
return [EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing,
TransposeOrderNormalizer]
def replace_op(self, graph: Graph, node: Node):
node_name = node.soft_get('name', node.id)

View File

@ -0,0 +1,149 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from extensions.ops.ConvertLike import ConvertLike
from extensions.ops.ReduceOps import ReduceSum
from extensions.ops.elementwise import Div
from extensions.ops.elementwise import Equal
from extensions.ops.embedding_bag import EmbeddingSegmentsSum
from extensions.ops.range import Range
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
from mo.ops.broadcast import Broadcast
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.unsqueeze import Unsqueeze
class EmbeddingSegmentsMeanDecomposition(FrontReplacementPattern):
"""
This transformation decomposes EmbeddingSegmentsMean operation into EmbeddingSegmentSum operations taking into
account that summed up embedding vectors for each vector must be normalized appropriately by a coefficient
equal to a number of gathered embedding vectors for each object. If there is no gathered embedding vector
for an object, the coefficient equals one.
Approximate computation scheme (Cast operations omitted) for the normalization coefficients:
Const(0)
segment_ids -> Unsqueeze(axis=1) -----------------\ |
\ \/
---> Equal() --> Select --> ReduceSum(axis=0) --> Norm. Coeff.
/ /\
Range(0, num_segments) -> Unsqueeze(axis=0)------ / |
Const(1)
"""
enabled = True
def run_after(self):
from extensions.front.tf.embedding_segments_operation_fusing import \
EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing
return [EmbeddingSegmentsOperationMultipleFeaturesFusing, EmbeddingSegmentsOperationSingleFeatureFusing]
def find_and_replace_pattern(self, graph: Graph):
for embedding_segments_mean in graph.get_op_nodes(op='EmbeddingSegmentsMean'):
embedding_segments_mean_name = embedding_segments_mean.soft_get('name',
embedding_segments_mean.id)
embedding_table_input = embedding_segments_mean.in_port(0)
segment_ids_input = embedding_segments_mean.in_port(2)
num_segments_input = embedding_segments_mean.in_port(3)
# TODO: support EmbeddingSegmentsMean with specified weights vector.
# now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
# transformations do not handle it either
if embedding_segments_mean.is_in_port_connected(5):
return
# 1. compute indices membership matrix, i.e. which indices belong to some object
# the shape of this matrix is [num_segments, num_indices]
non_norm_range_1_to_num_segments = create_op_with_const_inputs(graph, Range,
{0: int64_array(0),
2: int64_array(1)},
{'name': embedding_segments_mean_name +
'/Range1ToNumSegments',
'output_type': np.int64})
num_segments_input.get_connection().add_destination(non_norm_range_1_to_num_segments.in_port(1))
range_1_to_num_segments = ConvertLike(graph, {'name': embedding_segments_mean_name +
'/Range1ToNumSegmentsNorm'}
).create_node()
range_1_to_num_segments.in_port(0).connect(non_norm_range_1_to_num_segments.out_port(0))
num_segments_input.get_connection().add_destination(range_1_to_num_segments.in_port(1))
unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(1)},
{'name': embedding_segments_mean_name +
'/Range1ToNumSegmentsUnsqueeze'})
unsqueeze_range_1_to_num_segments.in_port(0).connect(range_1_to_num_segments.out_port(0))
unsqueeze_segment_ids = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)},
{'name': embedding_segments_mean_name +
'/SegmentIdsUnsqueeze'})
segment_ids_input.get_connection().add_destination(unsqueeze_segment_ids.in_port(0))
boolean_membership_matrix = Equal(graph, {'name': embedding_segments_mean_name +
'/BooleanMembershipMatrix'}
).create_node()
boolean_membership_matrix.in_port(0).connect(unsqueeze_range_1_to_num_segments.out_port(0))
boolean_membership_matrix.in_port(1).connect(unsqueeze_segment_ids.out_port(0))
shape_of_membership_matrix = Shape(graph, {'name': embedding_segments_mean_name +
'/ShapeOfMembershipMatrix'}
).create_node([boolean_membership_matrix])
one_scalar_constant = Const(graph, {'name': embedding_segments_mean_name + '/OneScalar',
'value': int64_array([1])}).create_node()
one_constant = Broadcast(graph, {'name': embedding_segments_mean_name + '/One'}
).create_node([one_scalar_constant,
shape_of_membership_matrix])
zero_constant = Const(graph, {'name': embedding_segments_mean_name + '/Zero',
'value': int64_array(0)}).create_node()
membership_matrix = Select(graph, {'name': embedding_segments_mean_name + '/MembershipMatrix',
'auto_broadcast': 'numpy'}).create_node([boolean_membership_matrix,
one_constant,
zero_constant])
# 2. compute a number of indices belong to each object from the batch
# it computes the normalization coefficients
num_indices_per_object = create_op_with_const_inputs(graph, ReduceSum,
{1: int64_array(1)},
{'name': embedding_segments_mean_name +
'/NumIndicesPerObject'})
num_indices_per_object.in_port(0).connect(membership_matrix.out_port(0))
# 3. replace zero coefficient (zero number of indices belong to an object) with one
# because for such object the single default embedding vector is used
where_zero_number = Equal(graph, {'name': embedding_segments_mean_name +
'/WhereZeroIndicesNumber'}
).create_node([num_indices_per_object, zero_constant])
normalized_num_indices_per_object = Select(graph, {'name': embedding_segments_mean_name +
'/NormNumIndicesPerObject',
'auto_broadcast': 'numpy'}
).create_node([where_zero_number,
one_scalar_constant,
num_indices_per_object])
# 4. cast normalized_num_indices_per_object to the same type as embedding vector table
norm_coefficients = ConvertLike(graph, {'name': embedding_segments_mean_name +
'/NormCoefficients'}
).create_node()
norm_coefficients.in_port(0).connect(normalized_num_indices_per_object.out_port(0))
embedding_table_input.get_connection().add_destination(norm_coefficients.in_port(1))
# 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': embedding_segments_mean_name +
'/EmbeddingSegmentsSum'}
).create_node()
for in_port in embedding_segments_mean.in_ports():
if embedding_segments_mean.is_in_port_connected(in_port):
embedding_segments_mean.in_port(in_port).get_connection().set_destination(
embedding_segments_sum.in_port(in_port))
# 6. normalize EmbeddingSegmentSum results by computed coefficients
result_node = Div(graph, {'name': embedding_segments_mean_name +
'/Div'}
).create_node([embedding_segments_sum, norm_coefficients])
embedding_segments_mean.out_port(0).get_connection().set_source(result_node.out_port(0))
rename_nodes([(embedding_segments_mean, embedding_segments_mean_name + '/AbandonedName'),
(result_node, embedding_segments_mean_name)])
graph.remove_nodes_from([embedding_segments_mean.id])

View File

@ -2,10 +2,11 @@
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.embedding_bag import EmbeddingSegmentsSum
from extensions.ops.embedding_bag import EmbeddingSegmentsMean, EmbeddingSegmentsSum
from extensions.ops.split import Split
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
@ -14,10 +15,10 @@ from mo.graph.graph import Graph, rename_nodes
from mo.ops.squeeze import Squeeze
class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
class EmbeddingSegmentsOperationSingleFeatureFusing(FrontReplacementSubgraph):
"""
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table
for object feature values and sum up these embedding vectors for every object.
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters
table for object feature values, and sum up these embedding vectors for every object or compute their mean value.
Such sub-graph is met in the Wide and Deep model in case of the SINGLE categorical feature.
"""
enabled = True
@ -36,7 +37,7 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
('strided_slice', dict(op='StridedSlice')),
('cast', dict(op='Cast')),
('gather', dict(type='Gather')),
('sparse_segment_sum', dict(op='SparseSegmentSum')),
('sparse_segment_op', dict(op=lambda op: op in ['SparseSegmentSum', 'SparseSegmentMean'])),
('reshape', dict(op='Reshape')),
('tile', dict(type='Tile')),
('select', dict(op='Select'))
@ -52,12 +53,12 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
('unique', 'sparse_segment_op', {'out': 1, 'in': 1}),
('unique', 'gather', {'out': 0, 'in': 1}),
('strided_slice', 'cast', {'out': 0, 'in': 0}),
('gather', 'sparse_segment_sum', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
('gather', 'sparse_segment_op', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_op', {'out': 0, 'in': 2}),
('sparse_segment_op', 'select', {'out': 0, 'in': 2}),
('reshape', 'tile', {'out': 0, 'in': 0}),
('tile', 'select', {'out': 0, 'in': 0})
])
@ -71,41 +72,52 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
gather = match['gather']
select = match['select']
where0 = match['where0']
sparse_segment_op = match['sparse_segment_op']
output_node_name = select.soft_get('name', select.id)
log.debug('Found EmbeddingSegmentsSum pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
sparse_fill_empty_rows.name))
log.debug('Found EmbeddingSparseSegmentsSingleFeature pattern after {} with name {}'.format(
sparse_fill_empty_rows.op,
sparse_fill_empty_rows.name))
split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2})
squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])})
split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2})
squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])
# TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
# indices, and num_segments of different integer type.
# Because the real cases show that it is possible to have it in TensorFlow
cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
'dst_type': np.int32}).create_node()
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
'dst_type': np.int32}).create_node()
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
'dst_type': np.int32}).create_node()
if sparse_segment_op.op == 'SparseSegmentSum':
embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
else:
embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])
# connect parameters table
gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
# connect indices values
greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
# split and connect segment ids
gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
# split and connect number of segments
identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
# connect default value
# TODO: remove casting once we start to support I64 model input
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
# no input port for per_sample_weight
identity_spw.in_port(0).disconnect()
@ -115,14 +127,15 @@ class EmbeddingSegmentsSumFrontReplacer(FrontReplacementSubgraph):
sparse_fill_empty_rows.in_port(2).disconnect()
gather.in_port(0).disconnect()
select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
graph.remove_nodes_from(
[gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
class EmbeddingSegmentsOperationMultipleFeaturesFusing(FrontReplacementSubgraph):
"""
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters table
for object feature values and sum up these embedding vectors for every object.
The transformation looks for pattern (sub-graph) that performs extraction of embedding vectors from the parameters
table for object feature values, and sum up these embedding vectors for every object or compute their mean value.
Such sub-graph is met in the Wide and Deep model in case of MULTIPLE categorical features.
"""
enabled = True
@ -143,7 +156,7 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
('gather', dict(type='Gather')),
('identity', dict(op='Identity')),
('identity_1', dict(op='Identity')),
('sparse_segment_sum', dict(op='SparseSegmentSum')),
('sparse_segment_op', dict(op=lambda op: op in ['SparseSegmentSum', 'SparseSegmentMean'])),
('reshape', dict(op='Reshape')),
('tile', dict(type='Tile')),
('select', dict(op='Select'))
@ -159,14 +172,14 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
('unique', 'sparse_segment_op', {'out': 1, 'in': 1}),
('unique', 'gather', {'out': 0, 'in': 1}),
('strided_slice', 'cast', {'out': 0, 'in': 0}),
('gather', 'identity', {'out': 0, 'in': 0}),
('identity', 'identity_1', {'out': 0, 'in': 0}),
('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
('identity_1', 'sparse_segment_op', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_op', {'out': 0, 'in': 2}),
('sparse_segment_op', 'select', {'out': 0, 'in': 2}),
('reshape', 'tile', {'out': 0, 'in': 0}),
('tile', 'select', {'out': 0, 'in': 0})
])
@ -180,10 +193,12 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
gather = match['gather']
select = match['select']
where0 = match['where0']
sparse_segment_op = match['sparse_segment_op']
output_node_name = select.soft_get('name', select.id)
log.debug('Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op,
sparse_fill_empty_rows.name))
log.debug('Found EmbeddingSparseSegmentsMultipleFeatures pattern after {} with name {}'.format(
sparse_fill_empty_rows.op,
sparse_fill_empty_rows.name))
split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)},
{'num_splits': 2,
@ -193,32 +208,42 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
{'num_splits': 2,
'name': output_node_name + '/SplitForDenseShape'})
squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node()
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node()
embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)])
# TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids,
# indices, and num_segments of different integer type.
# Because the real cases show that it is possible to have it in TensorFlow
cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node()
cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds',
'dst_type': np.int32}).create_node()
cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue',
'dst_type': np.int32}).create_node()
cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber',
'dst_type': np.int32}).create_node()
if sparse_segment_op.op == 'SparseSegmentSum':
embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node()
else:
embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node()
rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)])
# connect parameters table
gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0))
gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0))
# connect indices values
greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1))
greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0))
embedding_segments_op.in_port(1).connect(cast_indices.out_port(0))
# split and connect segment ids
gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0))
squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0))
embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0))
embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0))
# split and connect number of segments
identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0))
squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0))
# TODO: remove casting once we start to support I64 model input
cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0))
embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0))
embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0))
# connect default value
# TODO: remove casting once we start to support I64 model input
sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0))
embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0))
# no input port for per_sample_weight
identity_spw.in_port(0).disconnect()
@ -228,5 +253,6 @@ class EmbeddingSegmentsSumFrontReplacer2(FrontReplacementSubgraph):
sparse_fill_empty_rows.in_port(2).disconnect()
gather.in_port(0).disconnect()
select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0))
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0))
graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id,
select.id, where0.id])

View File

@ -3,7 +3,6 @@
import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.scatternd import ScatterNDUpdate
from mo.front.common.replacement import FrontReplacementOp
from mo.graph.graph import Node, Graph, rename_nodes
@ -32,12 +31,7 @@ class SparseToDenseReplacer(FrontReplacementOp):
broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
if not node.in_port(3).disconnected():
# TODO: remove casting once we start to support I64 model input
# cast default value to I32 due limitation about I64 input support
# so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate
cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0))
else:
broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
'value': np.float32(0)}

View File

@ -46,10 +46,10 @@ class EmbeddingBagOffsetsSum(EmbeddingBagBase):
"for node: `{}`. Ports: {}".format(name, connected_in_ports)
weights_shape = node.in_port(0).data.get_shape()
assert len(weights_shape) >= 2,\
assert len(weights_shape) >= 2, \
"EmbeddingBagOffsetsSum should have at least 2D weights for node: `{}`".format(name)
offsets_shape = node.in_port(2).data.get_shape()
assert offsets_shape is not None and len(offsets_shape) == 1,\
assert offsets_shape is not None and len(offsets_shape) == 1, \
"Rank of the offsets in EmbeddingBagOffsetsSum should be equal to 1 for node: `{}`".format(name)
node.out_port(0).data.set_shape(np.ma.concatenate((offsets_shape[:1], weights_shape[1:])))
@ -88,18 +88,41 @@ class EmbeddingSegmentsSum(EmbeddingBagBase):
connected_in_ports = {idx: port for idx, port in node.in_ports().items() if not port.disconnected()}
assert len(connected_in_ports) >= 4 and all(p in connected_in_ports for p in [0, 1, 2, 3]), \
"EmbeddingSegmentsSum should have at least 4 connected input port, but it doesn't for node: `{}`. " \
"Ports: {}".format(name, connected_in_ports)
"{} should have at least 4 connected input port, but it doesn't for node: `{}`. " \
"Ports: {}".format(node.op, name, connected_in_ports)
weights_shape = node.in_port(0).data.get_shape()
assert len(weights_shape) >= 2,\
"EmbeddingSegmentsSum should have at least 2D weights for node: `{}`".format(name)
assert len(weights_shape) >= 2, \
"{} should have at least 2D weights for node: `{}`".format(node.op, name)
indices_shape = node.in_port(1).data.get_shape()
segment_ids = node.in_port(2).data.get_shape()
assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids,\
assert len(indices_shape) == 1 and len(segment_ids) == 1 and indices_shape == segment_ids, \
"Both indices and segment_ids should have the same shape for node: `{}`".format(name)
num_segments = node.in_port(3).data.get_value()
assert num_segments is not None, "EmbeddingSegmentsSum should have a constant num_segments provided, but it " \
"doesn't for node: `{}`.".format(name)
assert num_segments is not None, "{} should have a constant num_segments provided, but it " \
"doesn't for node: `{}`.".format(node.op, name)
output_shape = np.ma.concatenate(([num_segments], weights_shape[1:]))
node.out_port(0).data.set_shape(output_shape)
class EmbeddingSegmentsMean(Op):
"""
Internal Operation.
In order not to overload transformations (EmbeddingSegmentsOperationSingleFeatureFusing,
EmbeddingSegmentsOperationMultipleFeaturesFusing) with additional sub-graph computing mean value of embedding
vectors, we introduce internal operation EmbeddingSegmentsMean. After these transformations, this operation
is decomposed into EmbeddingSegmentSum with appropriate computation of mean value for embedding vectors collected
for each object in a batch.
"""
op = "EmbeddingSegmentsMean"
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': self.op,
'in_ports_count': 6,
'out_ports_count': 1,
# it must have the same shape infer function as EmbeddingSegmentsSum
'infer': EmbeddingSegmentsSum.infer
}, attrs)

View File

@ -3,19 +3,21 @@
import unittest
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
from extensions.front.tf.embedding_segments_operation_fusing import EmbeddingSegmentsOperationMultipleFeaturesFusing, \
EmbeddingSegmentsOperationSingleFeatureFusing
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, const
class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
class EmbeddingSegmentsOperationFusingTest(unittest.TestCase):
def test1(self):
nodes_attributes = {
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
'op': 'Parameter'},
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'identity_spw': {'kind': 'op', 'op': 'Identity'},
@ -38,6 +40,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
'cast_indices': {'kind': 'op', 'op': 'Cast'},
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
@ -58,7 +61,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
('input_params_table', 'gather', {'out': 0, 'in': 0}),
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
@ -78,9 +81,9 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('reshape', 'tile', {'out': 0, 'in': 0}),
('tile', 'select', {'out': 0, 'in': 0}),
('select', 'last', {'out': 0, 'in': 0}),
], nodes_with_edges_only=True)
], nodes_with_edges_only=True)
graph.stage = 'front'
EmbeddingSegmentsSumFrontReplacer().find_and_replace_pattern(graph)
EmbeddingSegmentsOperationSingleFeatureFusing().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('input_indices', 'split_for_indices', {'in': 0}),
@ -89,7 +92,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
('input_values', 'embedding_segments_sum', {'in': 1}),
('input_values', 'cast_indices', {'in': 0}),
('cast_indices', 'embedding_segments_sum', {'in': 1}),
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
@ -99,8 +103,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('input_params_table', 'embedding_segments_sum', {'in': 0}),
('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
('embedding_segments_sum', 'last', {'in': 0}),],
nodes_with_edges_only=True)
('embedding_segments_sum', 'last', {'in': 0}), ],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)
@ -110,7 +114,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
'op': 'Parameter'},
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'identity_spw': {'kind': 'op', 'op': 'Identity'},
@ -126,7 +131,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'},
'identity': {'kind': 'op', 'op': 'Identity'},
'identity_1': {'kind': 'op', 'op': 'Identity'},
'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'},
'sparse_segment_mean': {'kind': 'op', 'op': 'SparseSegmentMean'},
'reshape': {'kind': 'op', 'op': 'Reshape'},
'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'},
'select': {'kind': 'op', 'op': 'Select'},
@ -135,10 +140,11 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
'cast_indices': {'kind': 'op', 'op': 'Cast'},
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
'embedding_segments_sum': {'kind': 'op', 'op': 'EmbeddingSegmentsSum'},
'embedding_segments_mean': {'kind': 'op', 'op': 'EmbeddingSegmentsMean'},
**const('split_for_indices_axis', int64_array(1)),
**const('split_for_dense_shape_axis', int64_array(0)),
@ -155,7 +161,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
('input_params_table', 'gather', {'out': 0, 'in': 0}),
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
@ -166,20 +172,20 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
('unique', 'sparse_segment_mean', {'out': 1, 'in': 1}),
('unique', 'gather', {'out': 0, 'in': 1}),
('strided_slice', 'cast', {'out': 0, 'in': 0}),
('gather', 'identity', {'out': 0, 'in': 0}),
('identity', 'identity_1', {'out': 0, 'in': 0}),
('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
('identity_1', 'sparse_segment_mean', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_mean', {'out': 0, 'in': 2}),
('sparse_segment_mean', 'select', {'out': 0, 'in': 2}),
('reshape', 'tile', {'out': 0, 'in': 0}),
('tile', 'select', {'out': 0, 'in': 0}),
('select', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
nodes_with_edges_only=True)
graph.stage = 'front'
EmbeddingSegmentsSumFrontReplacer2().find_and_replace_pattern(graph)
EmbeddingSegmentsOperationMultipleFeaturesFusing().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('input_indices', 'split_for_indices', {'in': 0}),
@ -187,18 +193,19 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('split_for_indices', 'squeeze_for_indices', {'in': 0}),
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
('input_values', 'embedding_segments_sum', {'in': 1}),
('cast_segment_ids', 'embedding_segments_mean', {'in': 2, 'out': 0}),
('input_values', 'cast_indices', {'in': 0}),
('cast_indices', 'embedding_segments_mean', {'in': 1}),
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
('squeeze_axis', 'squeeze_to_scalar', {'in': 1}),
('squeeze_to_scalar', 'cast_number_segments', {'in': 0}),
('cast_number_segments', 'embedding_segments_sum', {'in': 3, 'out': 0}),
('input_params_table', 'embedding_segments_sum', {'in': 0}),
('cast_number_segments', 'embedding_segments_mean', {'in': 3, 'out': 0}),
('input_params_table', 'embedding_segments_mean', {'in': 0}),
('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
('embedding_segments_sum', 'last', {'in': 0}),],
('cast_default_value', 'embedding_segments_mean', {'in': 4}),
('embedding_segments_mean', 'last', {'in': 0}), ],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)

View File

@ -13,12 +13,11 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
def test1(self):
nodes_attributes = {
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values' : {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'sparse_to_dense' : {'kind': 'op', 'op': 'SparseToDense'},
'broadcast' : {'kind': 'op', 'op': 'Broadcast'},
'scatternd' : {'kind': 'op', 'op': 'ScatterNDUpdate'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'scatternd': {'kind': 'op', 'op': 'ScatterNDUpdate'},
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
@ -31,19 +30,18 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
('input_values', 'sparse_to_dense', {'out': 0, 'in': 2}),
('input_default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
('sparse_to_dense', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
nodes_with_edges_only=True)
graph.stage = 'front'
SparseToDenseReplacer().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'broadcast', {'in': 0}),
[('input_default_value', 'broadcast', {'in': 0}),
('input_dense_shape', 'broadcast', {'in': 1}),
('broadcast', 'scatternd', {'in': 0}),
('input_indices', 'scatternd', {'in': 1}),
('input_values', 'scatternd', {'in': 2}),
('scatternd', 'last', {'in': 0})],
nodes_with_edges_only=True)
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)