[JAX][TF Hub][TF FE] Introduce JAX layer tests and support of XLA operations (#19269)
* [JAX][TF Hub][TF FE] Introduce JAX layer tests and support of XLA operations Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Fix JAX layer tests infa Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com> * Extend run for JAX layer tests * Use ovc convert_model * Fix translator and extend layer test cases * Exclude jax testing on Windows --------- Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
parent
cbe744b717
commit
5539d052b0
@ -525,6 +525,15 @@ jobs:
|
|||||||
TEST_DEVICE: CPU
|
TEST_DEVICE: CPU
|
||||||
displayName: 'TensorFlow 2 Layer Tests - TF FE'
|
displayName: 'TensorFlow 2 Layer Tests - TF FE'
|
||||||
|
|
||||||
|
- script: |
|
||||||
|
set -e
|
||||||
|
python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt
|
||||||
|
$(RUN_PREFIX) python3 -m pytest $(LAYER_TESTS_DIR)/jax_tests/ -m precommit --junitxml=$(INSTALL_TEST_DIR)/TEST-jax.xmlTEST
|
||||||
|
env:
|
||||||
|
PYTHONPATH: $(LAYER_TESTS_DIR)
|
||||||
|
TEST_DEVICE: CPU
|
||||||
|
displayName: 'JAX Layer Tests - TF FE'
|
||||||
|
|
||||||
- script: |
|
- script: |
|
||||||
set -e
|
set -e
|
||||||
python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt
|
python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt
|
||||||
|
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -92,6 +92,7 @@
|
|||||||
/tests/layer_tests/ @openvinotoolkit/openvino-tests-maintainers @openvinotoolkit/openvino-mo-maintainers
|
/tests/layer_tests/ @openvinotoolkit/openvino-tests-maintainers @openvinotoolkit/openvino-mo-maintainers
|
||||||
/tests/layer_tests/pytorch_tests/ @openvinotoolkit/openvino-pytorch-frontend-maintainers
|
/tests/layer_tests/pytorch_tests/ @openvinotoolkit/openvino-pytorch-frontend-maintainers
|
||||||
/tests/layer_tests/tensorflow_tests @openvinotoolkit/openvino-tf-frontend-maintainers
|
/tests/layer_tests/tensorflow_tests @openvinotoolkit/openvino-tf-frontend-maintainers
|
||||||
|
/tests/layer_tests/jax_tests @openvinotoolkit/openvino-tf-frontend-maintainers
|
||||||
|
|
||||||
# Tools:
|
# Tools:
|
||||||
/tools/ @openvinotoolkit/openvino-tools-maintainers
|
/tools/ @openvinotoolkit/openvino-tools-maintainers
|
||||||
|
11
.github/workflows/linux.yml
vendored
11
.github/workflows/linux.yml
vendored
@ -641,6 +641,17 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
TEST_DEVICE: CPU
|
TEST_DEVICE: CPU
|
||||||
|
|
||||||
|
- name: JAX Layer Tests - TF FE
|
||||||
|
run: |
|
||||||
|
python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt
|
||||||
|
export PYTHONPATH=${{ env.LAYER_TESTS_INSTALL_DIR }}:$PYTHONPATH
|
||||||
|
|
||||||
|
source ${{ env.INSTALL_DIR }}/setupvars.sh
|
||||||
|
|
||||||
|
python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/jax_tests/ -m precommit --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-jax.xml
|
||||||
|
env:
|
||||||
|
TEST_DEVICE: CPU
|
||||||
|
|
||||||
- name: TensorFlow 1 Layer Tests - Legacy FE
|
- name: TensorFlow 1 Layer Tests - Legacy FE
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt
|
python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt
|
||||||
|
200
src/frontends/tensorflow/src/op/xla_dot.cpp
Normal file
200
src/frontends/tensorflow/src/op/xla_dot.cpp
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "common_op_table.hpp"
|
||||||
|
#include "input_model.hpp"
|
||||||
|
#include "openvino/op/concat.hpp"
|
||||||
|
#include "openvino/op/constant.hpp"
|
||||||
|
#include "openvino/op/gather.hpp"
|
||||||
|
#include "openvino/op/matmul.hpp"
|
||||||
|
#include "openvino/op/reduce_prod.hpp"
|
||||||
|
#include "openvino/op/reshape.hpp"
|
||||||
|
#include "openvino/op/shape_of.hpp"
|
||||||
|
#include "openvino/op/transpose.hpp"
|
||||||
|
#include "openvino/op/unsqueeze.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
#include "xla_data.pb.h"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::op;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace frontend {
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
vector<int64_t> compute_non_contracting_dims(const NodeContext& node,
|
||||||
|
const vector<int64_t>& batch_dims,
|
||||||
|
const vector<int64_t>& contracting_dims,
|
||||||
|
const Output<Node>& operand) {
|
||||||
|
// combine two vectors of batch_dims and contracting_dims
|
||||||
|
set<int64_t> unique_dims(batch_dims.begin(), batch_dims.end());
|
||||||
|
unique_dims.insert(contracting_dims.begin(), contracting_dims.end());
|
||||||
|
vector<int64_t> all_dims(unique_dims.begin(), unique_dims.end());
|
||||||
|
|
||||||
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
|
operand.get_partial_shape().rank().is_static(),
|
||||||
|
"[TensorFlow Frontend] internal operation: XlaDotV2 expects inputs of static rank");
|
||||||
|
|
||||||
|
int64_t operand_rank = operand.get_partial_shape().rank().get_length();
|
||||||
|
vector<int64_t> non_contracting_dims;
|
||||||
|
for (int64_t ind = 0; ind < operand_rank; ++ind) {
|
||||||
|
if (find(all_dims.begin(), all_dims.end(), ind) == all_dims.end()) {
|
||||||
|
non_contracting_dims.push_back(ind);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return non_contracting_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
void insert_aux_dim(const NodeContext& node, Output<Node>& operand, vector<int64_t>& dims) {
|
||||||
|
TENSORFLOW_OP_VALIDATION(node,
|
||||||
|
operand.get_partial_shape().rank().is_static(),
|
||||||
|
"[TensorFlow Frontend] internal operation: XlaDotV2 expects inputs of static rank");
|
||||||
|
if (dims.size() == 0) {
|
||||||
|
int64_t operand_rank = operand.get_partial_shape().rank().get_length();
|
||||||
|
dims.push_back(operand_rank);
|
||||||
|
auto unsqueeze_axis = make_shared<v0::Constant>(element::i64, Shape{1}, operand_rank);
|
||||||
|
operand = make_shared<v0::Unsqueeze>(operand, unsqueeze_axis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void insert_aux_dims(const NodeContext& node,
|
||||||
|
Output<Node>& operand,
|
||||||
|
vector<int64_t>& batch_dims,
|
||||||
|
vector<int64_t>& contracting_dims,
|
||||||
|
vector<int64_t>& non_contract_dims) {
|
||||||
|
insert_aux_dim(node, operand, batch_dims);
|
||||||
|
insert_aux_dim(node, operand, contracting_dims);
|
||||||
|
insert_aux_dim(node, operand, non_contract_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
Output<Node> compute_dims_shape(const Output<Node>& hs_shape, const vector<int64_t>& dims) {
|
||||||
|
auto const_dims = make_shared<v0::Constant>(element::i64, Shape{dims.size()}, dims);
|
||||||
|
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
auto dims_shape = make_shared<v8::Gather>(hs_shape, const_dims, gather_axis);
|
||||||
|
return dims_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
Output<Node> compute_dims_size(const Output<Node>& hs_shape, const Output<Node>& dims) {
|
||||||
|
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
auto dims_shape = make_shared<v8::Gather>(hs_shape, dims, gather_axis);
|
||||||
|
auto dims_size = make_shared<v1::ReduceProd>(dims_shape, gather_axis, true);
|
||||||
|
return {dims_size};
|
||||||
|
}
|
||||||
|
|
||||||
|
OutputVector translate_xla_dot_op(const NodeContext& node) {
|
||||||
|
// see specification of XlaDotV2 here: https://www.tensorflow.org/xla/operation_semantics#dot
|
||||||
|
default_op_checks(node, 2, {"XlaDotV2"});
|
||||||
|
auto lhs = node.get_input(0);
|
||||||
|
auto rhs = node.get_input(1);
|
||||||
|
auto node_name = node.get_name();
|
||||||
|
auto dimension_numbers_message = node.get_attribute<string>("dimension_numbers");
|
||||||
|
::xla::DotDimensionNumbers dimension_numbers;
|
||||||
|
TENSORFLOW_OP_VALIDATION(
|
||||||
|
node,
|
||||||
|
dimension_numbers.ParseFromArray(dimension_numbers_message.data(),
|
||||||
|
static_cast<int>(dimension_numbers_message.size())),
|
||||||
|
"[TensorFlow Frontend] Incorrect input model: incorrect DotDimensionNumbers field for XlaDotV2 " + node_name);
|
||||||
|
|
||||||
|
vector<int64_t> lhs_batch_dims(dimension_numbers.lhs_batch_dimensions().begin(),
|
||||||
|
dimension_numbers.lhs_batch_dimensions().end());
|
||||||
|
vector<int64_t> rhs_batch_dims(dimension_numbers.rhs_batch_dimensions().begin(),
|
||||||
|
dimension_numbers.rhs_batch_dimensions().end());
|
||||||
|
vector<int64_t> rhs_contract_dims(dimension_numbers.rhs_contracting_dimensions().begin(),
|
||||||
|
dimension_numbers.rhs_contracting_dimensions().end());
|
||||||
|
vector<int64_t> lhs_contract_dims(dimension_numbers.lhs_contracting_dimensions().begin(),
|
||||||
|
dimension_numbers.lhs_contracting_dimensions().end());
|
||||||
|
|
||||||
|
// compute non-contracting dimensions
|
||||||
|
auto lhs_non_contract_dims = compute_non_contracting_dims(node, lhs_batch_dims, lhs_contract_dims, lhs);
|
||||||
|
auto rhs_non_contract_dims = compute_non_contracting_dims(node, rhs_batch_dims, rhs_contract_dims, rhs);
|
||||||
|
|
||||||
|
// compute the resulted shape before possible modification
|
||||||
|
auto resulted_shape = make_shared<v0::Constant>(element::i64, Shape{0}, vector<int64_t>{})->output(0);
|
||||||
|
bool apply_reshape = false;
|
||||||
|
auto lhs_shape = make_shared<v3::ShapeOf>(lhs, element::i64);
|
||||||
|
auto rhs_shape = make_shared<v3::ShapeOf>(rhs, element::i64);
|
||||||
|
if (lhs_batch_dims.size() > 0) {
|
||||||
|
auto batch_dims_shape = compute_dims_shape(lhs_shape, lhs_batch_dims);
|
||||||
|
resulted_shape = make_shared<v0::Concat>(OutputVector{resulted_shape, batch_dims_shape}, 0);
|
||||||
|
apply_reshape = true;
|
||||||
|
}
|
||||||
|
if (lhs_non_contract_dims.size() > 0) {
|
||||||
|
auto lhs_non_contract_shape = compute_dims_shape(lhs_shape, lhs_non_contract_dims);
|
||||||
|
resulted_shape = make_shared<v0::Concat>(OutputVector{resulted_shape, lhs_non_contract_shape}, 0);
|
||||||
|
apply_reshape = true;
|
||||||
|
}
|
||||||
|
if (rhs_non_contract_dims.size() > 0) {
|
||||||
|
auto rhs_non_contract_shape = compute_dims_shape(rhs_shape, rhs_non_contract_dims);
|
||||||
|
resulted_shape = make_shared<v0::Concat>(OutputVector{resulted_shape, rhs_non_contract_shape}, 0);
|
||||||
|
apply_reshape = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// take care of that at least one dimension of each type (batch, contracting, and non-contracting) exists
|
||||||
|
// if it does not, insert it to the end
|
||||||
|
insert_aux_dims(node, lhs, lhs_batch_dims, lhs_contract_dims, lhs_non_contract_dims);
|
||||||
|
insert_aux_dims(node, rhs, rhs_batch_dims, rhs_contract_dims, rhs_non_contract_dims);
|
||||||
|
|
||||||
|
// compute non-batch and non-contracting dimensions
|
||||||
|
auto const_lhs_batch_dims = make_shared<v0::Constant>(element::i64, Shape{lhs_batch_dims.size()}, lhs_batch_dims);
|
||||||
|
auto const_rhs_batch_dims = make_shared<v0::Constant>(element::i64, Shape{rhs_batch_dims.size()}, rhs_batch_dims);
|
||||||
|
auto const_lhs_contract_dims =
|
||||||
|
make_shared<v0::Constant>(element::i64, Shape{lhs_contract_dims.size()}, lhs_contract_dims);
|
||||||
|
auto const_rhs_contract_dims =
|
||||||
|
make_shared<v0::Constant>(element::i64, Shape{rhs_contract_dims.size()}, rhs_contract_dims);
|
||||||
|
auto const_lhs_non_contract_dims =
|
||||||
|
make_shared<v0::Constant>(element::i64, Shape{lhs_non_contract_dims.size()}, lhs_non_contract_dims);
|
||||||
|
auto const_rhs_non_contract_dims =
|
||||||
|
make_shared<v0::Constant>(element::i64, Shape{rhs_non_contract_dims.size()}, rhs_non_contract_dims);
|
||||||
|
|
||||||
|
lhs_shape = make_shared<v3::ShapeOf>(lhs, element::i64);
|
||||||
|
rhs_shape = make_shared<v3::ShapeOf>(rhs, element::i64);
|
||||||
|
|
||||||
|
// compute a part of the input shape covering batch dimensions and non-contracting dimensions
|
||||||
|
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
|
||||||
|
auto batch_dims_shape = compute_dims_shape(lhs_shape, lhs_batch_dims);
|
||||||
|
|
||||||
|
// transpose both operand in a way to have dimensions in the order
|
||||||
|
// [batch dims, non-contracting dims, contracting dims]
|
||||||
|
auto lhs_transpose_order = make_shared<v0::Concat>(
|
||||||
|
OutputVector{const_lhs_batch_dims, const_lhs_non_contract_dims, const_lhs_contract_dims},
|
||||||
|
0);
|
||||||
|
auto rhs_transpose_order = make_shared<v0::Concat>(
|
||||||
|
OutputVector{const_rhs_batch_dims, const_rhs_non_contract_dims, const_rhs_contract_dims},
|
||||||
|
0);
|
||||||
|
lhs = make_shared<v1::Transpose>(lhs, lhs_transpose_order);
|
||||||
|
rhs = make_shared<v1::Transpose>(rhs, rhs_transpose_order);
|
||||||
|
|
||||||
|
// compute size of contracting dims and non-contracting dims for each operand
|
||||||
|
auto lhs_contract_size = compute_dims_size(lhs_shape, const_lhs_contract_dims);
|
||||||
|
auto rhs_contract_size = compute_dims_size(rhs_shape, const_rhs_contract_dims);
|
||||||
|
auto lhs_non_contract_size = compute_dims_size(lhs_shape, const_lhs_non_contract_dims);
|
||||||
|
auto rhs_non_contract_size = compute_dims_size(rhs_shape, const_rhs_non_contract_dims);
|
||||||
|
|
||||||
|
// merge contracting and non-contracting dimensions to have operand
|
||||||
|
// of a shape [batch dims, non-contracting dim size, contracting dims size]
|
||||||
|
auto new_lhs_shape =
|
||||||
|
make_shared<v0::Concat>(OutputVector{batch_dims_shape, lhs_non_contract_size, lhs_contract_size}, 0);
|
||||||
|
auto new_rhs_shape =
|
||||||
|
make_shared<v0::Concat>(OutputVector{batch_dims_shape, rhs_non_contract_size, rhs_contract_size}, 0);
|
||||||
|
lhs = make_shared<v1::Reshape>(lhs, new_lhs_shape, false);
|
||||||
|
rhs = make_shared<v1::Reshape>(rhs, new_rhs_shape, false);
|
||||||
|
|
||||||
|
// execute MatMul that support batch matrix-multiplication
|
||||||
|
// note that the second operand is transposed
|
||||||
|
auto matmul = make_shared<v0::MatMul>(lhs, rhs, false, true)->output(0);
|
||||||
|
if (apply_reshape) {
|
||||||
|
matmul = make_shared<v1::Reshape>(matmul, resulted_shape, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
set_node_name(node_name, matmul.get_node_shared_ptr());
|
||||||
|
return {matmul};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace op
|
||||||
|
} // namespace tensorflow
|
||||||
|
} // namespace frontend
|
||||||
|
} // namespace ov
|
@ -46,6 +46,7 @@ TF_OP_CONVERTER(translate_varhandle_op);
|
|||||||
TF_OP_CONVERTER(translate_variable_op);
|
TF_OP_CONVERTER(translate_variable_op);
|
||||||
TF_OP_CONVERTER(translate_varisinitialized_op);
|
TF_OP_CONVERTER(translate_varisinitialized_op);
|
||||||
TF_OP_CONVERTER(translate_while_op);
|
TF_OP_CONVERTER(translate_while_op);
|
||||||
|
TF_OP_CONVERTER(translate_xla_dot_op);
|
||||||
|
|
||||||
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||||
return {
|
return {
|
||||||
@ -301,6 +302,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
|||||||
{"SparseFillEmptyRows", CreatorFunction(translate_sparse_fill_empty_rows_op)},
|
{"SparseFillEmptyRows", CreatorFunction(translate_sparse_fill_empty_rows_op)},
|
||||||
{"SparseSegmentSum", CreatorFunction(translate_sparse_segment_sum_op)},
|
{"SparseSegmentSum", CreatorFunction(translate_sparse_segment_sum_op)},
|
||||||
{"Unique", CreatorFunction(translate_unique_op)},
|
{"Unique", CreatorFunction(translate_unique_op)},
|
||||||
|
|
||||||
|
// XLA operations
|
||||||
|
{"XlaDotV2", CreatorFunction(translate_xla_dot_op)},
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
744
src/frontends/tensorflow/src/proto/xla_data.proto
Normal file
744
src/frontends/tensorflow/src/proto/xla_data.proto
Normal file
@ -0,0 +1,744 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package xla;
|
||||||
|
|
||||||
|
option cc_enable_arenas = true;
|
||||||
|
|
||||||
|
// Primitive types are the individual values that can be held in rectangular
|
||||||
|
// multidimensional arrays. A description of the rectangular multidimensional
|
||||||
|
// array dimensions / primitive type is given by Shape, below.
|
||||||
|
//
|
||||||
|
// LINT.IfChange
|
||||||
|
enum PrimitiveType {
|
||||||
|
// Invalid primitive type to serve as default.
|
||||||
|
PRIMITIVE_TYPE_INVALID = 0;
|
||||||
|
|
||||||
|
// Predicates are two-state booleans.
|
||||||
|
PRED = 1;
|
||||||
|
|
||||||
|
// Signed integral values of fixed width.
|
||||||
|
S8 = 2;
|
||||||
|
S16 = 3;
|
||||||
|
S32 = 4;
|
||||||
|
S64 = 5;
|
||||||
|
|
||||||
|
// Unsigned integral values of fixed width.
|
||||||
|
U8 = 6;
|
||||||
|
U16 = 7;
|
||||||
|
U32 = 8;
|
||||||
|
U64 = 9;
|
||||||
|
|
||||||
|
// Floating-point values of fixed width.
|
||||||
|
//
|
||||||
|
// Note: if f16s are not natively supported on the device, they will be
|
||||||
|
// converted to f16 from f32 at arbirary points in the computation.
|
||||||
|
F16 = 10;
|
||||||
|
F32 = 11;
|
||||||
|
|
||||||
|
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
|
||||||
|
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
|
||||||
|
// and 7 bits for the mantissa.
|
||||||
|
BF16 = 16;
|
||||||
|
|
||||||
|
F64 = 12;
|
||||||
|
|
||||||
|
// Complex values of fixed width.
|
||||||
|
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
|
||||||
|
C128 = 18; // Paired F64 (real, imag), as in std::complex<double>.
|
||||||
|
|
||||||
|
// A tuple is a polymorphic sequence; e.g. a shape that holds different
|
||||||
|
// sub-shapes. They are used for things like returning multiple values from a
|
||||||
|
// computation; e.g. a computation that returns weights and biases may have a
|
||||||
|
// signature that results in a tuple like (f32[784x2000], f32[2000])
|
||||||
|
//
|
||||||
|
// If a shape proto has the tuple element type, it may not have any entries
|
||||||
|
// in the dimensions field.
|
||||||
|
TUPLE = 13;
|
||||||
|
|
||||||
|
// An opaque type used for passing context-specific data to a custom
|
||||||
|
// operation. Shapes of this primitive type will have empty dimensions and
|
||||||
|
// tuple_shapes fields.
|
||||||
|
//
|
||||||
|
// (OPAQUE would be a better name for this identifier, but that conflicts with
|
||||||
|
// a macro defined in windows.h.)
|
||||||
|
OPAQUE_TYPE = 14;
|
||||||
|
|
||||||
|
// A token type threaded between side-effecting operations. Shapes of this
|
||||||
|
// primitive type will have empty dimensions and tuple_shapes fields.
|
||||||
|
TOKEN = 17;
|
||||||
|
|
||||||
|
// Next = 19
|
||||||
|
}
|
||||||
|
// LINT.ThenChange(
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
|
||||||
|
// )
|
||||||
|
|
||||||
|
// Describes the padding configuration for Pad operation. The padding amount on
|
||||||
|
// both edges as well as between the elements are specified for each dimension.
|
||||||
|
message PaddingConfig {
|
||||||
|
// Describes the padding configuration for a dimension.
|
||||||
|
message PaddingConfigDimension {
|
||||||
|
// Padding amount on the low-end (next to the index 0). May be negative.
|
||||||
|
int64 edge_padding_low = 1;
|
||||||
|
|
||||||
|
// Padding amount on the high-end (next to the highest index). May be
|
||||||
|
// negative.
|
||||||
|
int64 edge_padding_high = 2;
|
||||||
|
|
||||||
|
// Padding amount between the elements. May not be negative.
|
||||||
|
int64 interior_padding = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The padding configuration for all dimensions.
|
||||||
|
repeated PaddingConfigDimension dimensions = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A format specifies the method used by a layout to store an array in memory.
|
||||||
|
enum Format {
|
||||||
|
// TODO(b/120869032): Rename this to FORMAT_NONE or something else which
|
||||||
|
// better corresponds to its meaning.
|
||||||
|
INVALID_FORMAT = 0;
|
||||||
|
// The default layout, with exactly one storage location per element.
|
||||||
|
DENSE = 1;
|
||||||
|
reserved 2;
|
||||||
|
reserved "SPARSE";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes a tile used in tiling-based layout. Refer to
|
||||||
|
// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
|
||||||
|
// details about tiling-based layout.
|
||||||
|
message TileProto {
|
||||||
|
// Number of elements in each dimension of the tile. It's ordered from the
|
||||||
|
// most major dimension of the tile to the most minor dimension of the tile.
|
||||||
|
// The dimensions correspond to a suffix of the dimensions of the shape being
|
||||||
|
// tiled.
|
||||||
|
repeated int64 dimensions = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A layout describes how the array is placed in (1D) memory space. This
|
||||||
|
// includes the minor-to-major ordering of dimensions within a shape.
|
||||||
|
//
|
||||||
|
// Clients must specify the layouts of input Literals to the
|
||||||
|
// computation. Layouts specified in interior operations which take Shapes (for
|
||||||
|
// example, Convert) are ignored.
|
||||||
|
//
|
||||||
|
// See the XLA documentation for more information on shapes and layouts.
|
||||||
|
//
|
||||||
|
// LINT.IfChange
|
||||||
|
message LayoutProto {
|
||||||
|
// The method used to store the data in memory. The format determines which of
|
||||||
|
// the other fields are used by the layout.
|
||||||
|
Format format = 4;
|
||||||
|
|
||||||
|
// Sequence of dimension numbers, from minor (fastest varying index) to major
|
||||||
|
// (slowest varying index). This field is required.
|
||||||
|
repeated int64 minor_to_major = 1;
|
||||||
|
|
||||||
|
reserved 2;
|
||||||
|
reserved "padded_dimensions";
|
||||||
|
|
||||||
|
reserved 3;
|
||||||
|
reserved "padding_value";
|
||||||
|
|
||||||
|
reserved 5;
|
||||||
|
reserved "max_sparse_elements";
|
||||||
|
|
||||||
|
// A sequence of tiles, starting from the tile that's applied first to the
|
||||||
|
// Shape.
|
||||||
|
//
|
||||||
|
// TODO(b/119839262): implement tiling in each backend or add Unimplemented
|
||||||
|
// error.
|
||||||
|
repeated TileProto tiles = 6;
|
||||||
|
|
||||||
|
// Bit size of each element. If the size is bigger than what the element
|
||||||
|
// type requires, the value is stored in the least significant
|
||||||
|
// bits and the additional most significant bits are filled with 0's.
|
||||||
|
//
|
||||||
|
// TODO(b/119839262): implement in each backend or add Unimplemented error.
|
||||||
|
int64 element_size_in_bits = 7;
|
||||||
|
|
||||||
|
// Memory space where this array resides. The integer field is interpreted in
|
||||||
|
// a backend-specific manner.
|
||||||
|
int64 memory_space = 8;
|
||||||
|
|
||||||
|
// Important: if any field is added, be sure to modify ShapeUtil::Equal() and
|
||||||
|
// LayoutUtil::Hash appropriately to account for the new field.
|
||||||
|
}
|
||||||
|
// LINT.ThenChange( \
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
|
||||||
|
|
||||||
|
// A shape describes the number of dimensions in the array, the size of each
|
||||||
|
// dimension, and the primitive component type.
|
||||||
|
//
|
||||||
|
// Tuples are a special case in that they have rank zero and have tuple_shapes
|
||||||
|
// defined.
|
||||||
|
//
|
||||||
|
// See the XLA documentation for more information on shapes and layouts.
|
||||||
|
//
|
||||||
|
// LINT.IfChange
|
||||||
|
message ShapeProto {
|
||||||
|
reserved 1;
|
||||||
|
reserved "rank";
|
||||||
|
|
||||||
|
// The element type for this shape.
|
||||||
|
PrimitiveType element_type = 2;
|
||||||
|
|
||||||
|
// The size (number of elements) for each dimension, or an upper bound on the
|
||||||
|
// size if the dimension is dynamic. In XLA, dimensions are numbered from 0
|
||||||
|
// to N-1 for an N-dimensional array. The first element of 'dimensions' is the
|
||||||
|
// size of dimension 0, the second element is the size of dimension 1, and so
|
||||||
|
// forth. Empty list indicates a scalar.
|
||||||
|
//
|
||||||
|
// If the respective element in 'is_dimension_dynamic' is true then the value
|
||||||
|
// in this field represents an upper bound on the size of the dimension.
|
||||||
|
repeated int64 dimensions = 3;
|
||||||
|
|
||||||
|
// For tuples only, the shapes of constituent shapes in the tuple sequence.
|
||||||
|
repeated ShapeProto tuple_shapes = 4;
|
||||||
|
|
||||||
|
// The layout used to back this shape.
|
||||||
|
LayoutProto layout = 5;
|
||||||
|
|
||||||
|
// For arrays, this indicates whether or not each dimension is
|
||||||
|
// dynamically-sized. The number of elements in this repeated field should be
|
||||||
|
// zero (indicating that no dimensions are dynamic) or equal to the number of
|
||||||
|
// elements in the 'dimensions' field.
|
||||||
|
repeated bool is_dynamic_dimension = 6;
|
||||||
|
|
||||||
|
// Important: if any field is added, be sure to modify ShapeUtil::Equal(),
|
||||||
|
// ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
|
||||||
|
// the new field.
|
||||||
|
}
|
||||||
|
// LINT.ThenChange( \
|
||||||
|
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
|
||||||
|
|
||||||
|
// Shape of the parameters and output of a computation (like a traditional
|
||||||
|
// function signature).
|
||||||
|
message ProgramShapeProto {
|
||||||
|
repeated ShapeProto parameters = 1;
|
||||||
|
ShapeProto result = 2;
|
||||||
|
repeated string parameter_names = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Statistics of a computation.
|
||||||
|
message ComputationStats {
|
||||||
|
// The number of floating point operations in the computation.
|
||||||
|
double flop_count = 1;
|
||||||
|
|
||||||
|
// The number of transcendental operations (e.g., exp) in the computation.
|
||||||
|
double transcendental_count = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The type optimization profiles in use.
|
||||||
|
enum ProfileType {
|
||||||
|
INVALID = 0;
|
||||||
|
WINDOW = 1;
|
||||||
|
FLAG = 2;
|
||||||
|
INTEGER = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Symbolization metadata for HLO Instructions.
|
||||||
|
//
|
||||||
|
// This metadata is used for debugging XLA code generation, as well as
|
||||||
|
// performance profiling of XLA-generated executables.
|
||||||
|
message OpMetadata {
|
||||||
|
// The framework op name that generated this XLA op.
|
||||||
|
//
|
||||||
|
// Frameworks that build on top of XLA should mirror the names of their ops
|
||||||
|
// back to users by specifying the op_type. In this way, even if the
|
||||||
|
// framework's "ops" are implemented as multiple XLA HLO Ops, they can be
|
||||||
|
// grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
|
||||||
|
// multiple ops, then each op should have the op_type be "SoftMax".)
|
||||||
|
string op_type = 1;
|
||||||
|
// The user-specified name of the op.
|
||||||
|
//
|
||||||
|
// This name is often unique within a computation. Note: some frameworks
|
||||||
|
// add auto-generated names if the user does not provide one.
|
||||||
|
string op_name = 2;
|
||||||
|
// Indicate a file and line that this op is associated to in a user's program.
|
||||||
|
//
|
||||||
|
// e.g. it could be the file and line of user code that generated the op.
|
||||||
|
string source_file = 3;
|
||||||
|
int32 source_line = 4;
|
||||||
|
|
||||||
|
repeated ProfileType profile_type = 5;
|
||||||
|
|
||||||
|
// HloPassMetadata.pass_id of the pass that created this HLO instruction
|
||||||
|
// object. Should never be copied between HLO instructions. Zero if unset and
|
||||||
|
// -1 if the instruction was created before HLO passes began.
|
||||||
|
int64 creation_pass_id = 6;
|
||||||
|
|
||||||
|
// HloPassMetadata.pass_id of the pass that created the logical functionality
|
||||||
|
// that this HLO instruction represents. Should be copied between HLO
|
||||||
|
// instructions that correspond across compilation passes. Zero if unset and
|
||||||
|
// -1 if the instruction was created before HLO passes began.
|
||||||
|
int64 logical_creation_pass_id = 7;
|
||||||
|
|
||||||
|
// The footprint of the generated code for the instruction.
|
||||||
|
int64 size_of_generated_code_in_bytes = 8;
|
||||||
|
// The size of the working set, i.e., the amount of memory, used by the
|
||||||
|
// instruction in a compiler-managed fast device memory.
|
||||||
|
int64 size_of_memory_working_set_in_bytes = 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Profile data from the execution of a computation.
|
||||||
|
message ExecutionProfile {
|
||||||
|
// Whether the executable was read from the compilation cache.
|
||||||
|
bool compilation_cache_hit = 1;
|
||||||
|
|
||||||
|
// The time in milliseconds spent to compile the computation. This only set if
|
||||||
|
// the executable was not read from the compilation cache
|
||||||
|
// (compilation_cache_hit == false).
|
||||||
|
int64 compile_time_ms = 2;
|
||||||
|
|
||||||
|
// The number of cycles spent for the computation. This does not include the
|
||||||
|
// time taken for the data transfers between the host and the device. This is
|
||||||
|
// a target-dependent field and only used for debugging purposes.
|
||||||
|
int64 compute_cycle_count = 3;
|
||||||
|
|
||||||
|
// The time in nanoseconds spent for the computation, without data transfer.
|
||||||
|
int64 compute_time_ns = 4;
|
||||||
|
|
||||||
|
// The time in nanoseconds spent for the entire computation, including the
|
||||||
|
// result data transfer time. Current implementation does not spend any cycles
|
||||||
|
// for the input data transfer since the memory is initialized with the proper
|
||||||
|
// values before the execution.
|
||||||
|
int64 compute_and_transfer_time_ns = 5;
|
||||||
|
|
||||||
|
// The size of the binary code in the executable.
|
||||||
|
int64 executable_size_in_bytes = 6;
|
||||||
|
|
||||||
|
// Whether this profile was drawn from a cache of profiles instead of from
|
||||||
|
// execution on the hardware.
|
||||||
|
bool profile_cache_hit = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle given to a user that represents an execution that the user launched
|
||||||
|
// asynchronously on the device.
|
||||||
|
message ExecutionHandle {
|
||||||
|
int64 handle = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle given to a user that represents a globally accessible allocation.
|
||||||
|
// Contrast this against a ComputationDataHandle, which is not globally
|
||||||
|
// accessible, since it only exists within a specific computation.
|
||||||
|
message GlobalDataHandle {
|
||||||
|
int64 handle = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle given to a user that represents a replicated virtual device. Each
|
||||||
|
// replicated device represents N physical devices for execution where N is the
|
||||||
|
// number of replicas.
|
||||||
|
message DeviceHandle {
|
||||||
|
int64 handle = 1;
|
||||||
|
|
||||||
|
// The number of model-parallel virtual devices that communicate via XLA
|
||||||
|
// Send/Recv instructions.
|
||||||
|
int64 device_count = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle given to a user to represent a channel between two computations
|
||||||
|
// via a Send and Recv instruction pair. Channels are unbuffered, so Send
|
||||||
|
// Send instructions will be blocked until the data is transferred.
|
||||||
|
message ChannelHandle {
|
||||||
|
int64 handle = 1;
|
||||||
|
enum ChannelType {
|
||||||
|
// Invalid primitive type to serve as default.
|
||||||
|
CHANNEL_TYPE_INVALID = 0;
|
||||||
|
|
||||||
|
// A channel for sending data between devices.
|
||||||
|
DEVICE_TO_DEVICE = 1;
|
||||||
|
|
||||||
|
// A channel for sending data from the device to the host. Can only be used
|
||||||
|
// with a Send operation.
|
||||||
|
DEVICE_TO_HOST = 2;
|
||||||
|
|
||||||
|
// A channel for sending data from the host to the device. Can only be used
|
||||||
|
// with a Recv operation.
|
||||||
|
HOST_TO_DEVICE = 3;
|
||||||
|
}
|
||||||
|
ChannelType type = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
|
||||||
|
// represents the device ids assigned to a set of replicated computations.
|
||||||
|
// See xla::DeviceAssignment class comment for more details.
|
||||||
|
message DeviceAssignmentProto {
|
||||||
|
int32 replica_count = 1;
|
||||||
|
int32 computation_count = 2;
|
||||||
|
|
||||||
|
// Each logical computation runs on replica_count physical devices.
|
||||||
|
// ComputationDevice represents the device ids assinged to the replicas.
|
||||||
|
message ComputationDevice {
|
||||||
|
repeated int32 replica_device_ids = 1;
|
||||||
|
}
|
||||||
|
repeated ComputationDevice computation_devices = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Literals are used when the server and client need to exchange materialized
|
||||||
|
// data / results. Literals are also used to describe constants used in
|
||||||
|
// computations.
|
||||||
|
//
|
||||||
|
// Transfers to/from the client are encoded in literal form, and the structure
|
||||||
|
// of the repeated fields is implied by the shape.
|
||||||
|
message LiteralProto {
|
||||||
|
ShapeProto shape = 1;
|
||||||
|
repeated bool preds = 2;
|
||||||
|
bytes s8s = 15;
|
||||||
|
bytes u8s = 3;
|
||||||
|
repeated int32 s32s = 4;
|
||||||
|
repeated int64 s64s = 5;
|
||||||
|
repeated uint32 u32s = 6;
|
||||||
|
repeated uint64 u64s = 7;
|
||||||
|
repeated float f32s = 8;
|
||||||
|
repeated double f64s = 9;
|
||||||
|
repeated float c64s = 12; // Stored as interleaved real, imag floats.
|
||||||
|
repeated double c128s = 18; // Stored as interleaved real, imag doubles.
|
||||||
|
repeated LiteralProto tuple_literals = 10;
|
||||||
|
// The F16s, BF16s, U16s and S16s are encoded in little endian byte order
|
||||||
|
bytes f16s = 11;
|
||||||
|
bytes bf16s = 13;
|
||||||
|
bytes u16s = 16;
|
||||||
|
bytes s16s = 17;
|
||||||
|
repeated int64 sparse_indices = 14;
|
||||||
|
// Next = 19
|
||||||
|
}
|
||||||
|
|
||||||
|
message WindowDimension {
|
||||||
|
// The size of the window in this dimension. For a rectangle, this would be
|
||||||
|
// the width or height.
|
||||||
|
int64 size = 1;
|
||||||
|
|
||||||
|
// The stride at which the window moves across the base area in this
|
||||||
|
// dimension. In other words, this is the spacing between different
|
||||||
|
// positions of the window in this dimension.
|
||||||
|
int64 stride = 2;
|
||||||
|
|
||||||
|
// If positive, means the amount of padding to add to the base area at the low
|
||||||
|
// end of this dimension; if negative, its negative means the number of
|
||||||
|
// elements removed from the low end of this dimension. For example, in the
|
||||||
|
// horizontal dimension of a rectangle, this would be the number of padding
|
||||||
|
// values to pad on the left, given that indices increase when going right.
|
||||||
|
// The actual padding value depends upon the context. Convolution pads with
|
||||||
|
// zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
|
||||||
|
// init value.
|
||||||
|
int64 padding_low = 3;
|
||||||
|
|
||||||
|
// As padding_low, but on the high end of this dimension. For example, in the
|
||||||
|
// horizontal dimension of a rectangle, this would be the number of values to
|
||||||
|
// pad on the right, given that indices increase when going right.
|
||||||
|
int64 padding_high = 4;
|
||||||
|
|
||||||
|
// Dilation factor of the sliding window in this dimension. A dilation factor
|
||||||
|
// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
|
||||||
|
// implicitly placed between each kernel element. This value may not be less
|
||||||
|
// than 1. See documentation for convolution.
|
||||||
|
int64 window_dilation = 5;
|
||||||
|
|
||||||
|
// Dilation factor of the base area in this dimension. A dilation factor of 1
|
||||||
|
// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
|
||||||
|
// placed between each base area element. This value may not be less than 1.
|
||||||
|
// See documentation for convolution.
|
||||||
|
int64 base_dilation = 6;
|
||||||
|
|
||||||
|
// Window reversal means that this dimension was logically reversed before the
|
||||||
|
// operation.
|
||||||
|
bool window_reversal = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes the windowing in an operation such as convolution.
|
||||||
|
//
|
||||||
|
// The window is moved across a base area and for each position of the
|
||||||
|
// window a computation is performed. The field below describes the
|
||||||
|
// window and the movement of the window across a base area.
|
||||||
|
message Window {
|
||||||
|
repeated WindowDimension dimensions = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes the dimension numbers for a gather operation.
|
||||||
|
//
|
||||||
|
// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
|
||||||
|
// more details.
|
||||||
|
message GatherDimensionNumbers {
|
||||||
|
// "Window indices" is a term for a set of indices that index into the
|
||||||
|
// interior of a dynamic-slice from the input tensor, the starting indices for
|
||||||
|
// which were computed from output_gather_dims (see the operation semantic for
|
||||||
|
// how this is defined) and the start_indices tensor.
|
||||||
|
//
|
||||||
|
// The window indices for a specific output index Out is computed as:
|
||||||
|
//
|
||||||
|
// i = 0
|
||||||
|
// for (k : [0, input_tensor_shape.rank))
|
||||||
|
// window_indices[k] =
|
||||||
|
// if k in collapsed_slice_dims
|
||||||
|
// then 0
|
||||||
|
// else Out[offset_dims[i++]]
|
||||||
|
repeated int64 offset_dims = 1;
|
||||||
|
repeated int64 collapsed_slice_dims = 2;
|
||||||
|
|
||||||
|
// This is interpreted as a map from i to start_index_map[i]. It
|
||||||
|
// transforms the gather index looked up from the start_indices tensor into
|
||||||
|
// the starting index in the input space.
|
||||||
|
repeated int64 start_index_map = 3;
|
||||||
|
|
||||||
|
// The dimension in the start_indices input that contains the starting
|
||||||
|
// indices.
|
||||||
|
int64 index_vector_dim = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes the dimension numbers for a scatter operation.
|
||||||
|
//
|
||||||
|
// All the fields are similar to the corresponding fields in
|
||||||
|
// GatherDimensionNumbers. Differences are noted below.
|
||||||
|
message ScatterDimensionNumbers {
|
||||||
|
// The set of dimensions in the updates shape that are window dimensions.
|
||||||
|
repeated int64 update_window_dims = 1;
|
||||||
|
// The set of window dimensions that must be inserted into the updates shape.
|
||||||
|
repeated int64 inserted_window_dims = 2;
|
||||||
|
|
||||||
|
repeated int64 scatter_dims_to_operand_dims = 3;
|
||||||
|
int64 index_vector_dim = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ConvolutionDimensionNumbers {
|
||||||
|
// The number of the dimension that represents batch in the input.
|
||||||
|
int64 input_batch_dimension = 7;
|
||||||
|
|
||||||
|
// The number of the dimension that represents features in the input.
|
||||||
|
int64 input_feature_dimension = 8;
|
||||||
|
|
||||||
|
// The dimension numbers for the spatial dimensions that the window
|
||||||
|
// moves through in the input.
|
||||||
|
repeated int64 input_spatial_dimensions = 11;
|
||||||
|
|
||||||
|
// The number of the dimension that represents input features in the
|
||||||
|
// convolutional kernel (rhs).
|
||||||
|
int64 kernel_input_feature_dimension = 3;
|
||||||
|
|
||||||
|
// The number of the dimension that represents output features in
|
||||||
|
// the convolutional kernel (rhs).
|
||||||
|
int64 kernel_output_feature_dimension = 4;
|
||||||
|
|
||||||
|
// The dimension numbers for the spatial dimensions that the window
|
||||||
|
// moves through in the kernel (rhs). window.strides(0) is the
|
||||||
|
// stride in the kernel_spatial_dimensions(0) dimension.
|
||||||
|
repeated int64 kernel_spatial_dimensions = 6;
|
||||||
|
|
||||||
|
// The number of the dimension that represents batch in the output.
|
||||||
|
int64 output_batch_dimension = 9;
|
||||||
|
|
||||||
|
// The number of the dimension that represents features in the output.
|
||||||
|
int64 output_feature_dimension = 10;
|
||||||
|
|
||||||
|
// The dimension numbers for the spatial dimensions that the window
|
||||||
|
// moves through in the output.
|
||||||
|
repeated int64 output_spatial_dimensions = 12;
|
||||||
|
|
||||||
|
// Next = 13
|
||||||
|
}
|
||||||
|
|
||||||
|
enum PaddingType {
|
||||||
|
PADDING_INVALID = 0;
|
||||||
|
PADDING_VALID = 1; // Only valid portion of the base are covered.
|
||||||
|
PADDING_SAME = 2; // Extra is added to produce same output size as the input.
|
||||||
|
}
|
||||||
|
|
||||||
|
enum FftType {
|
||||||
|
FFT = 0; // Forward FFT; complex in, complex out.
|
||||||
|
IFFT = 1; // Inverse FFT; complex in, complex out.
|
||||||
|
RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out
|
||||||
|
IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in,
|
||||||
|
// fft_length real out
|
||||||
|
}
|
||||||
|
|
||||||
|
message DotDimensionNumbers {
|
||||||
|
// The dimension numbers that represent the 'lhs' contracting dimensions.
|
||||||
|
repeated int64 lhs_contracting_dimensions = 1;
|
||||||
|
// The dimension numbers that represent the 'rhs' contracting dimensions.
|
||||||
|
repeated int64 rhs_contracting_dimensions = 2;
|
||||||
|
// The dimension numbers that represent the 'lhs' batch dimensions.
|
||||||
|
repeated int64 lhs_batch_dimensions = 3;
|
||||||
|
// The dimension numbers that represent the 'rhs' batch dimensions.
|
||||||
|
repeated int64 rhs_batch_dimensions = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RandomDistribution {
|
||||||
|
RNG_INVALID = 0;
|
||||||
|
|
||||||
|
// Creates a uniform-distribution-generated random number on the semi-open
|
||||||
|
// interval [parameter[0], parameter[1]).
|
||||||
|
RNG_UNIFORM = 1;
|
||||||
|
|
||||||
|
// Creates a normal-distribution-generated random number with mean
|
||||||
|
// parameter[0] and standard deviation parameter[1].
|
||||||
|
RNG_NORMAL = 2;
|
||||||
|
|
||||||
|
// Next: 4
|
||||||
|
}
|
||||||
|
|
||||||
|
enum RandomAlgorithm {
|
||||||
|
RNG_DEFAULT = 0; // Backend dependent default algorithm.
|
||||||
|
RNG_THREE_FRY = 1;
|
||||||
|
RNG_PHILOX = 2;
|
||||||
|
// Next: 2
|
||||||
|
}
|
||||||
|
|
||||||
|
message TriangularSolveOptions {
|
||||||
|
// If true, solves ax = b. If false, solves xa = b.
|
||||||
|
bool left_side = 1;
|
||||||
|
|
||||||
|
// If true, 'a' is lower triangular. If false, 'a' is upper triangular.
|
||||||
|
bool lower = 2;
|
||||||
|
|
||||||
|
// If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
|
||||||
|
bool unit_diagonal = 3;
|
||||||
|
|
||||||
|
// Should we transpose or use the adjoint of 'a'?
|
||||||
|
enum Transpose {
|
||||||
|
TRANSPOSE_INVALID = 0;
|
||||||
|
NO_TRANSPOSE = 1; // Don't transpose 'a'.
|
||||||
|
TRANSPOSE = 2; // Transpose 'a'.
|
||||||
|
ADJOINT = 3; // Complex conjugate and transpose 'a'.
|
||||||
|
}
|
||||||
|
Transpose transpose_a = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CholeskyOptions {
|
||||||
|
// If true, uses the lower triangle of `a`. If false, uses the upper triangle
|
||||||
|
// of `a`.
|
||||||
|
bool lower = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic map of attributes used to pass hints / configuration options from
|
||||||
|
// the Python frontend to the XLA backend.
|
||||||
|
message FrontendAttributes {
|
||||||
|
map<string, string> map = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OpSharding {
|
||||||
|
enum Type {
|
||||||
|
// This sharding is replicated across all devices (implies maximal,
|
||||||
|
// all other fields are unused).
|
||||||
|
REPLICATED = 0;
|
||||||
|
// This sharding is maximal - one device runs the entire operation.
|
||||||
|
MAXIMAL = 1;
|
||||||
|
// This sharding is a tuple - only the tuple_shardings field is valid.
|
||||||
|
TUPLE = 2;
|
||||||
|
// None of the above; tile_shape and tile_assignment are both used.
|
||||||
|
OTHER = 3;
|
||||||
|
// This op is manually sharded: the shapes are already partitioned and the
|
||||||
|
// partitioner should not change this op.
|
||||||
|
MANUAL = 4;
|
||||||
|
}
|
||||||
|
Type type = 1;
|
||||||
|
// The shape of the sharded tile.
|
||||||
|
ShapeProto tile_shape = 2;
|
||||||
|
// The shape of the tile assignment tensor - this must be the same rank as
|
||||||
|
// tile_shape and the product of its dimensions must equal
|
||||||
|
// tile_assignment_devices.size().
|
||||||
|
repeated int64 tile_assignment_dimensions = 3;
|
||||||
|
// Flattened list of device IDs. The order of flattening is the same as used
|
||||||
|
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
|
||||||
|
repeated int64 tile_assignment_devices = 4;
|
||||||
|
// If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
|
||||||
|
// in pre-order. The tuple shape could be nested; here we store just a
|
||||||
|
// flattened list of all leaves in the tuple shape. Note that the tuple shape
|
||||||
|
// is not stored here; shardings do not store the shapes to which they are
|
||||||
|
// applied, this is inferred from the instruction this sharding gets attached
|
||||||
|
// to.
|
||||||
|
repeated OpSharding tuple_shardings = 5;
|
||||||
|
|
||||||
|
// Only used for OTHER type. If true, data is sharded according to other
|
||||||
|
// dimensions of tile_assignment(), but replicated across devices along the
|
||||||
|
// last dimension. (Experimental)
|
||||||
|
bool replicate_on_last_tile_dim = 6;
|
||||||
|
// This field is used to track the source of this sharding, usually derived
|
||||||
|
// from instructions. Multple metadata may be populated if sharding is
|
||||||
|
// combined with other shardings. Metadata are to not be populated when
|
||||||
|
// type == TUPLE and instead metadata should be set on individual tuple
|
||||||
|
// elements.
|
||||||
|
repeated OpMetadata metadata = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes the replica groups in a cross replica op (e.g., all-reduce and
|
||||||
|
// all-to-all).
|
||||||
|
message ReplicaGroup {
|
||||||
|
// The ids of the replicas that belongs to the same group. The ordering of the
|
||||||
|
// ids matters in some ops (e.g., all-to-all).
|
||||||
|
repeated int64 replica_ids = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes the source target pair in the collective permute op.
|
||||||
|
message SourceTarget {
|
||||||
|
int64 source = 1;
|
||||||
|
int64 target = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Used to indicate the precision configuration. It has backend specific
|
||||||
|
// meaning.
|
||||||
|
message PrecisionConfig {
|
||||||
|
enum Precision {
|
||||||
|
DEFAULT = 0;
|
||||||
|
HIGH = 1;
|
||||||
|
HIGHEST = 2;
|
||||||
|
|
||||||
|
// Next: 3
|
||||||
|
}
|
||||||
|
repeated Precision operand_precision = 1;
|
||||||
|
|
||||||
|
// Next: 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Describes whether all data-parallelism replicas will receive the same
|
||||||
|
// parameter data at each buffer.
|
||||||
|
message ParameterReplication {
|
||||||
|
// A list of boolean values for the flattened leaf buffers. Each value
|
||||||
|
// indicates whether the corresponding leaf buffer is replicated.
|
||||||
|
//
|
||||||
|
// If this field is empty, it means no buffer is replicated. Otherwise, the
|
||||||
|
// number of elements in this field must match the number of leaf buffers in
|
||||||
|
// the HLO instruction's shape.
|
||||||
|
repeated bool replicated_at_leaf_buffers = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A backend-config for kWhile loops that stores the loop's trip count, if it is
|
||||||
|
// known.
|
||||||
|
//
|
||||||
|
// This is useful for backends that can implement a `for i in 0..N` loop more
|
||||||
|
// efficiently than a `while` loop. For example, on GPUs, we can implement a
|
||||||
|
// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
|
||||||
|
// whereas implementing a `while` loop requires a host-device sync on each
|
||||||
|
// iteration.
|
||||||
|
message WhileLoopBackendConfig {
|
||||||
|
message KnownTripCount {
|
||||||
|
int64 n = 1;
|
||||||
|
}
|
||||||
|
// This indirection lets us distinguish between known-trip-count == 0 and
|
||||||
|
// unknown-trip-count.
|
||||||
|
KnownTripCount known_trip_count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specifies a pair of output/operand buffers for kCustomCall that alias each
|
||||||
|
// other.
|
||||||
|
message CustomCallOutputOperandAliasing {
|
||||||
|
repeated int64 output_shape_index = 1;
|
||||||
|
int64 operand_index = 2;
|
||||||
|
repeated int64 operand_shape_index = 3;
|
||||||
|
}
|
12
tests/layer_tests/jax_tests/conftest.py
Normal file
12
tests/layer_tests/jax_tests/conftest.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from common.layer_test_class import get_params
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
test_gen_attrs_names = list(inspect.signature(get_params).parameters)
|
||||||
|
params = get_params()
|
||||||
|
metafunc.parametrize(test_gen_attrs_names, params, scope="function")
|
142
tests/layer_tests/jax_tests/jax_layer_test_class.py
Normal file
142
tests/layer_tests/jax_tests/jax_layer_test_class.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from common.constants import test_device, test_precision
|
||||||
|
from jax import numpy as jnp
|
||||||
|
from openvino.runtime import Core
|
||||||
|
|
||||||
|
|
||||||
|
class JaxLayerTest:
|
||||||
|
def _test(self, model, ref_net, ie_device, precision, ir_version, infer_timeout=60, dynamic_shapes=True,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
:param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled.
|
||||||
|
Example: "transform_1,transform_2"
|
||||||
|
"""
|
||||||
|
inputs = self._prepare_input()
|
||||||
|
converted_model = self.convert_via_tensorflow_function(model, inputs)
|
||||||
|
|
||||||
|
# OV infer:
|
||||||
|
core = Core()
|
||||||
|
compiled = core.compile_model(converted_model, ie_device)
|
||||||
|
infer_res = compiled(deepcopy(inputs))
|
||||||
|
|
||||||
|
# Framework infer:
|
||||||
|
fw_res = model(*deepcopy(inputs))
|
||||||
|
|
||||||
|
if not isinstance(fw_res, (tuple)):
|
||||||
|
fw_res = (fw_res,)
|
||||||
|
|
||||||
|
output_list = ov_res_to_list(infer_res)
|
||||||
|
|
||||||
|
def flattenize_dict_outputs(res):
|
||||||
|
if isinstance(res, dict):
|
||||||
|
return flattenize_outputs(res.values())
|
||||||
|
|
||||||
|
def flattenize_outputs(res):
|
||||||
|
results = []
|
||||||
|
for res_item in res:
|
||||||
|
# if None is at output we skip it
|
||||||
|
if res_item is None:
|
||||||
|
continue
|
||||||
|
# If input is list or tuple flattenize it
|
||||||
|
if isinstance(res_item, (list, tuple)):
|
||||||
|
decomposed_res = flattenize_outputs(res_item)
|
||||||
|
results.extend(decomposed_res)
|
||||||
|
continue
|
||||||
|
if isinstance(res_item, dict):
|
||||||
|
decomposed_res = flattenize_dict_outputs(res_item)
|
||||||
|
results.extend(decomposed_res)
|
||||||
|
continue
|
||||||
|
results.append(res_item)
|
||||||
|
return results
|
||||||
|
|
||||||
|
flatten_fw_res = flattenize_outputs(fw_res)
|
||||||
|
|
||||||
|
assert len(flatten_fw_res) == len(
|
||||||
|
output_list), f'number of outputs not equal, {len(flatten_fw_res)} != {len(output_list)}'
|
||||||
|
# check if results dtypes match
|
||||||
|
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
|
||||||
|
fw_tensor_type = np.array(fw_tensor).dtype
|
||||||
|
ov_tensor_type = ov_tensor.dtype
|
||||||
|
assert ov_tensor_type == fw_tensor_type, f"dtype validation failed: {ov_tensor_type} != {fw_tensor_type}"
|
||||||
|
|
||||||
|
if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
|
||||||
|
custom_eps = kwargs['custom_eps']
|
||||||
|
else:
|
||||||
|
custom_eps = 1e-4
|
||||||
|
|
||||||
|
# compare OpenVINO results with JAX results
|
||||||
|
fw_eps = custom_eps if precision == 'FP32' else 5e-2
|
||||||
|
is_ok = True
|
||||||
|
for i in range(len(flatten_fw_res)):
|
||||||
|
cur_fw_res = np.array(flatten_fw_res[i])
|
||||||
|
cur_ov_res = infer_res[compiled.output(i)]
|
||||||
|
print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}")
|
||||||
|
if not np.allclose(cur_ov_res, cur_fw_res,
|
||||||
|
atol=fw_eps,
|
||||||
|
rtol=fw_eps, equal_nan=True):
|
||||||
|
is_ok = False
|
||||||
|
print("Max diff is {}".format(
|
||||||
|
np.array(
|
||||||
|
abs(cur_ov_res - cur_fw_res)).max()))
|
||||||
|
else:
|
||||||
|
print("Accuracy validation successful!\n")
|
||||||
|
print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps))
|
||||||
|
assert is_ok, "Accuracy validation failed"
|
||||||
|
|
||||||
|
# Each model should specify inputs
|
||||||
|
def _prepare_input(self):
|
||||||
|
raise RuntimeError("Please provide inputs generation function")
|
||||||
|
|
||||||
|
def convert_via_tensorflow_function(self, model, inputs):
|
||||||
|
import tensorflow as tf
|
||||||
|
from jax.experimental import jax2tf
|
||||||
|
from openvino.tools.ovc import convert_model
|
||||||
|
# create function signature based on input shapes and types
|
||||||
|
function_signature = []
|
||||||
|
for _input in inputs:
|
||||||
|
assert isinstance(_input, np.ndarray)
|
||||||
|
input_shape = _input.shape
|
||||||
|
input_type = _input.dtype
|
||||||
|
function_signature.append(tf.TensorSpec(input_shape, input_type))
|
||||||
|
|
||||||
|
f = tf.function(jax2tf.convert(model), autograph=False,
|
||||||
|
input_signature=function_signature)
|
||||||
|
converted_model = convert_model(f)
|
||||||
|
return converted_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_params(ie_device=None, precision=None):
|
||||||
|
"""
|
||||||
|
:param ie_device: list of devices
|
||||||
|
:param precision: list of precisions
|
||||||
|
"""
|
||||||
|
|
||||||
|
ie_device_params = ie_device if ie_device else test_device
|
||||||
|
precision_params = precision if precision else test_precision
|
||||||
|
|
||||||
|
test_args = []
|
||||||
|
for element in itertools.product(ie_device_params, precision_params):
|
||||||
|
if element[0] == 'CPU' and element[1] == 'FP16':
|
||||||
|
continue
|
||||||
|
test_args.append(element)
|
||||||
|
return test_args
|
||||||
|
|
||||||
|
|
||||||
|
def ov_res_to_list(ov_res_dict):
|
||||||
|
# 118221: remove this WA that clean-up repeating output tensors
|
||||||
|
# with the same tensor names
|
||||||
|
# probably, we do not utilize some meta info from tf.function
|
||||||
|
values = []
|
||||||
|
met_names = set()
|
||||||
|
for ov_res_name, ov_res_value in ov_res_dict.items():
|
||||||
|
if bool(set(ov_res_name.names) & met_names):
|
||||||
|
continue
|
||||||
|
met_names |= set(ov_res_name.names)
|
||||||
|
values.append(ov_res_value)
|
||||||
|
return values
|
47
tests/layer_tests/jax_tests/test_dot_general.py
Normal file
47
tests/layer_tests/jax_tests/test_dot_general.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from jax import lax
|
||||||
|
|
||||||
|
from jax_layer_test_class import JaxLayerTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDotGeneral(JaxLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
lhs = np.random.randint(-10, 10, self.lhs_shape).astype(self.input_type)
|
||||||
|
rhs = np.random.randint(-10, 10, self.rhs_shape).astype(self.input_type)
|
||||||
|
return (lhs, rhs)
|
||||||
|
|
||||||
|
def create_model(self, lhs_shape, rhs_shape, dimension_numbers, input_type):
|
||||||
|
self.lhs_shape = lhs_shape
|
||||||
|
self.rhs_shape = rhs_shape
|
||||||
|
self.input_type = input_type
|
||||||
|
|
||||||
|
def jax_dot_general(lhs, rhs):
|
||||||
|
out = lax.dot_general(lhs, rhs, dimension_numbers)
|
||||||
|
return out
|
||||||
|
|
||||||
|
return jax_dot_general, None
|
||||||
|
|
||||||
|
test_data = [
|
||||||
|
# 1D vector dot 1D vector
|
||||||
|
dict(lhs_shape=[4], rhs_shape=[4], dimension_numbers=(((0), (0)), ((), ()))),
|
||||||
|
# matrix mxk dot vector k
|
||||||
|
dict(lhs_shape=[2, 5], rhs_shape=[5], dimension_numbers=(((1), (0)), ((), ()))),
|
||||||
|
# matrix mxk dot matrix kxn
|
||||||
|
dict(lhs_shape=[2, 5], rhs_shape=[5, 6], dimension_numbers=(((1), (0)), ((), ()))),
|
||||||
|
# batch matmul case
|
||||||
|
dict(lhs_shape=[3, 2, 3, 4], rhs_shape=[3, 2, 2, 4], dimension_numbers=(((3), (3)), ((0, 1), (0, 1)))),
|
||||||
|
# batch matmul case: different batch and contracting dimensions
|
||||||
|
dict(lhs_shape=[2, 3, 4, 5], rhs_shape=[4, 2, 5, 3], dimension_numbers=(((2, 3), (0, 2)), ((0, 1), (1, 3)))),
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
@pytest.mark.parametrize("params", test_data)
|
||||||
|
@pytest.mark.parametrize("input_type", [np.float32, np.int32])
|
||||||
|
def test_dot_general(self, ie_device, precision, ir_version, params, input_type):
|
||||||
|
self._test(*self.create_model(**params, input_type=input_type), ie_device, precision,
|
||||||
|
ir_version)
|
@ -6,3 +6,5 @@ torch
|
|||||||
torchvision
|
torchvision
|
||||||
pytest
|
pytest
|
||||||
tensorflow-addons; python_version <= '3.10'
|
tensorflow-addons; python_version <= '3.10'
|
||||||
|
jax; sys_platform == "linux"
|
||||||
|
jaxlib; sys_platform == "linux"
|
||||||
|
Loading…
Reference in New Issue
Block a user