diff --git a/src/frontends/tensorflow/src/op/xla_conv_v2.cpp b/src/frontends/tensorflow/src/op/xla_conv_v2.cpp new file mode 100644 index 00000000000..dc2e319c9a0 --- /dev/null +++ b/src/frontends/tensorflow/src/op/xla_conv_v2.cpp @@ -0,0 +1,234 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_op_table.hpp" +#include "input_model.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convolution.hpp" +#include "openvino/op/group_conv.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/transpose.hpp" +#include "utils.hpp" +#include "xla_data.pb.h" + +using namespace std; +using namespace ov; +using namespace ov::op; +using namespace xla; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +namespace { +vector get_const_vector(const NodeContext& node, const Output& input, const string& input_name) { + OPENVINO_SUPPRESS_DEPRECATED_START + auto input_const = get_constant_from_source(input); + TENSORFLOW_OP_VALIDATION(node, input_const, "XlaConvV2 is supported only with constant " + input_name + "."); + OPENVINO_SUPPRESS_DEPRECATED_END + return input_const->cast_vector(); +} + +void set_transpose_order_element(const NodeContext& node, + vector& transpose_order, + int64_t index, + int64_t value) { + int64_t size = static_cast(transpose_order.size()); + TENSORFLOW_OP_VALIDATION( + node, + 0 <= index && index < size, + "[TensorFlow Frontend] inconsistent model: output dimension is out-of-range for XlaConvV2"); + TENSORFLOW_OP_VALIDATION( + node, + 0 <= value && value < size, + "[TensorFlow Frontend] inconsistent model: output dimension is out-of-range for XlaConvV2"); + transpose_order[index] = value; +} + +bool is_identity_transpose(vector& transpose_order) { + vector ref_vector(transpose_order.size()); + std::iota(ref_vector.begin(), ref_vector.end(), 0); + if (ref_vector == transpose_order) { + return true; + } + return false; +} + +} // namespace + +OutputVector translate_xla_conv_v2_op(const NodeContext& node) { + // see specification of XlaConvV2 here: + // https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution + default_op_checks(node, 7, {"XlaConvV2"}); + auto node_name = node.get_name(); + auto input = node.get_input(0); + auto kernel = node.get_input(1); + auto dimension_numbers_message = node.get_attribute("dimension_numbers"); + auto window_strides_vector = get_const_vector(node, node.get_input(2), "window_strides"); + size_t spatial_dim = window_strides_vector.size(); + TENSORFLOW_OP_VALIDATION(node, + spatial_dim == 2 || spatial_dim == 3, + "[TensorFlow Frontend] internal error: only 2D and 3D convolutions are supported"); + auto padding_vector = get_const_vector(node, node.get_input(3), "padding"); + TENSORFLOW_OP_VALIDATION(node, + padding_vector.size() == 2 * spatial_dim, + "[TensorFlow Frontend] inconsistent model: padding vector must contain elements equal to " + "doubled spatial dimensions "); + auto input_dilation_vector = get_const_vector(node, node.get_input(4), "lhs_dilation"); + TENSORFLOW_OP_VALIDATION( + node, + input_dilation_vector.size() == spatial_dim, + "[TensorFlow Frontend] inconsistent model: input dilation vector must contain elements equal to " + "spatial dimensions"); + auto kernel_dilation_vector = get_const_vector(node, node.get_input(5), "rhs_dilation"); + TENSORFLOW_OP_VALIDATION( + node, + kernel_dilation_vector.size() == spatial_dim, + "[TensorFlow Frontend] inconsistent model: kernel dilation vector must contain elements equal to " + "spatial dimensions"); + auto feature_group_count_vector = get_const_vector(node, node.get_input(6), "feature_group_count"); + TENSORFLOW_OP_VALIDATION( + node, + feature_group_count_vector.size() == 1 && feature_group_count_vector[0] > 0, + "[TensorFlow Frontend] inconsistent model: feature_group_count input must contain one positive element."); + int64_t feature_group_count = feature_group_count_vector[0]; + + // check that kernel dilation is one for each dimension + // other values are not supported + bool is_all_one = true; + for (auto dilation : kernel_dilation_vector) { + if (dilation != 1) { + is_all_one = false; + break; + } + } + TENSORFLOW_OP_VALIDATION(node, + is_all_one, + "[TensorFlow Frontend] internal error: convolutional kernel with holes is not supported"); + + ConvolutionDimensionNumbers dimension_numbers; + TENSORFLOW_OP_VALIDATION( + node, + dimension_numbers.ParseFromArray(dimension_numbers_message.data(), + static_cast(dimension_numbers_message.size())), + "[TensorFlow Frontend] Incorrect input model: incorrect ConvolutionDimensionNumbers field for XlaConvV2 " + + node_name); + + if (node.get_input_size() > 7) { + // batch_group_count input presents + auto batch_group_count_vector = get_const_vector(node, node.get_input(7), "batch_group_count"); + TENSORFLOW_OP_VALIDATION( + node, + batch_group_count_vector.size() == 1, + "[TensorFlow Frontend] inconsistent model: batch_group_count input must contain one element."); + TENSORFLOW_OP_VALIDATION( + node, + batch_group_count_vector[0] == 1, + "[TensorFlow Frontend] internal error: XlaConvV2 is supported only with batch_group_count equal to one."); + } + + // compute permutation vectors to transpose inputs and output + vector input_transpose_vector = {dimension_numbers.input_batch_dimension(), + dimension_numbers.input_feature_dimension()}; + input_transpose_vector.insert(input_transpose_vector.end(), + dimension_numbers.input_spatial_dimensions().begin(), + dimension_numbers.input_spatial_dimensions().end()); + vector kernel_transpose_vector = {dimension_numbers.kernel_output_feature_dimension(), + dimension_numbers.kernel_input_feature_dimension()}; + kernel_transpose_vector.insert(kernel_transpose_vector.end(), + dimension_numbers.kernel_spatial_dimensions().begin(), + dimension_numbers.kernel_spatial_dimensions().end()); + + // adjust inputs layout to have input and kernel of [N, C, H, W] and [Cout, Cin, H, W] layouts + if (!is_identity_transpose(input_transpose_vector)) { + auto input_transpose_order = + make_shared(element::i64, Shape{input_transpose_vector.size()}, input_transpose_vector); + input = make_shared(input, input_transpose_order); + } + if (!is_identity_transpose(kernel_transpose_vector)) { + auto kernel_transpose_order = + make_shared(element::i64, Shape{kernel_transpose_vector.size()}, kernel_transpose_vector); + kernel = make_shared(kernel, kernel_transpose_order); + } + + // create pads_begin and pads_end vectors + Strides strides(spatial_dim); + Strides dilations(spatial_dim); + CoordinateDiff pads_begin(spatial_dim); + CoordinateDiff pads_end(spatial_dim); + for (size_t ind = 0; ind < spatial_dim; ++ind) { + strides[ind] = static_cast(window_strides_vector[ind]); + dilations[ind] = static_cast(input_dilation_vector[ind]); + TENSORFLOW_OP_VALIDATION( + node, + padding_vector[2 * ind] >= 0 && padding_vector[2 * ind + 1] >= 0, + "[TensorFlow Frontend] internal error: only non-negative padding is supported for convolution"); + pads_begin[ind] = padding_vector[2 * ind]; + pads_end[ind] = padding_vector[2 * ind + 1]; + } + + Output conv; + if (feature_group_count == 1) { + // use regular convolution when there is no group + conv = make_shared(input, kernel, strides, pads_begin, pads_end, dilations, PadType::EXPLICIT); + } else { + // use group convolution + // for this, reformat kernel to have [GROUPS, C_OUT, C_IN, Z, Y, X] + // 1. compute a part of kernel shape [C_IN, Z, Y, X] + auto kernel_shape = make_shared(kernel, element::i64); + auto start = make_shared(ov::element::i32, Shape{1}, 1); + auto step = make_shared(ov::element::i32, Shape{1}, 1); + auto stop = make_shared(ov::element::i32, Shape{1}, numeric_limits::max()); + auto kernel_shape_part = make_shared(kernel_shape, start, stop, step); + // 2. create a new shape of the kernel [GROUPS, -1, C_IN, Z, Y, X] + auto feature_group_const = make_shared(ov::element::i64, Shape{1}, feature_group_count); + auto minus_one = make_shared(ov::element::i64, Shape{1}, -1); + auto new_shape = make_shared(OutputVector{feature_group_const, minus_one, kernel_shape_part}, 0); + kernel = make_shared(kernel, new_shape, false); + // 3. compute group convolution using reformatted kernel + conv = make_shared(input, + kernel, + strides, + pads_begin, + pads_end, + dilations, + PadType::EXPLICIT); + } + + // adjust output to transform to the required layout + // at this point, output is in [N, C_OUT, Z, Y, X] layout + vector output_transpose_vector(spatial_dim + 2, 0); + int64_t output_batch_dimension = dimension_numbers.output_batch_dimension(); + int64_t output_feature_dimension = dimension_numbers.output_feature_dimension(); + vector output_spatial_dimensions(dimension_numbers.output_spatial_dimensions().begin(), + dimension_numbers.output_spatial_dimensions().end()); + TENSORFLOW_OP_VALIDATION(node, + spatial_dim == output_spatial_dimensions.size(), + "[TensorFlow Frontend] inconsistent model: output_spatial_dimensions size is not equal to " + "spatial dimensions number"); + set_transpose_order_element(node, output_transpose_vector, output_batch_dimension, 0); + set_transpose_order_element(node, output_transpose_vector, output_feature_dimension, 1); + for (int64_t ind = 0; ind < static_cast(spatial_dim); ++ind) { + set_transpose_order_element(node, output_transpose_vector, output_spatial_dimensions[ind], ind + 2); + } + if (!is_identity_transpose(output_transpose_vector)) { + auto output_transpose_order = + make_shared(element::i64, Shape{output_transpose_vector.size()}, output_transpose_vector); + conv = make_shared(conv, output_transpose_order); + } + + set_node_name(node_name, conv.get_node_shared_ptr()); + return {conv}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 041231cc8c4..2ff98b810b1 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -46,6 +46,7 @@ TF_OP_CONVERTER(translate_varhandle_op); TF_OP_CONVERTER(translate_variable_op); TF_OP_CONVERTER(translate_varisinitialized_op); TF_OP_CONVERTER(translate_while_op); +TF_OP_CONVERTER(translate_xla_conv_v2_op); TF_OP_CONVERTER(translate_xla_dot_op); const std::map get_supported_ops() { @@ -306,6 +307,7 @@ const std::map get_supported_ops() { {"Unique", CreatorFunction(translate_unique_op)}, // XLA operations + {"XlaConvV2", CreatorFunction(translate_xla_conv_v2_op)}, {"XlaDotV2", CreatorFunction(translate_xla_dot_op)}, }; }; diff --git a/tests/layer_tests/jax_tests/test_conv_general_dilated.py b/tests/layer_tests/jax_tests/test_conv_general_dilated.py new file mode 100644 index 00000000000..81123ff49c5 --- /dev/null +++ b/tests/layer_tests/jax_tests/test_conv_general_dilated.py @@ -0,0 +1,62 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +from jax import lax +from jax import numpy as jnp + +from jax_layer_test_class import JaxLayerTest + + +class TestConvGeneralDilated(JaxLayerTest): + def _prepare_input(self): + lhs = np.random.rand(*self.lhs_shape).astype(np.float32) + return [lhs] + + def create_model(self, lhs_shape, rhs_shape, window_strides, padding, + lhs_dilation, dimension_numbers, + feature_group_count): + self.lhs_shape = lhs_shape + kernel = jnp.array(np.random.rand(*rhs_shape), dtype=jnp.float32) + + def jax_conv_general_dilated(lhs): + out = lax.conv_general_dilated(lhs=lhs, rhs=kernel, window_strides=window_strides, padding=padding, + lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers, + feature_group_count=feature_group_count) + return out + + return jax_conv_general_dilated, None + + test_data_basic = [ + # regular convolution with NCHW layout for inputs and NHWC layout for output + dict(lhs_shape=[2, 3, 40, 60], rhs_shape=[4, 3, 2, 3], + dimension_numbers=('NCHW', 'OIHW', 'NHWC'), feature_group_count=1), + # group convolution with groups = 3 + dict(lhs_shape=[2, 3 * 4, 20, 30], rhs_shape=[3 * 2, 4, 2, 2], + dimension_numbers=('NCHW', 'OIHW', 'NHWC'), feature_group_count=3), + # regular convolution with NHWC layout for input and NCHW layout for output + dict(lhs_shape=[1, 30, 20, 3], rhs_shape=[4, 3, 2, 3], + dimension_numbers=('NHWC', 'OIHW', 'NCHW'), feature_group_count=1), + ] + + @pytest.mark.parametrize("padding", [ + 'SAME_LOWER', 'SAME', 'VALID' + ]) + @pytest.mark.parametrize("window_strides", [ + [1, 1], [1, 2], [3, 2] + ]) + @pytest.mark.parametrize("lhs_dilation", [ + None, [1, 1], + # other type of lhs dilation is not supported by TF for tracing + # https://github.com/google/jax/issues/4216 + ]) + @pytest.mark.parametrize("params", test_data_basic) + @pytest.mark.nightly + @pytest.mark.precommit + def test_conv_general_dilated(self, ie_device, precision, ir_version, params, padding, window_strides, + lhs_dilation): + self._test(*self.create_model(**params, padding=padding, + window_strides=window_strides, lhs_dilation=lhs_dilation), + ie_device, precision, + ir_version)