[JAX][TF Hub][TF FE] Introduce JAX layer tests and support of XLA operations (#19269)

* [JAX][TF Hub][TF FE] Introduce JAX layer tests and support of XLA operations

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Fix JAX layer tests infa

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>

* Extend run for JAX layer tests

* Use ovc convert_model

* Fix translator and extend layer test cases

* Exclude jax testing on Windows

---------

Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2023-08-21 16:01:48 +04:00 committed by GitHub
parent cbe744b717
commit 5539d052b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1172 additions and 0 deletions

View File

@ -525,6 +525,15 @@ jobs:
TEST_DEVICE: CPU TEST_DEVICE: CPU
displayName: 'TensorFlow 2 Layer Tests - TF FE' displayName: 'TensorFlow 2 Layer Tests - TF FE'
- script: |
set -e
python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt
$(RUN_PREFIX) python3 -m pytest $(LAYER_TESTS_DIR)/jax_tests/ -m precommit --junitxml=$(INSTALL_TEST_DIR)/TEST-jax.xmlTEST
env:
PYTHONPATH: $(LAYER_TESTS_DIR)
TEST_DEVICE: CPU
displayName: 'JAX Layer Tests - TF FE'
- script: | - script: |
set -e set -e
python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt

1
.github/CODEOWNERS vendored
View File

@ -92,6 +92,7 @@
/tests/layer_tests/ @openvinotoolkit/openvino-tests-maintainers @openvinotoolkit/openvino-mo-maintainers /tests/layer_tests/ @openvinotoolkit/openvino-tests-maintainers @openvinotoolkit/openvino-mo-maintainers
/tests/layer_tests/pytorch_tests/ @openvinotoolkit/openvino-pytorch-frontend-maintainers /tests/layer_tests/pytorch_tests/ @openvinotoolkit/openvino-pytorch-frontend-maintainers
/tests/layer_tests/tensorflow_tests @openvinotoolkit/openvino-tf-frontend-maintainers /tests/layer_tests/tensorflow_tests @openvinotoolkit/openvino-tf-frontend-maintainers
/tests/layer_tests/jax_tests @openvinotoolkit/openvino-tf-frontend-maintainers
# Tools: # Tools:
/tools/ @openvinotoolkit/openvino-tools-maintainers /tools/ @openvinotoolkit/openvino-tools-maintainers

View File

@ -641,6 +641,17 @@ jobs:
env: env:
TEST_DEVICE: CPU TEST_DEVICE: CPU
- name: JAX Layer Tests - TF FE
run: |
python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt
export PYTHONPATH=${{ env.LAYER_TESTS_INSTALL_DIR }}:$PYTHONPATH
source ${{ env.INSTALL_DIR }}/setupvars.sh
python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/jax_tests/ -m precommit --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-jax.xml
env:
TEST_DEVICE: CPU
- name: TensorFlow 1 Layer Tests - Legacy FE - name: TensorFlow 1 Layer Tests - Legacy FE
run: | run: |
python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt

View File

@ -0,0 +1,200 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "common_op_table.hpp"
#include "input_model.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reduce_prod.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
#include "xla_data.pb.h"
using namespace std;
using namespace ov;
using namespace ov::op;
namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
vector<int64_t> compute_non_contracting_dims(const NodeContext& node,
const vector<int64_t>& batch_dims,
const vector<int64_t>& contracting_dims,
const Output<Node>& operand) {
// combine two vectors of batch_dims and contracting_dims
set<int64_t> unique_dims(batch_dims.begin(), batch_dims.end());
unique_dims.insert(contracting_dims.begin(), contracting_dims.end());
vector<int64_t> all_dims(unique_dims.begin(), unique_dims.end());
TENSORFLOW_OP_VALIDATION(node,
operand.get_partial_shape().rank().is_static(),
"[TensorFlow Frontend] internal operation: XlaDotV2 expects inputs of static rank");
int64_t operand_rank = operand.get_partial_shape().rank().get_length();
vector<int64_t> non_contracting_dims;
for (int64_t ind = 0; ind < operand_rank; ++ind) {
if (find(all_dims.begin(), all_dims.end(), ind) == all_dims.end()) {
non_contracting_dims.push_back(ind);
}
}
return non_contracting_dims;
}
void insert_aux_dim(const NodeContext& node, Output<Node>& operand, vector<int64_t>& dims) {
TENSORFLOW_OP_VALIDATION(node,
operand.get_partial_shape().rank().is_static(),
"[TensorFlow Frontend] internal operation: XlaDotV2 expects inputs of static rank");
if (dims.size() == 0) {
int64_t operand_rank = operand.get_partial_shape().rank().get_length();
dims.push_back(operand_rank);
auto unsqueeze_axis = make_shared<v0::Constant>(element::i64, Shape{1}, operand_rank);
operand = make_shared<v0::Unsqueeze>(operand, unsqueeze_axis);
}
}
void insert_aux_dims(const NodeContext& node,
Output<Node>& operand,
vector<int64_t>& batch_dims,
vector<int64_t>& contracting_dims,
vector<int64_t>& non_contract_dims) {
insert_aux_dim(node, operand, batch_dims);
insert_aux_dim(node, operand, contracting_dims);
insert_aux_dim(node, operand, non_contract_dims);
}
Output<Node> compute_dims_shape(const Output<Node>& hs_shape, const vector<int64_t>& dims) {
auto const_dims = make_shared<v0::Constant>(element::i64, Shape{dims.size()}, dims);
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto dims_shape = make_shared<v8::Gather>(hs_shape, const_dims, gather_axis);
return dims_shape;
}
Output<Node> compute_dims_size(const Output<Node>& hs_shape, const Output<Node>& dims) {
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto dims_shape = make_shared<v8::Gather>(hs_shape, dims, gather_axis);
auto dims_size = make_shared<v1::ReduceProd>(dims_shape, gather_axis, true);
return {dims_size};
}
OutputVector translate_xla_dot_op(const NodeContext& node) {
// see specification of XlaDotV2 here: https://www.tensorflow.org/xla/operation_semantics#dot
default_op_checks(node, 2, {"XlaDotV2"});
auto lhs = node.get_input(0);
auto rhs = node.get_input(1);
auto node_name = node.get_name();
auto dimension_numbers_message = node.get_attribute<string>("dimension_numbers");
::xla::DotDimensionNumbers dimension_numbers;
TENSORFLOW_OP_VALIDATION(
node,
dimension_numbers.ParseFromArray(dimension_numbers_message.data(),
static_cast<int>(dimension_numbers_message.size())),
"[TensorFlow Frontend] Incorrect input model: incorrect DotDimensionNumbers field for XlaDotV2 " + node_name);
vector<int64_t> lhs_batch_dims(dimension_numbers.lhs_batch_dimensions().begin(),
dimension_numbers.lhs_batch_dimensions().end());
vector<int64_t> rhs_batch_dims(dimension_numbers.rhs_batch_dimensions().begin(),
dimension_numbers.rhs_batch_dimensions().end());
vector<int64_t> rhs_contract_dims(dimension_numbers.rhs_contracting_dimensions().begin(),
dimension_numbers.rhs_contracting_dimensions().end());
vector<int64_t> lhs_contract_dims(dimension_numbers.lhs_contracting_dimensions().begin(),
dimension_numbers.lhs_contracting_dimensions().end());
// compute non-contracting dimensions
auto lhs_non_contract_dims = compute_non_contracting_dims(node, lhs_batch_dims, lhs_contract_dims, lhs);
auto rhs_non_contract_dims = compute_non_contracting_dims(node, rhs_batch_dims, rhs_contract_dims, rhs);
// compute the resulted shape before possible modification
auto resulted_shape = make_shared<v0::Constant>(element::i64, Shape{0}, vector<int64_t>{})->output(0);
bool apply_reshape = false;
auto lhs_shape = make_shared<v3::ShapeOf>(lhs, element::i64);
auto rhs_shape = make_shared<v3::ShapeOf>(rhs, element::i64);
if (lhs_batch_dims.size() > 0) {
auto batch_dims_shape = compute_dims_shape(lhs_shape, lhs_batch_dims);
resulted_shape = make_shared<v0::Concat>(OutputVector{resulted_shape, batch_dims_shape}, 0);
apply_reshape = true;
}
if (lhs_non_contract_dims.size() > 0) {
auto lhs_non_contract_shape = compute_dims_shape(lhs_shape, lhs_non_contract_dims);
resulted_shape = make_shared<v0::Concat>(OutputVector{resulted_shape, lhs_non_contract_shape}, 0);
apply_reshape = true;
}
if (rhs_non_contract_dims.size() > 0) {
auto rhs_non_contract_shape = compute_dims_shape(rhs_shape, rhs_non_contract_dims);
resulted_shape = make_shared<v0::Concat>(OutputVector{resulted_shape, rhs_non_contract_shape}, 0);
apply_reshape = true;
}
// take care of that at least one dimension of each type (batch, contracting, and non-contracting) exists
// if it does not, insert it to the end
insert_aux_dims(node, lhs, lhs_batch_dims, lhs_contract_dims, lhs_non_contract_dims);
insert_aux_dims(node, rhs, rhs_batch_dims, rhs_contract_dims, rhs_non_contract_dims);
// compute non-batch and non-contracting dimensions
auto const_lhs_batch_dims = make_shared<v0::Constant>(element::i64, Shape{lhs_batch_dims.size()}, lhs_batch_dims);
auto const_rhs_batch_dims = make_shared<v0::Constant>(element::i64, Shape{rhs_batch_dims.size()}, rhs_batch_dims);
auto const_lhs_contract_dims =
make_shared<v0::Constant>(element::i64, Shape{lhs_contract_dims.size()}, lhs_contract_dims);
auto const_rhs_contract_dims =
make_shared<v0::Constant>(element::i64, Shape{rhs_contract_dims.size()}, rhs_contract_dims);
auto const_lhs_non_contract_dims =
make_shared<v0::Constant>(element::i64, Shape{lhs_non_contract_dims.size()}, lhs_non_contract_dims);
auto const_rhs_non_contract_dims =
make_shared<v0::Constant>(element::i64, Shape{rhs_non_contract_dims.size()}, rhs_non_contract_dims);
lhs_shape = make_shared<v3::ShapeOf>(lhs, element::i64);
rhs_shape = make_shared<v3::ShapeOf>(rhs, element::i64);
// compute a part of the input shape covering batch dimensions and non-contracting dimensions
auto gather_axis = make_shared<v0::Constant>(element::i32, Shape{1}, 0);
auto batch_dims_shape = compute_dims_shape(lhs_shape, lhs_batch_dims);
// transpose both operand in a way to have dimensions in the order
// [batch dims, non-contracting dims, contracting dims]
auto lhs_transpose_order = make_shared<v0::Concat>(
OutputVector{const_lhs_batch_dims, const_lhs_non_contract_dims, const_lhs_contract_dims},
0);
auto rhs_transpose_order = make_shared<v0::Concat>(
OutputVector{const_rhs_batch_dims, const_rhs_non_contract_dims, const_rhs_contract_dims},
0);
lhs = make_shared<v1::Transpose>(lhs, lhs_transpose_order);
rhs = make_shared<v1::Transpose>(rhs, rhs_transpose_order);
// compute size of contracting dims and non-contracting dims for each operand
auto lhs_contract_size = compute_dims_size(lhs_shape, const_lhs_contract_dims);
auto rhs_contract_size = compute_dims_size(rhs_shape, const_rhs_contract_dims);
auto lhs_non_contract_size = compute_dims_size(lhs_shape, const_lhs_non_contract_dims);
auto rhs_non_contract_size = compute_dims_size(rhs_shape, const_rhs_non_contract_dims);
// merge contracting and non-contracting dimensions to have operand
// of a shape [batch dims, non-contracting dim size, contracting dims size]
auto new_lhs_shape =
make_shared<v0::Concat>(OutputVector{batch_dims_shape, lhs_non_contract_size, lhs_contract_size}, 0);
auto new_rhs_shape =
make_shared<v0::Concat>(OutputVector{batch_dims_shape, rhs_non_contract_size, rhs_contract_size}, 0);
lhs = make_shared<v1::Reshape>(lhs, new_lhs_shape, false);
rhs = make_shared<v1::Reshape>(rhs, new_rhs_shape, false);
// execute MatMul that support batch matrix-multiplication
// note that the second operand is transposed
auto matmul = make_shared<v0::MatMul>(lhs, rhs, false, true)->output(0);
if (apply_reshape) {
matmul = make_shared<v1::Reshape>(matmul, resulted_shape, false);
}
set_node_name(node_name, matmul.get_node_shared_ptr());
return {matmul};
}
} // namespace op
} // namespace tensorflow
} // namespace frontend
} // namespace ov

