[MO, TF] Support Custom Wide and Deep CTR model by MO (#8505)
* [MO, TF] Support Custom Wide and Deep CTR model by MO It implements implicit support of EmbeddingSegmentsMean operation through decomposition. Also, this extends the current transformation to fuse TensorFlow sub-graph (for Wide and Deep model family) containing SparseSegmentSum and SparseSegmentMean operations into EmbeddingSegmentsSum or EmbeddingSegmentsMean. Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix unit-tests after modifications of SparseToDense and EmbeddingSegmentsOperationFusing Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Document SparseSegmentMean support Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Add computation scheme for normalization coeffs and correct documentation Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -3,19 +3,21 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from extensions.front.tf.embedding_segments_sum import EmbeddingSegmentsSumFrontReplacer, EmbeddingSegmentsSumFrontReplacer2
|
||||
from extensions.front.tf.embedding_segments_operation_fusing import EmbeddingSegmentsOperationMultipleFeaturesFusing, \
|
||||
EmbeddingSegmentsOperationSingleFeatureFusing
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, const
|
||||
|
||||
|
||||
class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
class EmbeddingSegmentsOperationFusingTest(unittest.TestCase):
|
||||
def test1(self):
|
||||
nodes_attributes = {
|
||||
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
|
||||
'op': 'Parameter'},
|
||||
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'identity_spw': {'kind': 'op', 'op': 'Identity'},
|
||||
@@ -38,6 +40,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
|
||||
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'cast_indices': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
|
||||
@@ -58,7 +61,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
|
||||
('input_params_table', 'gather', {'out': 0, 'in': 0}),
|
||||
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
|
||||
|
||||
|
||||
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
|
||||
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
|
||||
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
|
||||
@@ -78,9 +81,9 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('reshape', 'tile', {'out': 0, 'in': 0}),
|
||||
('tile', 'select', {'out': 0, 'in': 0}),
|
||||
('select', 'last', {'out': 0, 'in': 0}),
|
||||
], nodes_with_edges_only=True)
|
||||
], nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
EmbeddingSegmentsSumFrontReplacer().find_and_replace_pattern(graph)
|
||||
EmbeddingSegmentsOperationSingleFeatureFusing().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('input_indices', 'split_for_indices', {'in': 0}),
|
||||
@@ -89,7 +92,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
|
||||
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
|
||||
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
|
||||
('input_values', 'embedding_segments_sum', {'in': 1}),
|
||||
('input_values', 'cast_indices', {'in': 0}),
|
||||
('cast_indices', 'embedding_segments_sum', {'in': 1}),
|
||||
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
|
||||
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
|
||||
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
|
||||
@@ -99,8 +103,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('input_params_table', 'embedding_segments_sum', {'in': 0}),
|
||||
('input_default_value', 'cast_default_value', {'in': 0}),
|
||||
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
|
||||
('embedding_segments_sum', 'last', {'in': 0}),],
|
||||
nodes_with_edges_only=True)
|
||||
('embedding_segments_sum', 'last', {'in': 0}), ],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
@@ -110,7 +114,8 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_dense_shape': {'shape': int64_array([2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_params_table': {'shape': int64_array([10, 3, 4]), 'type': 'Parameter', 'kind': 'op',
|
||||
'op': 'Parameter'},
|
||||
'input_default_value': {'shape': int64_array([]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'identity_spw': {'kind': 'op', 'op': 'Identity'},
|
||||
@@ -126,7 +131,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'gather': {'kind': 'op', 'op': 'Gather', 'type': 'Gather'},
|
||||
'identity': {'kind': 'op', 'op': 'Identity'},
|
||||
'identity_1': {'kind': 'op', 'op': 'Identity'},
|
||||
'sparse_segment_sum': {'kind': 'op', 'op': 'SparseSegmentSum'},
|
||||
'sparse_segment_mean': {'kind': 'op', 'op': 'SparseSegmentMean'},
|
||||
'reshape': {'kind': 'op', 'op': 'Reshape'},
|
||||
'tile': {'kind': 'op', 'op': 'Tile', 'type': 'Tile'},
|
||||
'select': {'kind': 'op', 'op': 'Select'},
|
||||
@@ -135,10 +140,11 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
'squeeze_for_indices': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'split_for_dense_shape': {'kind': 'op', 'op': 'Split'},
|
||||
'squeeze_to_scalar': {'kind': 'op', 'op': 'Squeeze'},
|
||||
'cast_indices': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_segment_ids': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
|
||||
'cast_number_segments': {'kind': 'op', 'op': 'Cast'},
|
||||
'embedding_segments_sum': {'kind': 'op', 'op': 'EmbeddingSegmentsSum'},
|
||||
'embedding_segments_mean': {'kind': 'op', 'op': 'EmbeddingSegmentsMean'},
|
||||
|
||||
**const('split_for_indices_axis', int64_array(1)),
|
||||
**const('split_for_dense_shape_axis', int64_array(0)),
|
||||
@@ -155,7 +161,7 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('input_values', 'gather0_2', {'out': 0, 'in': 0}),
|
||||
('input_params_table', 'gather', {'out': 0, 'in': 0}),
|
||||
('input_default_value', 'sparse_fill_empty_rows', {'out': 0, 'in': 3}),
|
||||
|
||||
|
||||
('identity_spw', 'sparse_fill_empty_rows', {'out': 0, 'in': 2}),
|
||||
('gather0_1', 'sparse_fill_empty_rows', {'out': 0, 'in': 0}),
|
||||
('gather0_2', 'sparse_fill_empty_rows', {'out': 0, 'in': 1}),
|
||||
@@ -166,20 +172,20 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('sparse_fill_empty_rows', 'unique', {'out': 1, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'strided_slice', {'out': 0, 'in': 0}),
|
||||
('sparse_fill_empty_rows', 'reshape', {'out': 2, 'in': 0}),
|
||||
('unique', 'sparse_segment_sum', {'out': 1, 'in': 1}),
|
||||
('unique', 'sparse_segment_mean', {'out': 1, 'in': 1}),
|
||||
('unique', 'gather', {'out': 0, 'in': 1}),
|
||||
('strided_slice', 'cast', {'out': 0, 'in': 0}),
|
||||
('gather', 'identity', {'out': 0, 'in': 0}),
|
||||
('identity', 'identity_1', {'out': 0, 'in': 0}),
|
||||
('identity_1', 'sparse_segment_sum', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_sum', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_sum', 'select', {'out': 0, 'in': 2}),
|
||||
('identity_1', 'sparse_segment_mean', {'out': 0, 'in': 0}),
|
||||
('cast', 'sparse_segment_mean', {'out': 0, 'in': 2}),
|
||||
('sparse_segment_mean', 'select', {'out': 0, 'in': 2}),
|
||||
('reshape', 'tile', {'out': 0, 'in': 0}),
|
||||
('tile', 'select', {'out': 0, 'in': 0}),
|
||||
('select', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
EmbeddingSegmentsSumFrontReplacer2().find_and_replace_pattern(graph)
|
||||
EmbeddingSegmentsOperationMultipleFeaturesFusing().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('input_indices', 'split_for_indices', {'in': 0}),
|
||||
@@ -187,18 +193,19 @@ class EmbeddingSegmentsSumFrontReplacerFrontReplacersTest(unittest.TestCase):
|
||||
('split_for_indices', 'squeeze_for_indices', {'in': 0}),
|
||||
('squeeze_for_indices_axis', 'squeeze_for_indices', {'in': 1}),
|
||||
('squeeze_for_indices', 'cast_segment_ids', {'in': 0}),
|
||||
('cast_segment_ids', 'embedding_segments_sum', {'in': 2, 'out': 0}),
|
||||
('input_values', 'embedding_segments_sum', {'in': 1}),
|
||||
('cast_segment_ids', 'embedding_segments_mean', {'in': 2, 'out': 0}),
|
||||
('input_values', 'cast_indices', {'in': 0}),
|
||||
('cast_indices', 'embedding_segments_mean', {'in': 1}),
|
||||
('input_dense_shape', 'split_for_dense_shape', {'in': 0}),
|
||||
('split_for_dense_shape_axis', 'split_for_dense_shape', {'in': 1}),
|
||||
('split_for_dense_shape', 'squeeze_to_scalar', {'in': 0}),
|
||||
('squeeze_axis', 'squeeze_to_scalar', {'in': 1}),
|
||||
('squeeze_to_scalar', 'cast_number_segments', {'in': 0}),
|
||||
('cast_number_segments', 'embedding_segments_sum', {'in': 3, 'out': 0}),
|
||||
('input_params_table', 'embedding_segments_sum', {'in': 0}),
|
||||
('cast_number_segments', 'embedding_segments_mean', {'in': 3, 'out': 0}),
|
||||
('input_params_table', 'embedding_segments_mean', {'in': 0}),
|
||||
('input_default_value', 'cast_default_value', {'in': 0}),
|
||||
('cast_default_value', 'embedding_segments_sum', {'in': 4}),
|
||||
('embedding_segments_sum', 'last', {'in': 0}),],
|
||||
('cast_default_value', 'embedding_segments_mean', {'in': 4}),
|
||||
('embedding_segments_mean', 'last', {'in': 0}), ],
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
@@ -13,12 +13,11 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
|
||||
def test1(self):
|
||||
nodes_attributes = {
|
||||
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values' : {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'input_values': {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
|
||||
'sparse_to_dense' : {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'broadcast' : {'kind': 'op', 'op': 'Broadcast'},
|
||||
'scatternd' : {'kind': 'op', 'op': 'ScatterNDUpdate'},
|
||||
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
|
||||
'sparse_to_dense': {'kind': 'op', 'op': 'SparseToDense'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'scatternd': {'kind': 'op', 'op': 'ScatterNDUpdate'},
|
||||
|
||||
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
||||
|
||||
@@ -31,19 +30,18 @@ class SparseToDenseFrontReplacersTest(unittest.TestCase):
|
||||
('input_values', 'sparse_to_dense', {'out': 0, 'in': 2}),
|
||||
('input_default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
|
||||
('sparse_to_dense', 'last', {'out': 0, 'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
SparseToDenseReplacer().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('input_default_value', 'cast_default_value', {'in': 0}),
|
||||
('cast_default_value', 'broadcast', {'in': 0}),
|
||||
[('input_default_value', 'broadcast', {'in': 0}),
|
||||
('input_dense_shape', 'broadcast', {'in': 1}),
|
||||
('broadcast', 'scatternd', {'in': 0}),
|
||||
('input_indices', 'scatternd', {'in': 1}),
|
||||
('input_values', 'scatternd', {'in': 2}),
|
||||
('scatternd', 'last', {'in': 0})],
|
||||
nodes_with_edges_only=True)
|
||||
nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
Reference in New Issue
Block a user