[TF FE][TF Hub] Support BatchMatMulV3 operation (#20528)

* [TF FE][TF Hub] Support BatchMatMulV3 operation

* Update src/frontends/tensorflow_common/src/op/matmul.cpp

* Update src/frontends/tensorflow_common/src/op/matmul.cpp

---------

Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Siddhant Chauhan
2023-10-18 11:19:33 +05:30
committed by GitHub
parent 3b2ad48d79
commit a30e25c725
4 changed files with 25 additions and 3 deletions

View File

@@ -130,6 +130,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"AvgPool3D", CreatorFunction(translate_avg_pool_op)},
{"BatchMatMul", CreatorFunction(translate_batch_mat_mul_op)},
{"BatchMatMulV2", CreatorFunction(translate_batch_mat_mul_op)},
{"BatchMatMulV3", CreatorFunction(translate_batch_mat_mul_with_type_op)},
{"BatchToSpaceND", CreatorFunction(translate_batch_to_space_nd_op)},
{"BroadcastArgs", CreatorFunction(translate_broadcast_args_op)},
{"BroadcastTo", CreatorFunction(translate_broadcast_to_op)},

View File

@@ -38,6 +38,7 @@ OP_CONVERTER(translate_arg_max_op);
OP_CONVERTER(translate_arg_min_op);
OP_CONVERTER(translate_avg_pool_op);
OP_CONVERTER(translate_batch_mat_mul_op);
OP_CONVERTER(translate_batch_mat_mul_with_type_op);
OP_CONVERTER(translate_batch_to_space_nd_op);
OP_CONVERTER(translate_bias_add_op);
OP_CONVERTER(translate_broadcast_args_op);

View File

@@ -35,6 +35,26 @@ OutputVector translate_batch_mat_mul_op(const NodeContext& node) {
set_node_name(node.get_name(), result);
return result->outputs();
}
OutputVector translate_batch_mat_mul_with_type_op(const NodeContext& node) {
auto x = node.get_input(0);
auto y = node.get_input(1);
auto input_type = x.get_element_type();
auto adj_x = node.get_attribute<bool>("adj_x", false);
auto adj_y = node.get_attribute<bool>("adj_y", false);
auto t_out = node.get_attribute<ov::element::Type>("Tout", input_type);
auto result = make_shared<MatMul>(x, y, adj_x, adj_y)->output(0);
if (t_out != input_type) {
result = make_shared<Convert>(result, t_out);
}
set_node_name(node.get_name(), result.get_node_shared_ptr());
return {result};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend

View File

@@ -36,7 +36,7 @@ class TestMatMul(CommonTFLayerTest):
elif op_type == 'BatchMatMulV3':
op_type_to_tf[op_type](x=tf_x, y=tf_y, Tout=tf.float32, adj_x=x_bool, adj_y=y_bool, name='Operation')
else:
raise RuntimeError("Undknown operation")
raise RuntimeError("Unknown operation")
tf.compat.v1.global_variables_initializer()
tf_net = sess.graph_def
@@ -53,7 +53,7 @@ class TestMatMul(CommonTFLayerTest):
@pytest.mark.parametrize("params", test_data_precommit)
@pytest.mark.parametrize("op_type", ['BatchMatMul',
'BatchMatMulV2',
#'BatchMatMulV3', #Isn't supported
'BatchMatMulV3',
'MatMul',
])
@pytest.mark.precommit_tf_fe
@@ -72,7 +72,7 @@ class TestMatMul(CommonTFLayerTest):
@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("op_type", ['BatchMatMul',
'BatchMatMulV2',
#'BatchMatMulV3', #Isn't supported
'BatchMatMulV3',
'MatMul',
])
@pytest.mark.parametrize("x_bool", [