View File

@ -46,6 +46,7 @@ TF_OP_CONVERTER(translate_varhandle_op);
TF_OP_CONVERTER(translate_variable_op); TF_OP_CONVERTER(translate_variable_op);
TF_OP_CONVERTER(translate_varisinitialized_op); TF_OP_CONVERTER(translate_varisinitialized_op);
TF_OP_CONVERTER(translate_while_op); TF_OP_CONVERTER(translate_while_op);
TF_OP_CONVERTER(translate_xla_dot_op);
const std::map<std::string, CreatorFunction> get_supported_ops() { const std::map<std::string, CreatorFunction> get_supported_ops() {
return { return {
@ -301,6 +302,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"SparseFillEmptyRows", CreatorFunction(translate_sparse_fill_empty_rows_op)}, {"SparseFillEmptyRows", CreatorFunction(translate_sparse_fill_empty_rows_op)},
{"SparseSegmentSum", CreatorFunction(translate_sparse_segment_sum_op)}, {"SparseSegmentSum", CreatorFunction(translate_sparse_segment_sum_op)},
{"Unique", CreatorFunction(translate_unique_op)}, {"Unique", CreatorFunction(translate_unique_op)},
// XLA operations
{"XlaDotV2", CreatorFunction(translate_xla_dot_op)},
}; };
}; };
} // namespace op } // namespace op

