[TF FE][TF Hub] Support BatchMatMulV3 operation (#20528)
* [TF FE][TF Hub] Support BatchMatMulV3 operation * Update src/frontends/tensorflow_common/src/op/matmul.cpp * Update src/frontends/tensorflow_common/src/op/matmul.cpp --------- Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
@@ -130,6 +130,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"AvgPool3D", CreatorFunction(translate_avg_pool_op)},
|
||||
{"BatchMatMul", CreatorFunction(translate_batch_mat_mul_op)},
|
||||
{"BatchMatMulV2", CreatorFunction(translate_batch_mat_mul_op)},
|
||||
{"BatchMatMulV3", CreatorFunction(translate_batch_mat_mul_with_type_op)},
|
||||
{"BatchToSpaceND", CreatorFunction(translate_batch_to_space_nd_op)},
|
||||
{"BroadcastArgs", CreatorFunction(translate_broadcast_args_op)},
|
||||
{"BroadcastTo", CreatorFunction(translate_broadcast_to_op)},
|
||||
|
||||
@@ -38,6 +38,7 @@ OP_CONVERTER(translate_arg_max_op);
|
||||
OP_CONVERTER(translate_arg_min_op);
|
||||
OP_CONVERTER(translate_avg_pool_op);
|
||||
OP_CONVERTER(translate_batch_mat_mul_op);
|
||||
OP_CONVERTER(translate_batch_mat_mul_with_type_op);
|
||||
OP_CONVERTER(translate_batch_to_space_nd_op);
|
||||
OP_CONVERTER(translate_bias_add_op);
|
||||
OP_CONVERTER(translate_broadcast_args_op);
|
||||
|
||||
@@ -35,6 +35,26 @@ OutputVector translate_batch_mat_mul_op(const NodeContext& node) {
|
||||
set_node_name(node.get_name(), result);
|
||||
return result->outputs();
|
||||
}
|
||||
|
||||
OutputVector translate_batch_mat_mul_with_type_op(const NodeContext& node) {
|
||||
auto x = node.get_input(0);
|
||||
auto y = node.get_input(1);
|
||||
|
||||
auto input_type = x.get_element_type();
|
||||
|
||||
auto adj_x = node.get_attribute<bool>("adj_x", false);
|
||||
auto adj_y = node.get_attribute<bool>("adj_y", false);
|
||||
auto t_out = node.get_attribute<ov::element::Type>("Tout", input_type);
|
||||
|
||||
auto result = make_shared<MatMul>(x, y, adj_x, adj_y)->output(0);
|
||||
|
||||
if (t_out != input_type) {
|
||||
result = make_shared<Convert>(result, t_out);
|
||||
}
|
||||
|
||||
set_node_name(node.get_name(), result.get_node_shared_ptr());
|
||||
return {result};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace tensorflow
|
||||
} // namespace frontend
|
||||
|
||||
@@ -36,7 +36,7 @@ class TestMatMul(CommonTFLayerTest):
|
||||
elif op_type == 'BatchMatMulV3':
|
||||
op_type_to_tf[op_type](x=tf_x, y=tf_y, Tout=tf.float32, adj_x=x_bool, adj_y=y_bool, name='Operation')
|
||||
else:
|
||||
raise RuntimeError("Undknown operation")
|
||||
raise RuntimeError("Unknown operation")
|
||||
|
||||
tf.compat.v1.global_variables_initializer()
|
||||
tf_net = sess.graph_def
|
||||
@@ -53,7 +53,7 @@ class TestMatMul(CommonTFLayerTest):
|
||||
@pytest.mark.parametrize("params", test_data_precommit)
|
||||
@pytest.mark.parametrize("op_type", ['BatchMatMul',
|
||||
'BatchMatMulV2',
|
||||
#'BatchMatMulV3', #Isn't supported
|
||||
'BatchMatMulV3',
|
||||
'MatMul',
|
||||
])
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@@ -72,7 +72,7 @@ class TestMatMul(CommonTFLayerTest):
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.parametrize("op_type", ['BatchMatMul',
|
||||
'BatchMatMulV2',
|
||||
#'BatchMatMulV3', #Isn't supported
|
||||
'BatchMatMulV3',
|
||||
'MatMul',
|
||||
])
|
||||
@pytest.mark.parametrize("x_bool", [
|
||||
|
||||
Reference in New Issue
Block a user