[JAX][TF Hub][TF FE] Support XlaConvV2 operation and add JAX test (#19466)

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-08-29 12:28:12 +04:00 committed by GitHub
parent f30afa9ad3
commit 928c75623b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 298 additions and 0 deletions

View File

@ -0,0 +1,234 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <numeric>
#include "common_op_table.hpp"
#include "input_model.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convolution.hpp"
#include "openvino/op/group_conv.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/transpose.hpp"
#include "utils.hpp"
#include "xla_data.pb.h"
using namespace std;
using namespace ov;
using namespace ov::op;
using namespace xla;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
namespace {
vector<int64_t> get_const_vector(const NodeContext& node, const Output<Node>& input, const string& input_name) {
OPENVINO_SUPPRESS_DEPRECATED_START
auto input_const = get_constant_from_source(input);
TENSORFLOW_OP_VALIDATION(node, input_const, "XlaConvV2 is supported only with constant " + input_name + ".");
OPENVINO_SUPPRESS_DEPRECATED_END
return input_const->cast_vector<int64_t>();
}
void set_transpose_order_element(const NodeContext& node,
vector<int64_t>& transpose_order,
int64_t index,
int64_t value) {
int64_t size = static_cast<int64_t>(transpose_order.size());
TENSORFLOW_OP_VALIDATION(
node,
0 <= index && index < size,
"[TensorFlow Frontend] inconsistent model: output dimension is out-of-range for XlaConvV2");
TENSORFLOW_OP_VALIDATION(
node,
0 <= value && value < size,
"[TensorFlow Frontend] inconsistent model: output dimension is out-of-range for XlaConvV2");
transpose_order[index] = value;
}
bool is_identity_transpose(vector<int64_t>& transpose_order) {
vector<int64_t> ref_vector(transpose_order.size());
std::iota(ref_vector.begin(), ref_vector.end(), 0);
if (ref_vector == transpose_order) {
return true;
}
return false;
}
} // namespace
OutputVector translate_xla_conv_v2_op(const NodeContext& node) {
// see specification of XlaConvV2 here:
// https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution
default_op_checks(node, 7, {"XlaConvV2"});
auto node_name = node.get_name();
auto input = node.get_input(0);
auto kernel = node.get_input(1);
auto dimension_numbers_message = node.get_attribute<string>("dimension_numbers");
auto window_strides_vector = get_const_vector(node, node.get_input(2), "window_strides");
size_t spatial_dim = window_strides_vector.size();
TENSORFLOW_OP_VALIDATION(node,
spatial_dim == 2 || spatial_dim == 3,
"[TensorFlow Frontend] internal error: only 2D and 3D convolutions are supported");
auto padding_vector = get_const_vector(node, node.get_input(3), "padding");
TENSORFLOW_OP_VALIDATION(node,
padding_vector.size() == 2 * spatial_dim,
"[TensorFlow Frontend] inconsistent model: padding vector must contain elements equal to "
"doubled spatial dimensions ");
auto input_dilation_vector = get_const_vector(node, node.get_input(4), "lhs_dilation");
TENSORFLOW_OP_VALIDATION(
node,
input_dilation_vector.size() == spatial_dim,
"[TensorFlow Frontend] inconsistent model: input dilation vector must contain elements equal to "
"spatial dimensions");
auto kernel_dilation_vector = get_const_vector(node, node.get_input(5), "rhs_dilation");
TENSORFLOW_OP_VALIDATION(
node,
kernel_dilation_vector.size() == spatial_dim,
"[TensorFlow Frontend] inconsistent model: kernel dilation vector must contain elements equal to "
"spatial dimensions");
auto feature_group_count_vector = get_const_vector(node, node.get_input(6), "feature_group_count");
TENSORFLOW_OP_VALIDATION(
node,
feature_group_count_vector.size() == 1 && feature_group_count_vector[0] > 0,
"[TensorFlow Frontend] inconsistent model: feature_group_count input must contain one positive element.");
int64_t feature_group_count = feature_group_count_vector[0];
// check that kernel dilation is one for each dimension
// other values are not supported
bool is_all_one = true;
for (auto dilation : kernel_dilation_vector) {
if (dilation != 1) {
is_all_one = false;
break;
}
}
TENSORFLOW_OP_VALIDATION(node,
is_all_one,
"[TensorFlow Frontend] internal error: convolutional kernel with holes is not supported");
ConvolutionDimensionNumbers dimension_numbers;
TENSORFLOW_OP_VALIDATION(
node,
dimension_numbers.ParseFromArray(dimension_numbers_message.data(),
static_cast<int>(dimension_numbers_message.size())),
"[TensorFlow Frontend] Incorrect input model: incorrect ConvolutionDimensionNumbers field for XlaConvV2 " +
node_name);
if (node.get_input_size() > 7) {
// batch_group_count input presents
auto batch_group_count_vector = get_const_vector(node, node.get_input(7), "batch_group_count");
TENSORFLOW_OP_VALIDATION(
node,
batch_group_count_vector.size() == 1,
"[TensorFlow Frontend] inconsistent model: batch_group_count input must contain one element.");
TENSORFLOW_OP_VALIDATION(
node,
batch_group_count_vector[0] == 1,
"[TensorFlow Frontend] internal error: XlaConvV2 is supported only with batch_group_count equal to one.");
}
// compute permutation vectors to transpose inputs and output
vector<int64_t> input_transpose_vector = {dimension_numbers.input_batch_dimension(),
dimension_numbers.input_feature_dimension()};
input_transpose_vector.insert(input_transpose_vector.end(),
dimension_numbers.input_spatial_dimensions().begin(),
dimension_numbers.input_spatial_dimensions().end());
vector<int64_t> kernel_transpose_vector = {dimension_numbers.kernel_output_feature_dimension(),
dimension_numbers.kernel_input_feature_dimension()};
kernel_transpose_vector.insert(kernel_transpose_vector.end(),
dimension_numbers.kernel_spatial_dimensions().begin(),
dimension_numbers.kernel_spatial_dimensions().end());
// adjust inputs layout to have input and kernel of [N, C, H, W] and [Cout, Cin, H, W] layouts
if (!is_identity_transpose(input_transpose_vector)) {
auto input_transpose_order =
make_shared<v0::Constant>(element::i64, Shape{input_transpose_vector.size()}, input_transpose_vector);
input = make_shared<v1::Transpose>(input, input_transpose_order);
}
if (!is_identity_transpose(kernel_transpose_vector)) {
auto kernel_transpose_order =
make_shared<v0::Constant>(element::i64, Shape{kernel_transpose_vector.size()}, kernel_transpose_vector);
kernel = make_shared<v1::Transpose>(kernel, kernel_transpose_order);
}
// create pads_begin and pads_end vectors
Strides strides(spatial_dim);
Strides dilations(spatial_dim);
CoordinateDiff pads_begin(spatial_dim);
CoordinateDiff pads_end(spatial_dim);
for (size_t ind = 0; ind < spatial_dim; ++ind) {
strides[ind] = static_cast<size_t>(window_strides_vector[ind]);
dilations[ind] = static_cast<size_t>(input_dilation_vector[ind]);
TENSORFLOW_OP_VALIDATION(
node,
padding_vector[2 * ind] >= 0 && padding_vector[2 * ind + 1] >= 0,
"[TensorFlow Frontend] internal error: only non-negative padding is supported for convolution");
pads_begin[ind] = padding_vector[2 * ind];
pads_end[ind] = padding_vector[2 * ind + 1];
}
Output<Node> conv;
if (feature_group_count == 1) {
// use regular convolution when there is no group
conv = make_shared<v1::Convolution>(input, kernel, strides, pads_begin, pads_end, dilations, PadType::EXPLICIT);
} else {
// use group convolution
// for this, reformat kernel to have [GROUPS, C_OUT, C_IN, Z, Y, X]
// 1. compute a part of kernel shape [C_IN, Z, Y, X]
auto kernel_shape = make_shared<v3::ShapeOf>(kernel, element::i64);
auto start = make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
auto step = make_shared<v0::Constant>(ov::element::i32, Shape{1}, 1);
auto stop = make_shared<v0::Constant>(ov::element::i32, Shape{1}, numeric_limits<int>::max());
auto kernel_shape_part = make_shared<v8::Slice>(kernel_shape, start, stop, step);
// 2. create a new shape of the kernel [GROUPS, -1, C_IN, Z, Y, X]
auto feature_group_const = make_shared<v0::Constant>(ov::element::i64, Shape{1}, feature_group_count);
auto minus_one = make_shared<v0::Constant>(ov::element::i64, Shape{1}, -1);
auto new_shape = make_shared<v0::Concat>(OutputVector{feature_group_const, minus_one, kernel_shape_part}, 0);
kernel = make_shared<v1::Reshape>(kernel, new_shape, false);
// 3. compute group convolution using reformatted kernel
conv = make_shared<v1::GroupConvolution>(input,
kernel,
strides,
pads_begin,
pads_end,
dilations,
PadType::EXPLICIT);
}
// adjust output to transform to the required layout
// at this point, output is in [N, C_OUT, Z, Y, X] layout
vector<int64_t> output_transpose_vector(spatial_dim + 2, 0);
int64_t output_batch_dimension = dimension_numbers.output_batch_dimension();
int64_t output_feature_dimension = dimension_numbers.output_feature_dimension();
vector<int64_t> output_spatial_dimensions(dimension_numbers.output_spatial_dimensions().begin(),
dimension_numbers.output_spatial_dimensions().end());
TENSORFLOW_OP_VALIDATION(node,
spatial_dim == output_spatial_dimensions.size(),
"[TensorFlow Frontend] inconsistent model: output_spatial_dimensions size is not equal to "
"spatial dimensions number");
set_transpose_order_element(node, output_transpose_vector, output_batch_dimension, 0);
set_transpose_order_element(node, output_transpose_vector, output_feature_dimension, 1);
for (int64_t ind = 0; ind < static_cast<int64_t>(spatial_dim); ++ind) {
set_transpose_order_element(node, output_transpose_vector, output_spatial_dimensions[ind], ind + 2);
}
if (!is_identity_transpose(output_transpose_vector)) {
auto output_transpose_order =
make_shared<v0::Constant>(element::i64, Shape{output_transpose_vector.size()}, output_transpose_vector);
conv = make_shared<v1::Transpose>(conv, output_transpose_order);
}
set_node_name(node_name, conv.get_node_shared_ptr());
return {conv};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -46,6 +46,7 @@ TF_OP_CONVERTER(translate_varhandle_op);
TF_OP_CONVERTER(translate_variable_op); TF_OP_CONVERTER(translate_variable_op);
TF_OP_CONVERTER(translate_varisinitialized_op); TF_OP_CONVERTER(translate_varisinitialized_op);
TF_OP_CONVERTER(translate_while_op); TF_OP_CONVERTER(translate_while_op);
TF_OP_CONVERTER(translate_xla_conv_v2_op);
TF_OP_CONVERTER(translate_xla_dot_op); TF_OP_CONVERTER(translate_xla_dot_op);
const std::map<std::string, CreatorFunction> get_supported_ops() { const std::map<std::string, CreatorFunction> get_supported_ops() {
@ -306,6 +307,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"Unique", CreatorFunction(translate_unique_op)}, {"Unique", CreatorFunction(translate_unique_op)},
// XLA operations // XLA operations
{"XlaConvV2", CreatorFunction(translate_xla_conv_v2_op)},
{"XlaDotV2", CreatorFunction(translate_xla_dot_op)}, {"XlaDotV2", CreatorFunction(translate_xla_dot_op)},
}; };
}; };

