[MO, TF] Support Custom Wide and Deep CTR model by MO (#8505)

* [MO, TF] Support Custom Wide and Deep CTR model by MO

It implements implicit support of EmbeddingSegmentsMean operation through decomposition.
Also, this extends the current transformation to fuse TensorFlow sub-graph (for Wide and Deep model family)
containing SparseSegmentSum and SparseSegmentMean operations into EmbeddingSegmentsSum or EmbeddingSegmentsMean.

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix unit-tests after modifications of SparseToDense and EmbeddingSegmentsOperationFusing

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Document SparseSegmentMean support

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add computation scheme for normalization coeffs and correct documentation

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev
2021-11-17 11:44:04 +03:00
committed by GitHub
parent c307f206dc
commit 8e327bd2ff
10 changed files with 329 additions and 128 deletions

View File

@@ -3,19 +3,21 @@
import unittest
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
from extensions.front.tf.embedding_segments_operation_fusing import EmbeddingSegmentsOperationMultipleFeaturesFusing, \
EmbeddingSegmentsOperationSingleFeatureFusing
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, const
class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
class EmbeddingSegmentsOperationFusingTest(unittest.TestCase):
def test1(self):
nodes_attributes = {
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
'op': 'Parameter'},
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'identity_spw': {'kind': 'op', 'op': 'Identity'},
@@ -38,6 +40,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
'cast_indices': {'kind': 'op', 'op': 'Cast'},
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
@@ -58,7 +61,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
('input_params_table', 'gather', {'out': 0, 'in': 0}),
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
@@ -78,9 +81,9 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('reshape', 'tile', {'out': 0, 'in': 0}),
('tile', 'select', {'out': 0, 'in': 0}),
('select', 'last', {'out': 0, 'in': 0}),
], nodes_with_edges_only=True)
], nodes_with_edges_only=True)
graph.stage = 'front'
EmbeddingSegmentsSumFrontReplacer().find_and_replace_pattern(graph)
EmbeddingSegmentsOperationSingleFeatureFusing().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('input_indices', 'split_for_indices', {'in': 0}),
@@ -89,7 +92,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
('input_values', 'embedding_segments_sum', {'in': 1}),
('input_values', 'cast_indices', {'in': 0}),
('cast_indices', 'embedding_segments_sum', {'in': 1}),
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
@@ -99,8 +103,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('input_params_table', 'embedding_segments_sum', {'in': 0}),
('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
('embedding_segments_sum', 'last', {'in': 0}),],
nodes_with_edges_only=True)
('embedding_segments_sum', 'last', {'in': 0}), ],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)
@@ -110,7 +114,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
'op': 'Parameter'},
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'identity_spw': {'kind': 'op', 'op': 'Identity'},
@@ -126,7 +131,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'},
'identity': {'kind': 'op', 'op': 'Identity'},
'identity_1': {'kind': 'op', 'op': 'Identity'},
'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'},
'sparse_segment_mean': {'kind': 'op', 'op': 'SparseSegmentMean'},
'reshape': {'kind': 'op', 'op': 'Reshape'},
'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'},
'select': {'kind': 'op', 'op': 'Select'},
@@ -135,10 +140,11 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
'cast_indices': {'kind': 'op', 'op': 'Cast'},
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
'embedding_segments_sum': {'kind': 'op', 'op': 'EmbeddingSegmentsSum'},
'embedding_segments_mean': {'kind': 'op', 'op': 'EmbeddingSegmentsMean'},
**const('split_for_indices_axis', int64_array(1)),
**const('split_for_dense_shape_axis', int64_array(0)),
@@ -155,7 +161,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
('input_params_table', 'gather', {'out': 0, 'in': 0}),
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
@@ -166,20 +172,20 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
('unique', 'sparse_segment_mean', {'out': 1, 'in': 1}),
('unique', 'gather', {'out': 0, 'in': 1}),
('strided_slice', 'cast', {'out': 0, 'in': 0}),
('gather', 'identity', {'out': 0, 'in': 0}),
('identity', 'identity_1', {'out': 0, 'in': 0}),
('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
('identity_1', 'sparse_segment_mean', {'out': 0, 'in': 0}),
('cast', 'sparse_segment_mean', {'out': 0, 'in': 2}),
('sparse_segment_mean', 'select', {'out': 0, 'in': 2}),
('reshape', 'tile', {'out': 0, 'in': 0}),
('tile', 'select', {'out': 0, 'in': 0}),
('select', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
nodes_with_edges_only=True)
graph.stage = 'front'
EmbeddingSegmentsSumFrontReplacer2().find_and_replace_pattern(graph)
EmbeddingSegmentsOperationMultipleFeaturesFusing().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('input_indices', 'split_for_indices', {'in': 0}),
@@ -187,18 +193,19 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
('split_for_indices', 'squeeze_for_indices', {'in': 0}),
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
('input_values', 'embedding_segments_sum', {'in': 1}),
('cast_segment_ids', 'embedding_segments_mean', {'in': 2, 'out': 0}),
('input_values', 'cast_indices', {'in': 0}),
('cast_indices', 'embedding_segments_mean', {'in': 1}),
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
('squeeze_axis', 'squeeze_to_scalar', {'in': 1}),
('squeeze_to_scalar', 'cast_number_segments', {'in': 0}),
('cast_number_segments', 'embedding_segments_sum', {'in': 3, 'out': 0}),
('input_params_table', 'embedding_segments_sum', {'in': 0}),
('cast_number_segments', 'embedding_segments_mean', {'in': 3, 'out': 0}),
('input_params_table', 'embedding_segments_mean', {'in': 0}),
('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
('embedding_segments_sum', 'last', {'in': 0}),],
('cast_default_value', 'embedding_segments_mean', {'in': 4}),
('embedding_segments_mean', 'last', {'in': 0}), ],
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)

View File

@@ -13,12 +13,11 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
def test1(self):
nodes_attributes = {
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values' : {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'sparse_to_dense' : {'kind': 'op', 'op': 'SparseToDense'},
'broadcast' : {'kind': 'op', 'op': 'Broadcast'},
'scatternd' : {'kind': 'op', 'op': 'ScatterNDUpdate'},
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'scatternd': {'kind': 'op', 'op': 'ScatterNDUpdate'},
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
@@ -31,19 +30,18 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
('input_values', 'sparse_to_dense', {'out': 0, 'in': 2}),
('input_default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
('sparse_to_dense', 'last', {'out': 0, 'in': 0})],
nodes_with_edges_only=True)
nodes_with_edges_only=True)
graph.stage = 'front'
SparseToDenseReplacer().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_attributes,
[('input_default_value', 'cast_default_value', {'in': 0}),
('cast_default_value', 'broadcast', {'in': 0}),
[('input_default_value', 'broadcast', {'in': 0}),
('input_dense_shape', 'broadcast', {'in': 1}),
('broadcast', 'scatternd', {'in': 0}),
('input_indices', 'scatternd', {'in': 1}),
('input_values', 'scatternd', {'in': 2}),
('scatternd', 'last', {'in': 0})],
nodes_with_edges_only=True)
nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
self.assertTrue(flag, resp)