From 5539d052b01d88d306663b7a9051046b0224a9b7 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 21 Aug 2023 16:01:48 +0400 Subject: [PATCH] [JAX][TF Hub][TF FE] Introduce JAX layer tests and support of XLA operations (#19269) * [JAX][TF Hub][TF FE] Introduce JAX layer tests and support of XLA operations Signed-off-by: Kazantsev, Roman * Fix JAX layer tests infa Signed-off-by: Kazantsev, Roman * Extend run for JAX layer tests * Use ovc convert_model * Fix translator and extend layer test cases * Exclude jax testing on Windows --------- Signed-off-by: Kazantsev, Roman --- .ci/azure/linux.yml | 9 + .github/CODEOWNERS | 1 + .github/workflows/linux.yml | 11 + src/frontends/tensorflow/src/op/xla_dot.cpp | 200 +++++ src/frontends/tensorflow/src/op_table.cpp | 4 + .../tensorflow/src/proto/xla_data.proto | 744 ++++++++++++++++++ tests/layer_tests/jax_tests/conftest.py | 12 + .../jax_tests/jax_layer_test_class.py | 142 ++++ .../layer_tests/jax_tests/test_dot_general.py | 47 ++ tests/layer_tests/requirements.txt | 2 + 10 files changed, 1172 insertions(+) create mode 100644 src/frontends/tensorflow/src/op/xla_dot.cpp create mode 100644 src/frontends/tensorflow/src/proto/xla_data.proto create mode 100644 tests/layer_tests/jax_tests/conftest.py create mode 100644 tests/layer_tests/jax_tests/jax_layer_test_class.py create mode 100644 tests/layer_tests/jax_tests/test_dot_general.py diff --git a/.ci/azure/linux.yml b/.ci/azure/linux.yml index f07f048001f..baf12669aa7 100644 --- a/.ci/azure/linux.yml +++ b/.ci/azure/linux.yml @@ -525,6 +525,15 @@ jobs: TEST_DEVICE: CPU displayName: 'TensorFlow 2 Layer Tests - TF FE' + - script: | + set -e + python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt + $(RUN_PREFIX) python3 -m pytest $(LAYER_TESTS_DIR)/jax_tests/ -m precommit --junitxml=$(INSTALL_TEST_DIR)/TEST-jax.xmlTEST + env: + PYTHONPATH: $(LAYER_TESTS_DIR) + TEST_DEVICE: CPU + displayName: 'JAX Layer Tests - TF FE' + - script: | set -e python3 -m pip install -r $(LAYER_TESTS_DIR)/requirements.txt diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 25ab1cec3e0..c43e6edf501 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -92,6 +92,7 @@ /tests/layer_tests/ @openvinotoolkit/openvino-tests-maintainers @openvinotoolkit/openvino-mo-maintainers /tests/layer_tests/pytorch_tests/ @openvinotoolkit/openvino-pytorch-frontend-maintainers /tests/layer_tests/tensorflow_tests @openvinotoolkit/openvino-tf-frontend-maintainers +/tests/layer_tests/jax_tests @openvinotoolkit/openvino-tf-frontend-maintainers # Tools: /tools/ @openvinotoolkit/openvino-tools-maintainers diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index cc16c105950..0c9baad6733 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -641,6 +641,17 @@ jobs: env: TEST_DEVICE: CPU + - name: JAX Layer Tests - TF FE + run: | + python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt + export PYTHONPATH=${{ env.LAYER_TESTS_INSTALL_DIR }}:$PYTHONPATH + + source ${{ env.INSTALL_DIR }}/setupvars.sh + + python3 -m pytest ${{ env.LAYER_TESTS_INSTALL_DIR }}/jax_tests/ -m precommit --junitxml=${{ env.INSTALL_TEST_DIR }}/TEST-jax.xml + env: + TEST_DEVICE: CPU + - name: TensorFlow 1 Layer Tests - Legacy FE run: | python3 -m pip install -r ${{ env.LAYER_TESTS_INSTALL_DIR }}/requirements.txt diff --git a/src/frontends/tensorflow/src/op/xla_dot.cpp b/src/frontends/tensorflow/src/op/xla_dot.cpp new file mode 100644 index 00000000000..e463494511f --- /dev/null +++ b/src/frontends/tensorflow/src/op/xla_dot.cpp @@ -0,0 +1,200 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "common_op_table.hpp" +#include "input_model.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/reduce_prod.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" +#include "xla_data.pb.h" + +using namespace std; +using namespace ov; +using namespace ov::op; + +namespace ov { +namespace frontend { +namespace tensorflow { +namespace op { + +vector compute_non_contracting_dims(const NodeContext& node, + const vector& batch_dims, + const vector& contracting_dims, + const Output& operand) { + // combine two vectors of batch_dims and contracting_dims + set unique_dims(batch_dims.begin(), batch_dims.end()); + unique_dims.insert(contracting_dims.begin(), contracting_dims.end()); + vector all_dims(unique_dims.begin(), unique_dims.end()); + + TENSORFLOW_OP_VALIDATION(node, + operand.get_partial_shape().rank().is_static(), + "[TensorFlow Frontend] internal operation: XlaDotV2 expects inputs of static rank"); + + int64_t operand_rank = operand.get_partial_shape().rank().get_length(); + vector non_contracting_dims; + for (int64_t ind = 0; ind < operand_rank; ++ind) { + if (find(all_dims.begin(), all_dims.end(), ind) == all_dims.end()) { + non_contracting_dims.push_back(ind); + } + } + + return non_contracting_dims; +} + +void insert_aux_dim(const NodeContext& node, Output& operand, vector& dims) { + TENSORFLOW_OP_VALIDATION(node, + operand.get_partial_shape().rank().is_static(), + "[TensorFlow Frontend] internal operation: XlaDotV2 expects inputs of static rank"); + if (dims.size() == 0) { + int64_t operand_rank = operand.get_partial_shape().rank().get_length(); + dims.push_back(operand_rank); + auto unsqueeze_axis = make_shared(element::i64, Shape{1}, operand_rank); + operand = make_shared(operand, unsqueeze_axis); + } +} + +void insert_aux_dims(const NodeContext& node, + Output& operand, + vector& batch_dims, + vector& contracting_dims, + vector& non_contract_dims) { + insert_aux_dim(node, operand, batch_dims); + insert_aux_dim(node, operand, contracting_dims); + insert_aux_dim(node, operand, non_contract_dims); +} + +Output compute_dims_shape(const Output& hs_shape, const vector& dims) { + auto const_dims = make_shared(element::i64, Shape{dims.size()}, dims); + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + auto dims_shape = make_shared(hs_shape, const_dims, gather_axis); + return dims_shape; +} + +Output compute_dims_size(const Output& hs_shape, const Output& dims) { + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + auto dims_shape = make_shared(hs_shape, dims, gather_axis); + auto dims_size = make_shared(dims_shape, gather_axis, true); + return {dims_size}; +} + +OutputVector translate_xla_dot_op(const NodeContext& node) { + // see specification of XlaDotV2 here: https://www.tensorflow.org/xla/operation_semantics#dot + default_op_checks(node, 2, {"XlaDotV2"}); + auto lhs = node.get_input(0); + auto rhs = node.get_input(1); + auto node_name = node.get_name(); + auto dimension_numbers_message = node.get_attribute("dimension_numbers"); + ::xla::DotDimensionNumbers dimension_numbers; + TENSORFLOW_OP_VALIDATION( + node, + dimension_numbers.ParseFromArray(dimension_numbers_message.data(), + static_cast(dimension_numbers_message.size())), + "[TensorFlow Frontend] Incorrect input model: incorrect DotDimensionNumbers field for XlaDotV2 " + node_name); + + vector lhs_batch_dims(dimension_numbers.lhs_batch_dimensions().begin(), + dimension_numbers.lhs_batch_dimensions().end()); + vector rhs_batch_dims(dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); + vector rhs_contract_dims(dimension_numbers.rhs_contracting_dimensions().begin(), + dimension_numbers.rhs_contracting_dimensions().end()); + vector lhs_contract_dims(dimension_numbers.lhs_contracting_dimensions().begin(), + dimension_numbers.lhs_contracting_dimensions().end()); + + // compute non-contracting dimensions + auto lhs_non_contract_dims = compute_non_contracting_dims(node, lhs_batch_dims, lhs_contract_dims, lhs); + auto rhs_non_contract_dims = compute_non_contracting_dims(node, rhs_batch_dims, rhs_contract_dims, rhs); + + // compute the resulted shape before possible modification + auto resulted_shape = make_shared(element::i64, Shape{0}, vector{})->output(0); + bool apply_reshape = false; + auto lhs_shape = make_shared(lhs, element::i64); + auto rhs_shape = make_shared(rhs, element::i64); + if (lhs_batch_dims.size() > 0) { + auto batch_dims_shape = compute_dims_shape(lhs_shape, lhs_batch_dims); + resulted_shape = make_shared(OutputVector{resulted_shape, batch_dims_shape}, 0); + apply_reshape = true; + } + if (lhs_non_contract_dims.size() > 0) { + auto lhs_non_contract_shape = compute_dims_shape(lhs_shape, lhs_non_contract_dims); + resulted_shape = make_shared(OutputVector{resulted_shape, lhs_non_contract_shape}, 0); + apply_reshape = true; + } + if (rhs_non_contract_dims.size() > 0) { + auto rhs_non_contract_shape = compute_dims_shape(rhs_shape, rhs_non_contract_dims); + resulted_shape = make_shared(OutputVector{resulted_shape, rhs_non_contract_shape}, 0); + apply_reshape = true; + } + + // take care of that at least one dimension of each type (batch, contracting, and non-contracting) exists + // if it does not, insert it to the end + insert_aux_dims(node, lhs, lhs_batch_dims, lhs_contract_dims, lhs_non_contract_dims); + insert_aux_dims(node, rhs, rhs_batch_dims, rhs_contract_dims, rhs_non_contract_dims); + + // compute non-batch and non-contracting dimensions + auto const_lhs_batch_dims = make_shared(element::i64, Shape{lhs_batch_dims.size()}, lhs_batch_dims); + auto const_rhs_batch_dims = make_shared(element::i64, Shape{rhs_batch_dims.size()}, rhs_batch_dims); + auto const_lhs_contract_dims = + make_shared(element::i64, Shape{lhs_contract_dims.size()}, lhs_contract_dims); + auto const_rhs_contract_dims = + make_shared(element::i64, Shape{rhs_contract_dims.size()}, rhs_contract_dims); + auto const_lhs_non_contract_dims = + make_shared(element::i64, Shape{lhs_non_contract_dims.size()}, lhs_non_contract_dims); + auto const_rhs_non_contract_dims = + make_shared(element::i64, Shape{rhs_non_contract_dims.size()}, rhs_non_contract_dims); + + lhs_shape = make_shared(lhs, element::i64); + rhs_shape = make_shared(rhs, element::i64); + + // compute a part of the input shape covering batch dimensions and non-contracting dimensions + auto gather_axis = make_shared(element::i32, Shape{1}, 0); + auto batch_dims_shape = compute_dims_shape(lhs_shape, lhs_batch_dims); + + // transpose both operand in a way to have dimensions in the order + // [batch dims, non-contracting dims, contracting dims] + auto lhs_transpose_order = make_shared( + OutputVector{const_lhs_batch_dims, const_lhs_non_contract_dims, const_lhs_contract_dims}, + 0); + auto rhs_transpose_order = make_shared( + OutputVector{const_rhs_batch_dims, const_rhs_non_contract_dims, const_rhs_contract_dims}, + 0); + lhs = make_shared(lhs, lhs_transpose_order); + rhs = make_shared(rhs, rhs_transpose_order); + + // compute size of contracting dims and non-contracting dims for each operand + auto lhs_contract_size = compute_dims_size(lhs_shape, const_lhs_contract_dims); + auto rhs_contract_size = compute_dims_size(rhs_shape, const_rhs_contract_dims); + auto lhs_non_contract_size = compute_dims_size(lhs_shape, const_lhs_non_contract_dims); + auto rhs_non_contract_size = compute_dims_size(rhs_shape, const_rhs_non_contract_dims); + + // merge contracting and non-contracting dimensions to have operand + // of a shape [batch dims, non-contracting dim size, contracting dims size] + auto new_lhs_shape = + make_shared(OutputVector{batch_dims_shape, lhs_non_contract_size, lhs_contract_size}, 0); + auto new_rhs_shape = + make_shared(OutputVector{batch_dims_shape, rhs_non_contract_size, rhs_contract_size}, 0); + lhs = make_shared(lhs, new_lhs_shape, false); + rhs = make_shared(rhs, new_rhs_shape, false); + + // execute MatMul that support batch matrix-multiplication + // note that the second operand is transposed + auto matmul = make_shared(lhs, rhs, false, true)->output(0); + if (apply_reshape) { + matmul = make_shared(matmul, resulted_shape, false); + } + + set_node_name(node_name, matmul.get_node_shared_ptr()); + return {matmul}; +} + +} // namespace op +} // namespace tensorflow +} // namespace frontend +} // namespace ov diff --git a/src/frontends/tensorflow/src/op_table.cpp b/src/frontends/tensorflow/src/op_table.cpp index 955701d807c..32c6c093e34 100644 --- a/src/frontends/tensorflow/src/op_table.cpp +++ b/src/frontends/tensorflow/src/op_table.cpp @@ -46,6 +46,7 @@ TF_OP_CONVERTER(translate_varhandle_op); TF_OP_CONVERTER(translate_variable_op); TF_OP_CONVERTER(translate_varisinitialized_op); TF_OP_CONVERTER(translate_while_op); +TF_OP_CONVERTER(translate_xla_dot_op); const std::map get_supported_ops() { return { @@ -301,6 +302,9 @@ const std::map get_supported_ops() { {"SparseFillEmptyRows", CreatorFunction(translate_sparse_fill_empty_rows_op)}, {"SparseSegmentSum", CreatorFunction(translate_sparse_segment_sum_op)}, {"Unique", CreatorFunction(translate_unique_op)}, + + // XLA operations + {"XlaDotV2", CreatorFunction(translate_xla_dot_op)}, }; }; } // namespace op diff --git a/src/frontends/tensorflow/src/proto/xla_data.proto b/src/frontends/tensorflow/src/proto/xla_data.proto new file mode 100644 index 00000000000..95695ba78a2 --- /dev/null +++ b/src/frontends/tensorflow/src/proto/xla_data.proto @@ -0,0 +1,744 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Primitive types are the individual values that can be held in rectangular +// multidimensional arrays. A description of the rectangular multidimensional +// array dimensions / primitive type is given by Shape, below. +// +// LINT.IfChange +enum PrimitiveType { + // Invalid primitive type to serve as default. + PRIMITIVE_TYPE_INVALID = 0; + + // Predicates are two-state booleans. + PRED = 1; + + // Signed integral values of fixed width. + S8 = 2; + S16 = 3; + S32 = 4; + S64 = 5; + + // Unsigned integral values of fixed width. + U8 = 6; + U16 = 7; + U32 = 8; + U64 = 9; + + // Floating-point values of fixed width. + // + // Note: if f16s are not natively supported on the device, they will be + // converted to f16 from f32 at arbirary points in the computation. + F16 = 10; + F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + + F64 = 12; + + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A tuple is a polymorphic sequence; e.g. a shape that holds different + // sub-shapes. They are used for things like returning multiple values from a + // computation; e.g. a computation that returns weights and biases may have a + // signature that results in a tuple like (f32[784x2000], f32[2000]) + // + // If a shape proto has the tuple element type, it may not have any entries + // in the dimensions field. + TUPLE = 13; + + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. + // + // (OPAQUE would be a better name for this identifier, but that conflicts with + // a macro defined in windows.h.) + OPAQUE_TYPE = 14; + + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 19 +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, +// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc +// ) + +// Describes the padding configuration for Pad operation. The padding amount on +// both edges as well as between the elements are specified for each dimension. +message PaddingConfig { + // Describes the padding configuration for a dimension. + message PaddingConfigDimension { + // Padding amount on the low-end (next to the index 0). May be negative. + int64 edge_padding_low = 1; + + // Padding amount on the high-end (next to the highest index). May be + // negative. + int64 edge_padding_high = 2; + + // Padding amount between the elements. May not be negative. + int64 interior_padding = 3; + } + + // The padding configuration for all dimensions. + repeated PaddingConfigDimension dimensions = 1; +} + +// A format specifies the method used by a layout to store an array in memory. +enum Format { + // TODO(b/120869032): Rename this to FORMAT_NONE or something else which + // better corresponds to its meaning. + INVALID_FORMAT = 0; + // The default layout, with exactly one storage location per element. + DENSE = 1; + reserved 2; + reserved "SPARSE"; +} + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for +// details about tiling-based layout. +message TileProto { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + +// A layout describes how the array is placed in (1D) memory space. This +// includes the minor-to-major ordering of dimensions within a shape. +// +// Clients must specify the layouts of input Literals to the +// computation. Layouts specified in interior operations which take Shapes (for +// example, Convert) are ignored. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message LayoutProto { + // The method used to store the data in memory. The format determines which of + // the other fields are used by the layout. + Format format = 4; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). This field is required. + repeated int64 minor_to_major = 1; + + reserved 2; + reserved "padded_dimensions"; + + reserved 3; + reserved "padding_value"; + + reserved 5; + reserved "max_sparse_elements"; + + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated TileProto tiles = 6; + + // Bit size of each element. If the size is bigger than what the element + // type requires, the value is stored in the least significant + // bits and the additional most significant bits are filled with 0's. + // + // TODO(b/119839262): implement in each backend or add Unimplemented error. + int64 element_size_in_bits = 7; + + // Memory space where this array resides. The integer field is interpreted in + // a backend-specific manner. + int64 memory_space = 8; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) + +// A shape describes the number of dimensions in the array, the size of each +// dimension, and the primitive component type. +// +// Tuples are a special case in that they have rank zero and have tuple_shapes +// defined. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message ShapeProto { + reserved 1; + reserved "rank"; + + // The element type for this shape. + PrimitiveType element_type = 2; + + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. + repeated int64 dimensions = 3; + + // For tuples only, the shapes of constituent shapes in the tuple sequence. + repeated ShapeProto tuple_shapes = 4; + + // The layout used to back this shape. + LayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) + +// Shape of the parameters and output of a computation (like a traditional +// function signature). +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; + repeated string parameter_names = 3; +} + +// Statistics of a computation. +message ComputationStats { + // The number of floating point operations in the computation. + double flop_count = 1; + + // The number of transcendental operations (e.g., exp) in the computation. + double transcendental_count = 2; +} + +// The type optimization profiles in use. +enum ProfileType { + INVALID = 0; + WINDOW = 1; + FLAG = 2; + INTEGER = 3; +} + +// Symbolization metadata for HLO Instructions. +// +// This metadata is used for debugging XLA code generation, as well as +// performance profiling of XLA-generated executables. +message OpMetadata { + // The framework op name that generated this XLA op. + // + // Frameworks that build on top of XLA should mirror the names of their ops + // back to users by specifying the op_type. In this way, even if the + // framework's "ops" are implemented as multiple XLA HLO Ops, they can be + // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as + // multiple ops, then each op should have the op_type be "SoftMax".) + string op_type = 1; + // The user-specified name of the op. + // + // This name is often unique within a computation. Note: some frameworks + // add auto-generated names if the user does not provide one. + string op_name = 2; + // Indicate a file and line that this op is associated to in a user's program. + // + // e.g. it could be the file and line of user code that generated the op. + string source_file = 3; + int32 source_line = 4; + + repeated ProfileType profile_type = 5; + + // HloPassMetadata.pass_id of the pass that created this HLO instruction + // object. Should never be copied between HLO instructions. Zero if unset and + // -1 if the instruction was created before HLO passes began. + int64 creation_pass_id = 6; + + // HloPassMetadata.pass_id of the pass that created the logical functionality + // that this HLO instruction represents. Should be copied between HLO + // instructions that correspond across compilation passes. Zero if unset and + // -1 if the instruction was created before HLO passes began. + int64 logical_creation_pass_id = 7; + + // The footprint of the generated code for the instruction. + int64 size_of_generated_code_in_bytes = 8; + // The size of the working set, i.e., the amount of memory, used by the + // instruction in a compiler-managed fast device memory. + int64 size_of_memory_working_set_in_bytes = 9; +} + +// Profile data from the execution of a computation. +message ExecutionProfile { + // Whether the executable was read from the compilation cache. + bool compilation_cache_hit = 1; + + // The time in milliseconds spent to compile the computation. This only set if + // the executable was not read from the compilation cache + // (compilation_cache_hit == false). + int64 compile_time_ms = 2; + + // The number of cycles spent for the computation. This does not include the + // time taken for the data transfers between the host and the device. This is + // a target-dependent field and only used for debugging purposes. + int64 compute_cycle_count = 3; + + // The time in nanoseconds spent for the computation, without data transfer. + int64 compute_time_ns = 4; + + // The time in nanoseconds spent for the entire computation, including the + // result data transfer time. Current implementation does not spend any cycles + // for the input data transfer since the memory is initialized with the proper + // values before the execution. + int64 compute_and_transfer_time_ns = 5; + + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; + + // Whether this profile was drawn from a cache of profiles instead of from + // execution on the hardware. + bool profile_cache_hit = 7; +} + +// Handle given to a user that represents an execution that the user launched +// asynchronously on the device. +message ExecutionHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a globally accessible allocation. +// Contrast this against a ComputationDataHandle, which is not globally +// accessible, since it only exists within a specific computation. +message GlobalDataHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a replicated virtual device. Each +// replicated device represents N physical devices for execution where N is the +// number of replicas. +message DeviceHandle { + int64 handle = 1; + + // The number of model-parallel virtual devices that communicate via XLA + // Send/Recv instructions. + int64 device_count = 2; +} + +// Handle given to a user to represent a channel between two computations +// via a Send and Recv instruction pair. Channels are unbuffered, so Send +// Send instructions will be blocked until the data is transferred. +message ChannelHandle { + int64 handle = 1; + enum ChannelType { + // Invalid primitive type to serve as default. + CHANNEL_TYPE_INVALID = 0; + + // A channel for sending data between devices. + DEVICE_TO_DEVICE = 1; + + // A channel for sending data from the device to the host. Can only be used + // with a Send operation. + DEVICE_TO_HOST = 2; + + // A channel for sending data from the host to the device. Can only be used + // with a Recv operation. + HOST_TO_DEVICE = 3; + } + ChannelType type = 2; +} + +// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which +// represents the device ids assigned to a set of replicated computations. +// See xla::DeviceAssignment class comment for more details. +message DeviceAssignmentProto { + int32 replica_count = 1; + int32 computation_count = 2; + + // Each logical computation runs on replica_count physical devices. + // ComputationDevice represents the device ids assinged to the replicas. + message ComputationDevice { + repeated int32 replica_device_ids = 1; + } + repeated ComputationDevice computation_devices = 3; +} + +// Literals are used when the server and client need to exchange materialized +// data / results. Literals are also used to describe constants used in +// computations. +// +// Transfers to/from the client are encoded in literal form, and the structure +// of the repeated fields is implied by the shape. +message LiteralProto { + ShapeProto shape = 1; + repeated bool preds = 2; + bytes s8s = 15; + bytes u8s = 3; + repeated int32 s32s = 4; + repeated int64 s64s = 5; + repeated uint32 u32s = 6; + repeated uint64 u64s = 7; + repeated float f32s = 8; + repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. + repeated LiteralProto tuple_literals = 10; + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; + repeated int64 sparse_indices = 14; + // Next = 19 +} + +message WindowDimension { + // The size of the window in this dimension. For a rectangle, this would be + // the width or height. + int64 size = 1; + + // The stride at which the window moves across the base area in this + // dimension. In other words, this is the spacing between different + // positions of the window in this dimension. + int64 stride = 2; + + // If positive, means the amount of padding to add to the base area at the low + // end of this dimension; if negative, its negative means the number of + // elements removed from the low end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of padding + // values to pad on the left, given that indices increase when going right. + // The actual padding value depends upon the context. Convolution pads with + // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + // init value. + int64 padding_low = 3; + + // As padding_low, but on the high end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of values to + // pad on the right, given that indices increase when going right. + int64 padding_high = 4; + + // Dilation factor of the sliding window in this dimension. A dilation factor + // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are + // implicitly placed between each kernel element. This value may not be less + // than 1. See documentation for convolution. + int64 window_dilation = 5; + + // Dilation factor of the base area in this dimension. A dilation factor of 1 + // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly + // placed between each base area element. This value may not be less than 1. + // See documentation for convolution. + int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; +} + +// Describes the windowing in an operation such as convolution. +// +// The window is moved across a base area and for each position of the +// window a computation is performed. The field below describes the +// window and the movement of the window across a base area. +message Window { + repeated WindowDimension dimensions = 1; +} + +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the start_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in collapsed_slice_dims + // then 0 + // else Out[offset_dims[i++]] + repeated int64 offset_dims = 1; + repeated int64 collapsed_slice_dims = 2; + + // This is interpreted as a map from i to start_index_map[i]. It + // transforms the gather index looked up from the start_indices tensor into + // the starting index in the input space. + repeated int64 start_index_map = 3; + + // The dimension in the start_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; +} + +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + +message ConvolutionDimensionNumbers { + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; + + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; + + // The number of the dimension that represents input features in the + // convolutional kernel (rhs). + int64 kernel_input_feature_dimension = 3; + + // The number of the dimension that represents output features in + // the convolutional kernel (rhs). + int64 kernel_output_feature_dimension = 4; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the kernel (rhs). window.strides(0) is the + // stride in the kernel_spatial_dimensions(0) dimension. + repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 +} + +enum PaddingType { + PADDING_INVALID = 0; + PADDING_VALID = 1; // Only valid portion of the base are covered. + PADDING_SAME = 2; // Extra is added to produce same output size as the input. +} + +enum FftType { + FFT = 0; // Forward FFT; complex in, complex out. + IFFT = 1; // Inverse FFT; complex in, complex out. + RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out + IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, + // fft_length real out +} + +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +} + +enum RandomDistribution { + RNG_INVALID = 0; + + // Creates a uniform-distribution-generated random number on the semi-open + // interval [parameter[0], parameter[1]). + RNG_UNIFORM = 1; + + // Creates a normal-distribution-generated random number with mean + // parameter[0] and standard deviation parameter[1]. + RNG_NORMAL = 2; + + // Next: 4 +} + +enum RandomAlgorithm { + RNG_DEFAULT = 0; // Backend dependent default algorithm. + RNG_THREE_FRY = 1; + RNG_PHILOX = 2; + // Next: 2 +} + +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + } + Transpose transpose_a = 4; +} + +message CholeskyOptions { + // If true, uses the lower triangle of `a`. If false, uses the upper triangle + // of `a`. + bool lower = 1; +} + +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + +message OpSharding { + enum Type { + // This sharding is replicated across all devices (implies maximal, + // all other fields are unused). + REPLICATED = 0; + // This sharding is maximal - one device runs the entire operation. + MAXIMAL = 1; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; + // This op is manually sharded: the shapes are already partitioned and the + // partitioner should not change this op. + MANUAL = 4; + } + Type type = 1; + // The shape of the sharded tile. + ShapeProto tile_shape = 2; + // The shape of the tile assignment tensor - this must be the same rank as + // tile_shape and the product of its dimensions must equal + // tile_assignment_devices.size(). + repeated int64 tile_assignment_dimensions = 3; + // Flattened list of device IDs. The order of flattening is the same as used + // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). + repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; + + // Only used for OTHER type. If true, data is sharded according to other + // dimensions of tile_assignment(), but replicated across devices along the + // last dimension. (Experimental) + bool replicate_on_last_tile_dim = 6; + // This field is used to track the source of this sharding, usually derived + // from instructions. Multple metadata may be populated if sharding is + // combined with other shardings. Metadata are to not be populated when + // type == TUPLE and instead metadata should be set on individual tuple + // elements. + repeated OpMetadata metadata = 7; +} + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some ops (e.g., all-to-all). + repeated int64 replica_ids = 1; +} + +// Describes the source target pair in the collective permute op. +message SourceTarget { + int64 source = 1; + int64 target = 2; +} + +// Used to indicate the precision configuration. It has backend specific +// meaning. +message PrecisionConfig { + enum Precision { + DEFAULT = 0; + HIGH = 1; + HIGHEST = 2; + + // Next: 3 + } + repeated Precision operand_precision = 1; + + // Next: 2 +} + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} + +// A backend-config for kWhile loops that stores the loop's trip count, if it is +// known. +// +// This is useful for backends that can implement a `for i in 0..N` loop more +// efficiently than a `while` loop. For example, on GPUs, we can implement a +// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, +// whereas implementing a `while` loop requires a host-device sync on each +// iteration. +message WhileLoopBackendConfig { + message KnownTripCount { + int64 n = 1; + } + // This indirection lets us distinguish between known-trip-count == 0 and + // unknown-trip-count. + KnownTripCount known_trip_count = 1; +} + +// Specifies a pair of output/operand buffers for kCustomCall that alias each +// other. +message CustomCallOutputOperandAliasing { + repeated int64 output_shape_index = 1; + int64 operand_index = 2; + repeated int64 operand_shape_index = 3; +} diff --git a/tests/layer_tests/jax_tests/conftest.py b/tests/layer_tests/jax_tests/conftest.py new file mode 100644 index 00000000000..6d1ec3182a9 --- /dev/null +++ b/tests/layer_tests/jax_tests/conftest.py @@ -0,0 +1,12 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import inspect + +from common.layer_test_class import get_params + + +def pytest_generate_tests(metafunc): + test_gen_attrs_names = list(inspect.signature(get_params).parameters) + params = get_params() + metafunc.parametrize(test_gen_attrs_names, params, scope="function") diff --git a/tests/layer_tests/jax_tests/jax_layer_test_class.py b/tests/layer_tests/jax_tests/jax_layer_test_class.py new file mode 100644 index 00000000000..8d79a49809d --- /dev/null +++ b/tests/layer_tests/jax_tests/jax_layer_test_class.py @@ -0,0 +1,142 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import itertools +from copy import deepcopy + +import numpy as np +from common.constants import test_device, test_precision +from jax import numpy as jnp +from openvino.runtime import Core + + +class JaxLayerTest: + def _test(self, model, ref_net, ie_device, precision, ir_version, infer_timeout=60, dynamic_shapes=True, + **kwargs): + """ + :param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled. + Example: "transform_1,transform_2" + """ + inputs = self._prepare_input() + converted_model = self.convert_via_tensorflow_function(model, inputs) + + # OV infer: + core = Core() + compiled = core.compile_model(converted_model, ie_device) + infer_res = compiled(deepcopy(inputs)) + + # Framework infer: + fw_res = model(*deepcopy(inputs)) + + if not isinstance(fw_res, (tuple)): + fw_res = (fw_res,) + + output_list = ov_res_to_list(infer_res) + + def flattenize_dict_outputs(res): + if isinstance(res, dict): + return flattenize_outputs(res.values()) + + def flattenize_outputs(res): + results = [] + for res_item in res: + # if None is at output we skip it + if res_item is None: + continue + # If input is list or tuple flattenize it + if isinstance(res_item, (list, tuple)): + decomposed_res = flattenize_outputs(res_item) + results.extend(decomposed_res) + continue + if isinstance(res_item, dict): + decomposed_res = flattenize_dict_outputs(res_item) + results.extend(decomposed_res) + continue + results.append(res_item) + return results + + flatten_fw_res = flattenize_outputs(fw_res) + + assert len(flatten_fw_res) == len( + output_list), f'number of outputs not equal, {len(flatten_fw_res)} != {len(output_list)}' + # check if results dtypes match + for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list): + fw_tensor_type = np.array(fw_tensor).dtype + ov_tensor_type = ov_tensor.dtype + assert ov_tensor_type == fw_tensor_type, f"dtype validation failed: {ov_tensor_type} != {fw_tensor_type}" + + if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None: + custom_eps = kwargs['custom_eps'] + else: + custom_eps = 1e-4 + + # compare OpenVINO results with JAX results + fw_eps = custom_eps if precision == 'FP32' else 5e-2 + is_ok = True + for i in range(len(flatten_fw_res)): + cur_fw_res = np.array(flatten_fw_res[i]) + cur_ov_res = infer_res[compiled.output(i)] + print(f"fw_re: {cur_fw_res};\n ov_res: {cur_ov_res}") + if not np.allclose(cur_ov_res, cur_fw_res, + atol=fw_eps, + rtol=fw_eps, equal_nan=True): + is_ok = False + print("Max diff is {}".format( + np.array( + abs(cur_ov_res - cur_fw_res)).max())) + else: + print("Accuracy validation successful!\n") + print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps)) + assert is_ok, "Accuracy validation failed" + + # Each model should specify inputs + def _prepare_input(self): + raise RuntimeError("Please provide inputs generation function") + + def convert_via_tensorflow_function(self, model, inputs): + import tensorflow as tf + from jax.experimental import jax2tf + from openvino.tools.ovc import convert_model + # create function signature based on input shapes and types + function_signature = [] + for _input in inputs: + assert isinstance(_input, np.ndarray) + input_shape = _input.shape + input_type = _input.dtype + function_signature.append(tf.TensorSpec(input_shape, input_type)) + + f = tf.function(jax2tf.convert(model), autograph=False, + input_signature=function_signature) + converted_model = convert_model(f) + return converted_model + + +def get_params(ie_device=None, precision=None): + """ + :param ie_device: list of devices + :param precision: list of precisions + """ + + ie_device_params = ie_device if ie_device else test_device + precision_params = precision if precision else test_precision + + test_args = [] + for element in itertools.product(ie_device_params, precision_params): + if element[0] == 'CPU' and element[1] == 'FP16': + continue + test_args.append(element) + return test_args + + +def ov_res_to_list(ov_res_dict): + # 118221: remove this WA that clean-up repeating output tensors + # with the same tensor names + # probably, we do not utilize some meta info from tf.function + values = [] + met_names = set() + for ov_res_name, ov_res_value in ov_res_dict.items(): + if bool(set(ov_res_name.names) & met_names): + continue + met_names |= set(ov_res_name.names) + values.append(ov_res_value) + return values diff --git a/tests/layer_tests/jax_tests/test_dot_general.py b/tests/layer_tests/jax_tests/test_dot_general.py new file mode 100644 index 00000000000..02de2dfaaf4 --- /dev/null +++ b/tests/layer_tests/jax_tests/test_dot_general.py @@ -0,0 +1,47 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +from jax import lax + +from jax_layer_test_class import JaxLayerTest + + +class TestDotGeneral(JaxLayerTest): + def _prepare_input(self): + lhs = np.random.randint(-10, 10, self.lhs_shape).astype(self.input_type) + rhs = np.random.randint(-10, 10, self.rhs_shape).astype(self.input_type) + return (lhs, rhs) + + def create_model(self, lhs_shape, rhs_shape, dimension_numbers, input_type): + self.lhs_shape = lhs_shape + self.rhs_shape = rhs_shape + self.input_type = input_type + + def jax_dot_general(lhs, rhs): + out = lax.dot_general(lhs, rhs, dimension_numbers) + return out + + return jax_dot_general, None + + test_data = [ + # 1D vector dot 1D vector + dict(lhs_shape=[4], rhs_shape=[4], dimension_numbers=(((0), (0)), ((), ()))), + # matrix mxk dot vector k + dict(lhs_shape=[2, 5], rhs_shape=[5], dimension_numbers=(((1), (0)), ((), ()))), + # matrix mxk dot matrix kxn + dict(lhs_shape=[2, 5], rhs_shape=[5, 6], dimension_numbers=(((1), (0)), ((), ()))), + # batch matmul case + dict(lhs_shape=[3, 2, 3, 4], rhs_shape=[3, 2, 2, 4], dimension_numbers=(((3), (3)), ((0, 1), (0, 1)))), + # batch matmul case: different batch and contracting dimensions + dict(lhs_shape=[2, 3, 4, 5], rhs_shape=[4, 2, 5, 3], dimension_numbers=(((2, 3), (0, 2)), ((0, 1), (1, 3)))), + ] + + @pytest.mark.nightly + @pytest.mark.precommit + @pytest.mark.parametrize("params", test_data) + @pytest.mark.parametrize("input_type", [np.float32, np.int32]) + def test_dot_general(self, ie_device, precision, ir_version, params, input_type): + self._test(*self.create_model(**params, input_type=input_type), ie_device, precision, + ir_version) diff --git a/tests/layer_tests/requirements.txt b/tests/layer_tests/requirements.txt index a182efac32a..30a92216833 100644 --- a/tests/layer_tests/requirements.txt +++ b/tests/layer_tests/requirements.txt @@ -6,3 +6,5 @@ torch torchvision pytest tensorflow-addons; python_version <= '3.10' +jax; sys_platform == "linux" +jaxlib; sys_platform == "linux"