View File

@ -0,0 +1,744 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
syntax = "proto3";
package xla;
option cc_enable_arenas = true;
// Primitive types are the individual values that can be held in rectangular
// multidimensional arrays. A description of the rectangular multidimensional
// array dimensions / primitive type is given by Shape, below.
//
// LINT.IfChange
enum PrimitiveType {
// Invalid primitive type to serve as default.
PRIMITIVE_TYPE_INVALID = 0;
// Predicates are two-state booleans.
PRED = 1;
// Signed integral values of fixed width.
S8 = 2;
S16 = 3;
S32 = 4;
S64 = 5;
// Unsigned integral values of fixed width.
U8 = 6;
U16 = 7;
U32 = 8;
U64 = 9;
// Floating-point values of fixed width.
//
// Note: if f16s are not natively supported on the device, they will be
// converted to f16 from f32 at arbirary points in the computation.
F16 = 10;
F32 = 11;
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
// and 7 bits for the mantissa.
BF16 = 16;
F64 = 12;
// Complex values of fixed width.
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>.
C128 = 18; // Paired F64 (real, imag), as in std::complex<double>.
// A tuple is a polymorphic sequence; e.g. a shape that holds different
// sub-shapes. They are used for things like returning multiple values from a
// computation; e.g. a computation that returns weights and biases may have a
// signature that results in a tuple like (f32[784x2000], f32[2000])
//
// If a shape proto has the tuple element type, it may not have any entries
// in the dimensions field.
TUPLE = 13;
// An opaque type used for passing context-specific data to a custom
// operation. Shapes of this primitive type will have empty dimensions and
// tuple_shapes fields.
//
// (OPAQUE would be a better name for this identifier, but that conflicts with
// a macro defined in windows.h.)
OPAQUE_TYPE = 14;
// A token type threaded between side-effecting operations. Shapes of this
// primitive type will have empty dimensions and tuple_shapes fields.
TOKEN = 17;
// Next = 19
}
// LINT.ThenChange(
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,
// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc
// )
// Describes the padding configuration for Pad operation. The padding amount on
// both edges as well as between the elements are specified for each dimension.
message PaddingConfig {
// Describes the padding configuration for a dimension.
message PaddingConfigDimension {
// Padding amount on the low-end (next to the index 0). May be negative.
int64 edge_padding_low = 1;
// Padding amount on the high-end (next to the highest index). May be
// negative.
int64 edge_padding_high = 2;
// Padding amount between the elements. May not be negative.
int64 interior_padding = 3;
}
// The padding configuration for all dimensions.
repeated PaddingConfigDimension dimensions = 1;
}
// A format specifies the method used by a layout to store an array in memory.
enum Format {
// TODO(b/120869032): Rename this to FORMAT_NONE or something else which
// better corresponds to its meaning.
INVALID_FORMAT = 0;
// The default layout, with exactly one storage location per element.
DENSE = 1;
reserved 2;
reserved "SPARSE";
}
// Describes a tile used in tiling-based layout. Refer to
// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
// details about tiling-based layout.
message TileProto {
// Number of elements in each dimension of the tile. It's ordered from the
// most major dimension of the tile to the most minor dimension of the tile.
// The dimensions correspond to a suffix of the dimensions of the shape being
// tiled.
repeated int64 dimensions = 1;
}
// A layout describes how the array is placed in (1D) memory space. This
// includes the minor-to-major ordering of dimensions within a shape.
//
// Clients must specify the layouts of input Literals to the
// computation. Layouts specified in interior operations which take Shapes (for
// example, Convert) are ignored.
//
// See the XLA documentation for more information on shapes and layouts.
//
// LINT.IfChange
message LayoutProto {
// The method used to store the data in memory. The format determines which of
// the other fields are used by the layout.
Format format = 4;
// Sequence of dimension numbers, from minor (fastest varying index) to major
// (slowest varying index). This field is required.
repeated int64 minor_to_major = 1;
reserved 2;
reserved "padded_dimensions";
reserved 3;
reserved "padding_value";
reserved 5;
reserved "max_sparse_elements";
// A sequence of tiles, starting from the tile that's applied first to the
// Shape.
//
// TODO(b/119839262): implement tiling in each backend or add Unimplemented
// error.
repeated TileProto tiles = 6;
// Bit size of each element. If the size is bigger than what the element
// type requires, the value is stored in the least significant
// bits and the additional most significant bits are filled with 0's.
//
// TODO(b/119839262): implement in each backend or add Unimplemented error.
int64 element_size_in_bits = 7;
// Memory space where this array resides. The integer field is interpreted in
// a backend-specific manner.
int64 memory_space = 8;
// Important: if any field is added, be sure to modify ShapeUtil::Equal() and
// LayoutUtil::Hash appropriately to account for the new field.
}
// LINT.ThenChange( \
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \
// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
// A shape describes the number of dimensions in the array, the size of each
// dimension, and the primitive component type.
//
// Tuples are a special case in that they have rank zero and have tuple_shapes
// defined.
//
// See the XLA documentation for more information on shapes and layouts.
//
// LINT.IfChange
message ShapeProto {
reserved 1;
reserved "rank";
// The element type for this shape.
PrimitiveType element_type = 2;
// The size (number of elements) for each dimension, or an upper bound on the
// size if the dimension is dynamic. In XLA, dimensions are numbered from 0
// to N-1 for an N-dimensional array. The first element of 'dimensions' is the
// size of dimension 0, the second element is the size of dimension 1, and so
// forth. Empty list indicates a scalar.
//
// If the respective element in 'is_dimension_dynamic' is true then the value
// in this field represents an upper bound on the size of the dimension.
repeated int64 dimensions = 3;
// For tuples only, the shapes of constituent shapes in the tuple sequence.
repeated ShapeProto tuple_shapes = 4;
// The layout used to back this shape.
LayoutProto layout = 5;
// For arrays, this indicates whether or not each dimension is
// dynamically-sized. The number of elements in this repeated field should be
// zero (indicating that no dimensions are dynamic) or equal to the number of
// elements in the 'dimensions' field.
repeated bool is_dynamic_dimension = 6;
// Important: if any field is added, be sure to modify ShapeUtil::Equal(),
// ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
// the new field.
}
// LINT.ThenChange( \
// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
// Shape of the parameters and output of a computation (like a traditional
// function signature).
message ProgramShapeProto {
repeated ShapeProto parameters = 1;
ShapeProto result = 2;
repeated string parameter_names = 3;
}
// Statistics of a computation.
message ComputationStats {
// The number of floating point operations in the computation.
double flop_count = 1;
// The number of transcendental operations (e.g., exp) in the computation.
double transcendental_count = 2;
}
// The type optimization profiles in use.
enum ProfileType {
INVALID = 0;
WINDOW = 1;
FLAG = 2;
INTEGER = 3;
}
// Symbolization metadata for HLO Instructions.
//
// This metadata is used for debugging XLA code generation, as well as
// performance profiling of XLA-generated executables.
message OpMetadata {
// The framework op name that generated this XLA op.
//
// Frameworks that build on top of XLA should mirror the names of their ops
// back to users by specifying the op_type. In this way, even if the
// framework's "ops" are implemented as multiple XLA HLO Ops, they can be
// grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
// multiple ops, then each op should have the op_type be "SoftMax".)
string op_type = 1;
// The user-specified name of the op.
//
// This name is often unique within a computation. Note: some frameworks
// add auto-generated names if the user does not provide one.
string op_name = 2;
// Indicate a file and line that this op is associated to in a user's program.
//
// e.g. it could be the file and line of user code that generated the op.
string source_file = 3;
int32 source_line = 4;
repeated ProfileType profile_type = 5;
// HloPassMetadata.pass_id of the pass that created this HLO instruction
// object. Should never be copied between HLO instructions. Zero if unset and
// -1 if the instruction was created before HLO passes began.
int64 creation_pass_id = 6;
// HloPassMetadata.pass_id of the pass that created the logical functionality
// that this HLO instruction represents. Should be copied between HLO
// instructions that correspond across compilation passes. Zero if unset and
// -1 if the instruction was created before HLO passes began.
int64 logical_creation_pass_id = 7;
// The footprint of the generated code for the instruction.
int64 size_of_generated_code_in_bytes = 8;
// The size of the working set, i.e., the amount of memory, used by the
// instruction in a compiler-managed fast device memory.
int64 size_of_memory_working_set_in_bytes = 9;
}
// Profile data from the execution of a computation.
message ExecutionProfile {
// Whether the executable was read from the compilation cache.
bool compilation_cache_hit = 1;
// The time in milliseconds spent to compile the computation. This only set if
// the executable was not read from the compilation cache
// (compilation_cache_hit == false).
int64 compile_time_ms = 2;
// The number of cycles spent for the computation. This does not include the
// time taken for the data transfers between the host and the device. This is
// a target-dependent field and only used for debugging purposes.
int64 compute_cycle_count = 3;
// The time in nanoseconds spent for the computation, without data transfer.
int64 compute_time_ns = 4;
// The time in nanoseconds spent for the entire computation, including the
// result data transfer time. Current implementation does not spend any cycles
// for the input data transfer since the memory is initialized with the proper
// values before the execution.
int64 compute_and_transfer_time_ns = 5;
// The size of the binary code in the executable.
int64 executable_size_in_bytes = 6;
// Whether this profile was drawn from a cache of profiles instead of from
// execution on the hardware.
bool profile_cache_hit = 7;
}
// Handle given to a user that represents an execution that the user launched
// asynchronously on the device.
message ExecutionHandle {
int64 handle = 1;
}
// Handle given to a user that represents a globally accessible allocation.
// Contrast this against a ComputationDataHandle, which is not globally
// accessible, since it only exists within a specific computation.
message GlobalDataHandle {
int64 handle = 1;
}
// Handle given to a user that represents a replicated virtual device. Each
// replicated device represents N physical devices for execution where N is the
// number of replicas.
message DeviceHandle {
int64 handle = 1;
// The number of model-parallel virtual devices that communicate via XLA
// Send/Recv instructions.
int64 device_count = 2;
}
// Handle given to a user to represent a channel between two computations
// via a Send and Recv instruction pair. Channels are unbuffered, so Send
// Send instructions will be blocked until the data is transferred.
message ChannelHandle {
int64 handle = 1;
enum ChannelType {
// Invalid primitive type to serve as default.
CHANNEL_TYPE_INVALID = 0;
// A channel for sending data between devices.
DEVICE_TO_DEVICE = 1;
// A channel for sending data from the device to the host. Can only be used
// with a Send operation.
DEVICE_TO_HOST = 2;
// A channel for sending data from the host to the device. Can only be used
// with a Recv operation.
HOST_TO_DEVICE = 3;
}
ChannelType type = 2;
}
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
// represents the device ids assigned to a set of replicated computations.
// See xla::DeviceAssignment class comment for more details.
message DeviceAssignmentProto {
int32 replica_count = 1;
int32 computation_count = 2;
// Each logical computation runs on replica_count physical devices.
// ComputationDevice represents the device ids assinged to the replicas.
message ComputationDevice {
repeated int32 replica_device_ids = 1;
}
repeated ComputationDevice computation_devices = 3;
}
// Literals are used when the server and client need to exchange materialized
// data / results. Literals are also used to describe constants used in
// computations.
//
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message LiteralProto {
ShapeProto shape = 1;
repeated bool preds = 2;
bytes s8s = 15;
bytes u8s = 3;
repeated int32 s32s = 4;
repeated int64 s64s = 5;
repeated uint32 u32s = 6;
repeated uint64 u64s = 7;
repeated float f32s = 8;
repeated double f64s = 9;
repeated float c64s = 12; // Stored as interleaved real, imag floats.
repeated double c128s = 18; // Stored as interleaved real, imag doubles.
repeated LiteralProto tuple_literals = 10;
// The F16s, BF16s, U16s and S16s are encoded in little endian byte order
bytes f16s = 11;
bytes bf16s = 13;
bytes u16s = 16;
bytes s16s = 17;
repeated int64 sparse_indices = 14;
// Next = 19
}
message WindowDimension {
// The size of the window in this dimension. For a rectangle, this would be
// the width or height.
int64 size = 1;
// The stride at which the window moves across the base area in this
// dimension. In other words, this is the spacing between different
// positions of the window in this dimension.
int64 stride = 2;
// If positive, means the amount of padding to add to the base area at the low
// end of this dimension; if negative, its negative means the number of
// elements removed from the low end of this dimension. For example, in the
// horizontal dimension of a rectangle, this would be the number of padding
// values to pad on the left, given that indices increase when going right.
// The actual padding value depends upon the context. Convolution pads with
// zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
// init value.
int64 padding_low = 3;
// As padding_low, but on the high end of this dimension. For example, in the
// horizontal dimension of a rectangle, this would be the number of values to
// pad on the right, given that indices increase when going right.
int64 padding_high = 4;
// Dilation factor of the sliding window in this dimension. A dilation factor
// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
// implicitly placed between each kernel element. This value may not be less
// than 1. See documentation for convolution.
int64 window_dilation = 5;
// Dilation factor of the base area in this dimension. A dilation factor of 1
// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
// placed between each base area element. This value may not be less than 1.
// See documentation for convolution.
int64 base_dilation = 6;
// Window reversal means that this dimension was logically reversed before the
// operation.
bool window_reversal = 7;
}
// Describes the windowing in an operation such as convolution.
//
// The window is moved across a base area and for each position of the
// window a computation is performed. The field below describes the
// window and the movement of the window across a base area.
message Window {
repeated WindowDimension dimensions = 1;
}
// Describes the dimension numbers for a gather operation.
//
// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
// more details.
message GatherDimensionNumbers {
// "Window indices" is a term for a set of indices that index into the
// interior of a dynamic-slice from the input tensor, the starting indices for
// which were computed from output_gather_dims (see the operation semantic for
// how this is defined) and the start_indices tensor.
//
// The window indices for a specific output index Out is computed as:
//
// i = 0
// for (k : [0, input_tensor_shape.rank))
// window_indices[k] =
// if k in collapsed_slice_dims
// then 0
// else Out[offset_dims[i++]]
repeated int64 offset_dims = 1;
repeated int64 collapsed_slice_dims = 2;
// This is interpreted as a map from i to start_index_map[i]. It
// transforms the gather index looked up from the start_indices tensor into
// the starting index in the input space.
repeated int64 start_index_map = 3;
// The dimension in the start_indices input that contains the starting
// indices.
int64 index_vector_dim = 4;
}
// Describes the dimension numbers for a scatter operation.
//
// All the fields are similar to the corresponding fields in
// GatherDimensionNumbers. Differences are noted below.
message ScatterDimensionNumbers {
// The set of dimensions in the updates shape that are window dimensions.
repeated int64 update_window_dims = 1;
// The set of window dimensions that must be inserted into the updates shape.
repeated int64 inserted_window_dims = 2;
repeated int64 scatter_dims_to_operand_dims = 3;
int64 index_vector_dim = 4;
}
message ConvolutionDimensionNumbers {
// The number of the dimension that represents batch in the input.
int64 input_batch_dimension = 7;
// The number of the dimension that represents features in the input.
int64 input_feature_dimension = 8;
// The dimension numbers for the spatial dimensions that the window
// moves through in the input.
repeated int64 input_spatial_dimensions = 11;
// The number of the dimension that represents input features in the
// convolutional kernel (rhs).
int64 kernel_input_feature_dimension = 3;
// The number of the dimension that represents output features in
// the convolutional kernel (rhs).
int64 kernel_output_feature_dimension = 4;
// The dimension numbers for the spatial dimensions that the window
// moves through in the kernel (rhs). window.strides(0) is the
// stride in the kernel_spatial_dimensions(0) dimension.
repeated int64 kernel_spatial_dimensions = 6;
// The number of the dimension that represents batch in the output.
int64 output_batch_dimension = 9;
// The number of the dimension that represents features in the output.
int64 output_feature_dimension = 10;
// The dimension numbers for the spatial dimensions that the window
// moves through in the output.
repeated int64 output_spatial_dimensions = 12;
// Next = 13
}
enum PaddingType {
PADDING_INVALID = 0;
PADDING_VALID = 1; // Only valid portion of the base are covered.
PADDING_SAME = 2; // Extra is added to produce same output size as the input.
}
enum FftType {
FFT = 0; // Forward FFT; complex in, complex out.
IFFT = 1; // Inverse FFT; complex in, complex out.
RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out
IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in,
// fft_length real out
}
message DotDimensionNumbers {
// The dimension numbers that represent the 'lhs' contracting dimensions.
repeated int64 lhs_contracting_dimensions = 1;
// The dimension numbers that represent the 'rhs' contracting dimensions.
repeated int64 rhs_contracting_dimensions = 2;
// The dimension numbers that represent the 'lhs' batch dimensions.
repeated int64 lhs_batch_dimensions = 3;
// The dimension numbers that represent the 'rhs' batch dimensions.
repeated int64 rhs_batch_dimensions = 4;
}
enum RandomDistribution {
RNG_INVALID = 0;
// Creates a uniform-distribution-generated random number on the semi-open
// interval [parameter[0], parameter[1]).
RNG_UNIFORM = 1;
// Creates a normal-distribution-generated random number with mean
// parameter[0] and standard deviation parameter[1].
RNG_NORMAL = 2;
// Next: 4
}
enum RandomAlgorithm {
RNG_DEFAULT = 0; // Backend dependent default algorithm.
RNG_THREE_FRY = 1;
RNG_PHILOX = 2;
// Next: 2
}
message TriangularSolveOptions {
// If true, solves ax = b. If false, solves xa = b.
bool left_side = 1;
// If true, 'a' is lower triangular. If false, 'a' is upper triangular.
bool lower = 2;
// If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
bool unit_diagonal = 3;
// Should we transpose or use the adjoint of 'a'?
enum Transpose {
TRANSPOSE_INVALID = 0;
NO_TRANSPOSE = 1; // Don't transpose 'a'.
TRANSPOSE = 2; // Transpose 'a'.
ADJOINT = 3; // Complex conjugate and transpose 'a'.
}
Transpose transpose_a = 4;
}
message CholeskyOptions {
// If true, uses the lower triangle of `a`. If false, uses the upper triangle
// of `a`.
bool lower = 1;
}
// Generic map of attributes used to pass hints / configuration options from
// the Python frontend to the XLA backend.
message FrontendAttributes {
map<string, string> map = 1;
}
message OpSharding {
enum Type {
// This sharding is replicated across all devices (implies maximal,
// all other fields are unused).
REPLICATED = 0;
// This sharding is maximal - one device runs the entire operation.
MAXIMAL = 1;
// This sharding is a tuple - only the tuple_shardings field is valid.
TUPLE = 2;
// None of the above; tile_shape and tile_assignment are both used.
OTHER = 3;
// This op is manually sharded: the shapes are already partitioned and the
// partitioner should not change this op.
MANUAL = 4;
}
Type type = 1;
// The shape of the sharded tile.
ShapeProto tile_shape = 2;
// The shape of the tile assignment tensor - this must be the same rank as
// tile_shape and the product of its dimensions must equal
// tile_assignment_devices.size().
repeated int64 tile_assignment_dimensions = 3;
// Flattened list of device IDs. The order of flattening is the same as used
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
repeated int64 tile_assignment_devices = 4;
// If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
// in pre-order. The tuple shape could be nested; here we store just a
// flattened list of all leaves in the tuple shape. Note that the tuple shape
// is not stored here; shardings do not store the shapes to which they are
// applied, this is inferred from the instruction this sharding gets attached
// to.
repeated OpSharding tuple_shardings = 5;
// Only used for OTHER type. If true, data is sharded according to other
// dimensions of tile_assignment(), but replicated across devices along the
// last dimension. (Experimental)
bool replicate_on_last_tile_dim = 6;
// This field is used to track the source of this sharding, usually derived
// from instructions. Multple metadata may be populated if sharding is
// combined with other shardings. Metadata are to not be populated when
// type == TUPLE and instead metadata should be set on individual tuple
// elements.
repeated OpMetadata metadata = 7;
}
// Describes the replica groups in a cross replica op (e.g., all-reduce and
// all-to-all).
message ReplicaGroup {
// The ids of the replicas that belongs to the same group. The ordering of the
// ids matters in some ops (e.g., all-to-all).
repeated int64 replica_ids = 1;
}
// Describes the source target pair in the collective permute op.
message SourceTarget {
int64 source = 1;
int64 target = 2;
}
// Used to indicate the precision configuration. It has backend specific
// meaning.
message PrecisionConfig {
enum Precision {
DEFAULT = 0;
HIGH = 1;
HIGHEST = 2;
// Next: 3
}
repeated Precision operand_precision = 1;
// Next: 2
}
// Describes whether all data-parallelism replicas will receive the same
// parameter data at each buffer.
message ParameterReplication {
// A list of boolean values for the flattened leaf buffers. Each value
// indicates whether the corresponding leaf buffer is replicated.
//
// If this field is empty, it means no buffer is replicated. Otherwise, the
// number of elements in this field must match the number of leaf buffers in
// the HLO instruction's shape.
repeated bool replicated_at_leaf_buffers = 1;
}
// A backend-config for kWhile loops that stores the loop's trip count, if it is
// known.
//
// This is useful for backends that can implement a `for i in 0..N` loop more
// efficiently than a `while` loop. For example, on GPUs, we can implement a
// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
// whereas implementing a `while` loop requires a host-device sync on each
// iteration.
message WhileLoopBackendConfig {
message KnownTripCount {
int64 n = 1;
}
// This indirection lets us distinguish between known-trip-count == 0 and
// unknown-trip-count.
KnownTripCount known_trip_count = 1;
}
// Specifies a pair of output/operand buffers for kCustomCall that alias each
// other.
message CustomCallOutputOperandAliasing {
repeated int64 output_shape_index = 1;
int64 operand_index = 2;
repeated int64 operand_shape_index = 3;
}