View File

@ -0,0 +1,62 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from jax import lax
from jax import numpy as jnp
from jax_layer_test_class import JaxLayerTest
class TestConvGeneralDilated(JaxLayerTest):
def _prepare_input(self):
lhs = np.random.rand(*self.lhs_shape).astype(np.float32)
return [lhs]
def create_model(self, lhs_shape, rhs_shape, window_strides, padding,
lhs_dilation, dimension_numbers,
feature_group_count):
self.lhs_shape = lhs_shape
kernel = jnp.array(np.random.rand(*rhs_shape), dtype=jnp.float32)
def jax_conv_general_dilated(lhs):
out = lax.conv_general_dilated(lhs=lhs, rhs=kernel, window_strides=window_strides, padding=padding,
lhs_dilation=lhs_dilation, dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count)
return out
return jax_conv_general_dilated, None
test_data_basic = [
# regular convolution with NCHW layout for inputs and NHWC layout for output
dict(lhs_shape=[2, 3, 40, 60], rhs_shape=[4, 3, 2, 3],
dimension_numbers=('NCHW', 'OIHW', 'NHWC'), feature_group_count=1),
# group convolution with groups = 3
dict(lhs_shape=[2, 3 * 4, 20, 30], rhs_shape=[3 * 2, 4, 2, 2],
dimension_numbers=('NCHW', 'OIHW', 'NHWC'), feature_group_count=3),
# regular convolution with NHWC layout for input and NCHW layout for output
dict(lhs_shape=[1, 30, 20, 3], rhs_shape=[4, 3, 2, 3],
dimension_numbers=('NHWC', 'OIHW', 'NCHW'), feature_group_count=1),
]
@pytest.mark.parametrize("padding", [
'SAME_LOWER', 'SAME', 'VALID'
])
@pytest.mark.parametrize("window_strides", [
[1, 1], [1, 2], [3, 2]
])
@pytest.mark.parametrize("lhs_dilation", [
None, [1, 1],
# other type of lhs dilation is not supported by TF for tracing
# https://github.com/google/jax/issues/4216
])
@pytest.mark.parametrize("params", test_data_basic)
@pytest.mark.nightly
@pytest.mark.precommit
def test_conv_general_dilated(self, ie_device, precision, ir_version, params, padding, window_strides,
lhs_dilation):
self._test(*self.create_model(**params, padding=padding,
window_strides=window_strides, lhs_dilation=lhs_dilation),
ie_device, precision,
ir_version)