View File

@ -0,0 +1,12 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import inspect
from common.layer_test_class import get_params
def pytest_generate_tests(metafunc):
test_gen_attrs_names = list(inspect.signature(get_params).parameters)
params = get_params()
metafunc.parametrize(test_gen_attrs_names, params, scope="function")

View File

@ -0,0 +1,142 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import itertools
from copy import deepcopy
import numpy as np
from common.constants import test_device, test_precision
from jax import numpy as jnp
from openvino.runtime import Core
class JaxLayerTest:
def _test(self, model, ref_net, ie_device, precision, ir_version, infer_timeout=60, dynamic_shapes=True,
**kwargs):
"""
:param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled.
Example: "transform_1,transform_2"
"""
inputs = self._prepare_input()
converted_model = self.convert_via_tensorflow_function(model, inputs)
# OV infer:
core = Core()
compiled = core.compile_model(converted_model, ie_device)
infer_res = compiled(deepcopy(inputs))
# Framework infer:
fw_res = model(*deepcopy(inputs))
if not isinstance(fw_res, (tuple)):
fw_res = (fw_res,)
output_list = ov_res_to_list(infer_res)
def flattenize_dict_outputs(res):
if isinstance(res, dict):
return flattenize_outputs(res.values())
def flattenize_outputs(res):
results = []
for res_item in res:
# if None is at output we skip it
if res_item is None:
continue
# If input is list or tuple flattenize it
if isinstance(res_item, (list, tuple)):
decomposed_res = flattenize_outputs(res_item)
results.extend(decomposed_res)
continue
if isinstance(res_item, dict):
decomposed_res = flattenize_dict_outputs(res_item)
results.extend(decomposed_res)
continue
results.append(res_item)
return results
flatten_fw_res = flattenize_outputs(fw_res)
assert len(flatten_fw_res) == len(
output_list), f'number of outputs not equal, {len(flatten_fw_res)} != {len(output_list)}'
# check if results dtypes match
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
fw_tensor_type = np.array(fw_tensor).dtype
ov_tensor_type = ov_tensor.dtype
assert ov_tensor_type == fw_tensor_type, f"dtype validation failed: {ov_tensor_type} != {fw_tensor_type}"
if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
custom_eps = kwargs['custom_eps']
else:
custom_eps = 1e-4
# compare OpenVINO results with JAX results
fw_eps = custom_eps if precision == 'FP32' else 5e-2
is_ok = True
for i in range(len(flatten_fw_res)):
cur_fw_res = np.array(flatten_fw_res[i])
cur_ov_res = infer_res[compiled.output(i)]
print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}")
if not np.allclose(cur_ov_res, cur_fw_res,
atol=fw_eps,
rtol=fw_eps, equal_nan=True):
is_ok = False
print("Max diff is {}".format(
np.array(
abs(cur_ov_res - cur_fw_res)).max()))
else:
print("Accuracy validation successful!\n")
print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps))
assert is_ok, "Accuracy validation failed"
# Each model should specify inputs
def _prepare_input(self):
raise RuntimeError("Please provide inputs generation function")
def convert_via_tensorflow_function(self, model, inputs):
import tensorflow as tf
from jax.experimental import jax2tf
from openvino.tools.ovc import convert_model
# create function signature based on input shapes and types
function_signature = []
for _input in inputs:
assert isinstance(_input, np.ndarray)
input_shape = _input.shape
input_type = _input.dtype
function_signature.append(tf.TensorSpec(input_shape, input_type))
f = tf.function(jax2tf.convert(model), autograph=False,
input_signature=function_signature)
converted_model = convert_model(f)
return converted_model
def get_params(ie_device=None, precision=None):
"""
:param ie_device: list of devices
:param precision: list of precisions
"""
ie_device_params = ie_device if ie_device else test_device
precision_params = precision if precision else test_precision
test_args = []
for element in itertools.product(ie_device_params, precision_params):
if element[0] == 'CPU' and element[1] == 'FP16':
continue
test_args.append(element)
return test_args
def ov_res_to_list(ov_res_dict):
# 118221: remove this WA that clean-up repeating output tensors
# with the same tensor names
# probably, we do not utilize some meta info from tf.function
values = []
met_names = set()
for ov_res_name, ov_res_value in ov_res_dict.items():
if bool(set(ov_res_name.names) & met_names):
continue
met_names |= set(ov_res_name.names)
values.append(ov_res_value)
return values

View File

@ -0,0 +1,47 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from jax import lax
from jax_layer_test_class import JaxLayerTest
class TestDotGeneral(JaxLayerTest):
def _prepare_input(self):
lhs = np.random.randint(-10, 10, self.lhs_shape).astype(self.input_type)
rhs = np.random.randint(-10, 10, self.rhs_shape).astype(self.input_type)
return (lhs, rhs)
def create_model(self, lhs_shape, rhs_shape, dimension_numbers, input_type):
self.lhs_shape = lhs_shape
self.rhs_shape = rhs_shape
self.input_type = input_type
def jax_dot_general(lhs, rhs):
out = lax.dot_general(lhs, rhs, dimension_numbers)
return out
return jax_dot_general, None
test_data = [
# 1D vector dot 1D vector
dict(lhs_shape=[4], rhs_shape=[4], dimension_numbers=(((0), (0)), ((), ()))),
# matrix mxk dot vector k
dict(lhs_shape=[2, 5], rhs_shape=[5], dimension_numbers=(((1), (0)), ((), ()))),
# matrix mxk dot matrix kxn
dict(lhs_shape=[2, 5], rhs_shape=[5, 6], dimension_numbers=(((1), (0)), ((), ()))),
# batch matmul case
dict(lhs_shape=[3, 2, 3, 4], rhs_shape=[3, 2, 2, 4], dimension_numbers=(((3), (3)), ((0, 1), (0, 1)))),
# batch matmul case: different batch and contracting dimensions
dict(lhs_shape=[2, 3, 4, 5], rhs_shape=[4, 2, 5, 3], dimension_numbers=(((2, 3), (0, 2)), ((0, 1), (1, 3)))),
]
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("params", test_data)
@pytest.mark.parametrize("input_type", [np.float32, np.int32])
def test_dot_general(self, ie_device, precision, ir_version, params, input_type):
self._test(*self.create_model(**params, input_type=input_type), ie_device, precision,
ir_version)

View File

@ -6,3 +6,5 @@ torch
torchvision torchvision
pytest pytest
tensorflow-addons; python_version <= '3.10' tensorflow-addons; python_version <= '3.10'
jax; sys_platform == "linux"
jaxlib; sys_platform == "linux"