Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
904384fee3
@ -26,7 +26,7 @@ endif()
|
||||
# resolving dependencies for the project
|
||||
message (STATUS "PROJECT ............................... " ${PROJECT_NAME})
|
||||
message (STATUS "CMAKE_BINARY_DIR ...................... " ${CMAKE_BINARY_DIR})
|
||||
message (STATUS "OpenVINO_SOURCE_DIR .... .......... " ${OpenVINO_SOURCE_DIR})
|
||||
message (STATUS "OpenVINO_SOURCE_DIR ................... " ${OpenVINO_SOURCE_DIR})
|
||||
message (STATUS "CMAKE_GENERATOR ....................... " ${CMAKE_GENERATOR})
|
||||
message (STATUS "CMAKE_C_COMPILER_ID ................... " ${CMAKE_C_COMPILER_ID})
|
||||
message (STATUS "CMAKE_BUILD_TYPE ...................... " ${CMAKE_BUILD_TYPE})
|
||||
|
@ -189,15 +189,6 @@ limitations under the License.
|
||||
<tab type="user" title="Benchmark C++ Tool" url="@ref openvino_inference_engine_samples_benchmark_app_README"/>
|
||||
<tab type="user" title="Benchmark Python* Tool" url="@ref openvino_inference_engine_tools_benchmark_tool_README"/>
|
||||
</tab>
|
||||
<!-- Reference Implementations -->
|
||||
<tab type="usergroup" title="Reference Implementations" url="">
|
||||
<tab type="usergroup" title="Speech Library and Speech Recognition Demos" url="@ref openvino_inference_engine_samples_speech_libs_and_demos_Speech_libs_and_demos">
|
||||
<tab type="user" title="Speech Library" url="@ref openvino_inference_engine_samples_speech_libs_and_demos_Speech_library"/>
|
||||
<tab type="user" title="Offline Speech Recognition Demo" url="@ref openvino_inference_engine_samples_speech_libs_and_demos_Offline_speech_recognition_demo"/>
|
||||
<tab type="user" title="Live Speech Recognition Demo" url="@ref openvino_inference_engine_samples_speech_libs_and_demos_Live_speech_recognition_demo"/>
|
||||
<tab type="user" title="Kaldi* Statistical Language Model Conversion Tool" url="@ref openvino_inference_engine_samples_speech_libs_and_demos_Kaldi_SLM_conversion_tool"/>
|
||||
</tab>
|
||||
</tab>
|
||||
<!-- DL Streamer Examples -->
|
||||
<tab type="usergroup" title="DL Streamer Examples" url="@ref gst_samples_README">
|
||||
<tab type="usergroup" title="Command Line Samples" url="">
|
||||
|
@ -1,45 +0,0 @@
|
||||
# Kaldi* Statistical Language Model Conversion Tool {#openvino_inference_engine_samples_speech_libs_and_demos_Kaldi_SLM_conversion_tool}
|
||||
|
||||
The Kaldi* Statistical Language Model (SLM) Conversion Tool is a command-line tool that converts [Kaldi](https://kaldi-asr.org/) language model resources to the format supported by the OpenVINO™ Speech Recognition Demos.
|
||||
|
||||
## Command Line
|
||||
|
||||
`kaldi_slm_convertion_tool HCLG.const.fst transitions.txt words.txt slm.fst labels.bin`
|
||||
|
||||
## Input Parameters
|
||||
|
||||
**HCLG.fst**
|
||||
|
||||
The `HCLG.const.fst` parameter is the input weighted finite-state transducer (WFST) file in the OpenFST const format.
|
||||
|
||||
Most example scripts create a language model file in that format. If you have a WFST in a different OpenFST format, convert it with the following command:
|
||||
|
||||
```sh
|
||||
$KALDI_ROOT/tools/openfst/bin/fstconvert --fst_type=const HCLG.fst HCLG.const.fst
|
||||
```
|
||||
|
||||
The source Kaldi language model file `HCLG.fst` can be found in directories like `exp/tri2b/graph_xyz`, where `tri2b` is the name of the model used for speech recognition tests.
|
||||
|
||||
**transitions.txt**
|
||||
The WFST transitions file describes the relations between WFST transitions and neural acoustic model outputs. This file is usually not generated by Kaldi example scripts, so you have to create it with the following command:
|
||||
|
||||
```sh
|
||||
$KALDI_ROOT/src/bin/show-transitions phones.txt final.mdl > transitions.txt
|
||||
```
|
||||
|
||||
For this call, the `phones.txt` file is the phoneme description file, which can often be found in `data/lang/phones.txt`.
|
||||
The `final.mdl` file is the neural acoustic model that is used for speech recognition.
|
||||
|
||||
**words.txt**
|
||||
|
||||
The `words.txt` file defines the mappings from word IDs used internally to their text representation. For many Kaldi example scripts, the file can be found in the same directory as `HCLG.fst`.
|
||||
|
||||
## Output Parameters
|
||||
|
||||
**slm.fst**
|
||||
|
||||
The output file `slm.fst` is generated by the SLM Conversion Tool. It contains the information needed for the OpenVINO™ speech recognition demos for decoding.
|
||||
|
||||
**labels.bin**
|
||||
|
||||
The `labels.bin` file defines mappings from word IDs to word strings, like the `words.txt` file, but in the binary format. The OpenVINO™ speech recognition example needs the `labels.bin` file to convert recognized words into a human-readable format.
|
@ -1,78 +0,0 @@
|
||||
# Live Speech Recognition Demo {#openvino_inference_engine_samples_speech_libs_and_demos_Live_speech_recognition_demo}
|
||||
|
||||
This demo provides a GUI interface for automatic speech recognition using selected OpenVINO™ Inference Engine plugin, OpenVINO™ Feature Extraction Library, and OpenVINO™ Decoder Library.
|
||||
|
||||
## How It Works
|
||||
|
||||
The application transcribes audio from a WAV file and/or audio device. It supports recognition of two audio sources in parallel, for example audio coming from your microphone and audio coming from your PC (loopback). That enables use cases like audio conference or transcribing audio from an online video stream. Among other things, user can select a specific plugin to use for the recognition, set batch size, and control volume.
|
||||
|
||||
The software stack used by the demo is as follows:
|
||||
|
||||

|
||||
|
||||
## Running
|
||||
|
||||
The application main window looks like this:
|
||||
|
||||

|
||||
|
||||
Refer to the sections below for instructions for particular scenarios.
|
||||
|
||||
### Transcribing Speech from WAV File
|
||||
|
||||
Click **Select File (9)** and navigate to the audio file using the file selection window dialog. Ensure the selected audio format is 16 kHz, 16 bit, 1 channel stored as WAV.
|
||||
|
||||
Alternatively, use the audio file that is already selected upon launching the app.
|
||||
|
||||
Click **Recognize (10)**.
|
||||
|
||||
Transcription appear in the **Source 1** box.
|
||||
|
||||
### Transcribing Speech from Audio or Video Playback (Loopback)
|
||||
|
||||
Select a proper audio output device **(3)**.
|
||||
|
||||
Click **Recognize (5)** and play your video or other multimedia.
|
||||
|
||||
Transcription appears in the **Source 1** box.
|
||||
|
||||
> **NOTE**: Loopback on Linux\* OS may need manual settings in PulseAudio Control or via a config file.
|
||||
|
||||
### Transcribing Speech Captured with Microphone
|
||||
|
||||
Select a microphone **(6)**.
|
||||
|
||||
Click **Recognize (8)** and start speaking.
|
||||
|
||||
Transcription appears in the **Source 2** box.
|
||||
|
||||
### Transcribing Speech from Audio Output and Microphone at the Same Time (Audio Conference)
|
||||
|
||||
Select an audio output device **(3)**.
|
||||
|
||||
Select a microphone **(6)**.
|
||||
|
||||
Click both **Recognize** buttons: **(5)** and **(8)**. Then start speaking.
|
||||
|
||||
Transcriptions appear in both **Source 1** and **Source 2** boxes.
|
||||
|
||||
> **NOTE**: Loopback on Linux OS may need manual settings in PulseAudio Control or via a config file.
|
||||
|
||||
### Changing Speech Recognition Model
|
||||
|
||||
Select the desired configuration from the dropdown list **(1)**.
|
||||
|
||||
To reset the application to default configuration, click **Reload (2)**.
|
||||
|
||||
### Controlling Volume
|
||||
|
||||
Audio volume for each stream can be controlled with sliders **(4)** and **(7)**.
|
||||
Current audio levels of each stream are shown in the bar on the same row as its source selector.
|
||||
|
||||
### Selecting Inference Engine Plugin
|
||||
|
||||
Select an Inference Engine plugin and batch size with **(11)** and **(12)**.
|
||||
|
||||
## Demo Output
|
||||
|
||||
The resulting transcription for each audio source is presented in the application in real time.
|
@ -1,90 +0,0 @@
|
||||
# Offline Speech Recognition Demo {#openvino_inference_engine_samples_speech_libs_and_demos_Offline_speech_recognition_demo}
|
||||
|
||||
This demo provides a command-line interface for automatic speech recognition using OpenVINO™.
|
||||
Components used by this executable:
|
||||
|
||||
* `lspeech_s5_ext` model - Example pre-trained LibriSpeech DNN
|
||||
* `speech_library.dll` (`.so`) - Open source speech recognition library that uses OpenVINO™ Inference Engine, Intel® Speech Feature Extraction and Intel® Speech Decoder libraries
|
||||
|
||||
## How It Works
|
||||
|
||||
The application transcribes speech from a given WAV file and outputs the text to the console.
|
||||
|
||||
## Running
|
||||
|
||||
The application requires two command-line parameters, which point to an audio file with speech to transcribe and a configuration file describing the resources to use for transcription.
|
||||
|
||||
### Parameters for Executable
|
||||
|
||||
* `-wave` - Path to input WAV to process. WAV file needs to be in the following format: RIFF WAVE PCM 16bit, 16kHz, 1 channel, with header.
|
||||
* `-c`, `--config` - Path to configuration file with paths to resources and other parameters.
|
||||
|
||||
Example usage:
|
||||
|
||||
```sh
|
||||
offline_speech_recognition_app.exe -wave="<path_to_audio>/inputAudio.wav" -c="<path_to_config>/configFile.cfg"
|
||||
```
|
||||
|
||||
### Configuration File Description
|
||||
|
||||
The configuration file is an ASCII text file where:
|
||||
* Parameter name and its value are separated with the space character
|
||||
* Parameter and value pair ends with the end of line character
|
||||
|
||||
#### Parameter Description
|
||||
|
||||
| Parameter | Description | Value used for demo |
|
||||
| --- | --- | --- |
|
||||
| `-fe:rt:numCeps` | Number of MFCC cepstrums | *13* |
|
||||
| `-fe:rt:contextLeft` | Numbers of past frames that are stacked to form input vector for neural network inference | *5* |
|
||||
| `-fe:rt:contextRight` | Numbers of future frames that are stacked to form input vector for neural network inference | *5* |
|
||||
| `-fe:rt:hpfBeta` | High pass filter beta coefficient, where 0.0f means no filtering | *0.0f* |
|
||||
| `-fe:rt:inputDataType` | Feature extraction input format description | *INT16_16KHZ* |
|
||||
| `-fe:rt:cepstralLifter` | Lifting factor | *22.0f* |
|
||||
| `-fe:rt:noDct` | Flag: use DCT as final step or not | *0* |
|
||||
| `-fe:rt:featureTransform` | [Kaldi](https://kaldi-asr.org/) feature transform file that normalizes stacked features for neural network inference | |
|
||||
| `-dec:wfst:acousticModelFName` | Full path to the acoustic model file, for example `openvino_ir.xml`| |
|
||||
| `-dec:wfst:acousticScaleFactor` | The acoustic log likelihood scaling factor | *0.1f* |
|
||||
| `-dec:wfst:beamWidth` | Viterbi search beam width | *14.0f* |
|
||||
| `-dec:wfst:latticeWidth` | Lattice beam width (extends the beam width) | *0.0f* |
|
||||
| `-dec:wfst:nbest` | Number of n-best hypothesis to be generated | *1* |
|
||||
| `-dec:wfst:confidenceAcousticScaleFactor` | Scaling parameter to factor in acoustic scores in confidence computations | *1.0f* |
|
||||
| `-dec:wfst:confidenceLMScaleFactor` | Scaling parameter to factor in language model in confidence computations | *1.0f* |
|
||||
| `-dec:wfst:hmmModelFName` | Full path to HMM model | |
|
||||
| `-dec:wfst:fsmFName` | Full path to pronunciation model or full statically composed LM, if static composition is used | |
|
||||
| `-dec:wfstotf:gramFsmFName` | Full path to grammar model | |
|
||||
| `-dec:wfst:outSymsFName` | Full path to the output symbols (lexicon) filename | |
|
||||
| `-dec:wfst:tokenBufferSize` | Token pool size expressed in number of DWORDs | *150000* |
|
||||
| `-dec:wfstotf:traceBackLogSize` | Number of entries in traceback expressed as log2(N) | *19* |
|
||||
| `-dec:wfstotf:minStableFrames` | The time expressed in frames, after which the winning hypothesis is recognized as stable and the final result can be printed | *45* |
|
||||
| `-dec:wfst:maxCumulativeTokenSize` | Maximum fill rate of token buffer before token beam is adjusted to keep token buffer fill constant. Expressed as factor of buffer size (0.0, 1.0) | *0.2f* |
|
||||
| `-dec:wfst:maxTokenBufferFill` | Active token count number triggering beam tightening expressed as factor of buffer size | *0.6f* |
|
||||
| `-dec:wfst:maxAvgTokenBufferFill` | Average active token count number for utterance, which triggers beam tightening when exceeded. Expressed as factor of buffer size | *1.0f* |
|
||||
| `-dec:wfst:tokenBufferMinFill` | Minimum fill rate of token buffer | *0.1f* |
|
||||
| `-dec:wfst:pruningTighteningDelta` | Beam tightening value when token pool usage reaches the pool capacity | *1.0f* |
|
||||
| `-dec:wfst:pruningRelaxationDelta` | Beam relaxation value when token pool is not meeting minimum fill ratio criterion | *0.5f* |
|
||||
| `-dec:wfst:useScoreTrendForEndpointing` | Extend end pointing with acoustic feedback | *1* |
|
||||
| `-dec:wfstotf:cacheLogSize` | Number of entries in LM cache expressed as log2(N) | *16* |
|
||||
| `-eng:output:format` | Format of the speech recognition output | *text* |
|
||||
| `-inference:contextLeft` | IE: Additional stacking option, independent from feature extraction | *0* |
|
||||
| `-inference:contextRight` | IE: Additional stacking option, independent from feature extraction | *0* |
|
||||
| `-inference:device` | IE: Device used for neural computations | CPU |
|
||||
| `-inference:numThreads` | IE: Number of threads used by GNA in SW mode | *1* |
|
||||
| `-inference:scaleFactor` | IE: Scale factor used for static quantization | *3000.0* |
|
||||
| `-inference:quantizationBits` | IE: Quantization resolution in bits | *16* or *8* |
|
||||
|
||||
|
||||
## Demo Output
|
||||
|
||||
The resulting transcription for the sample audio file:
|
||||
|
||||
```sh
|
||||
[ INFO ] Using feature transformation
|
||||
[ INFO ] InferenceEngine API
|
||||
[ INFO ] Device info:
|
||||
[ INFO ] CPU: MKLDNNPlugin
|
||||
[ INFO ] Batch size: 1
|
||||
[ INFO ] Model loading time: 61.01 ms
|
||||
Recognition result:
|
||||
HOW ARE YOU DOING
|
||||
```
|
@ -1,52 +0,0 @@
|
||||
# Speech Library {#openvino_inference_engine_samples_speech_libs_and_demos_Speech_library}
|
||||
|
||||
## Overview
|
||||
|
||||
Speech Library provides an easy way to work with the end-to-end speech recognition pipeline.
|
||||
The software stack is created to minimize effort required to build speech-enabled applications.
|
||||
Speech Library wraps all of the processing blocks and exposes a simple API. The library takes care of proper initialization and data passing between all the components in the pipeline.
|
||||
|
||||
Speech Library contains:
|
||||
|
||||
- Two core binary libraries in the `lib` folder: Intel® Feature Extraction library and Intel® Speech Decoder
|
||||
- Speech library source code in the `src` folder
|
||||
- Speech library header files in the `include` folder. The library API is in the file `speech_library.h`.
|
||||
|
||||
To compile the libraries, please run a `.bat/.sh` file in the root folder of speech libraries and demos, or run the demonstration script `<INSTALL_DIR>/deployment_tools/demo/speech_recogintion.bat/sh`.
|
||||
|
||||
## Architecture
|
||||
|
||||
The implementation of speech recognition pipeline used in demo applications is based on classic HMM/DNN approach.
|
||||
|
||||
The pipeline consists of the following stages:
|
||||
|
||||
1. Mel-frequency cepstral coefficients (MFCC) feature extraction: the input audio signal or waveform is processed by Intel® Feature Extraction library to create a series of MFCC features
|
||||
2. Neural acoustic scoring: the OpenVINO ™ Inference Engine transcribes the extracted features into a sequence of phonemes using a neural acoustic model
|
||||
3. Language model decoding: the Intel® Speech Decoder turns the phonemes into text hypothesis. The decoding graph takes into account the grammar of the data, as well as the distribution and probabilities of contiguous specific words (n-grams)
|
||||
|
||||

|
||||
|
||||
## Speech Library API
|
||||
|
||||
The Speech Library API consists of simple routines:
|
||||
|
||||
* Build recognizer pipeline
|
||||
* Provide audio samples for processing
|
||||
* Inform about new stable recognition result
|
||||
|
||||
The flow is described below:
|
||||

|
||||
|
||||
See `<INSTALL_DIR>/data_processing/audio/speech_recognition/include/speech_library.h` for details about the API.
|
||||
|
||||
A great example on how to use the API is the source code of [offline speech recognition demo](Offline_speech_recognition_demo.md).
|
||||
|
||||
## Run Your Application
|
||||
|
||||
Before running compiled binary files, make sure your application can find the Inference Engine, Speech, Decoder, and Feature Extraction libraries.
|
||||
|
||||
On Linux* operating systems, including Ubuntu*, the `LD_LIBRARY_PATH` environment variable is usually used to specify directories to search libraries in.
|
||||
|
||||
You can update the `LD_LIBRARY_PATH` with paths to the directories in the Inference Engine installation directory where the libraries are placed.
|
||||
|
||||
Please check `run_demo.sh` of offline and live speech recognition demos to learn how the `LD_LIBRARY_PATH` environment parameter can be set.
|
@ -1,133 +0,0 @@
|
||||
# Speech Library and Speech Recognition Demos {#openvino_inference_engine_samples_speech_libs_and_demos_Speech_libs_and_demos}
|
||||
|
||||
Intel® distributions of OpenVINO™ toolkit for Linux* OS and Windows* OS provide a set of libraries and demos to
|
||||
demonstrate end-to-end speech recognition, as well as new acoustic and language models that can work with these demos.
|
||||
The libraries are designed for preprocessing (feature extraction) to get a feature vector from a speech signal, as well
|
||||
as postprocessing (decoding) to produce text from scores. Together with OpenVINO™-based neural-network speech recognition,
|
||||
these libraries provide an end-to-end pipeline converting speech to text. This pipeline is demonstrated by the
|
||||
end-to-end demos:
|
||||
|
||||

|
||||
|
||||
Note that the OpenVINO™ package also includes an [automatic speech recognition sample](../speech_sample/README.md) demonstrating acoustic model inference based on Kaldi\* neural networks. The sample works with Kaldi ARK files only, so it does not cover an end-to-end speech recognition scenario (speech to text),requiring additional preprocessing (feature extraction) to get a feature vector from a speech signal, as well as postprocessing (decoding) to produce text from scores:
|
||||
|
||||

|
||||
|
||||
The main purpose of the sample is to demonstrate a variety of features and options provided by OpenVINO™
|
||||
for speech recognition neural networks.
|
||||
|
||||
Find new libraries, demos, and models at `<INSTALL_DIR>/data_processing/audio/speech_recognition`.
|
||||
|
||||
> **NOTE**: These components are installed only if you select the **Inference Engine Runtime for Intel® Gaussian & Neural Accelerator** component during installation. However, the Speech Library and speech recognition demos do not require the GNA accelerator. See <a href="#hardware-support">Hardware support</a> for details.
|
||||
|
||||
## Package Components
|
||||
|
||||
The package contains the following components:
|
||||
|
||||
* [Speech Library](Speech_library.md), which includes a feature extractor and decoder
|
||||
|
||||
* [Offline Speech Recognition Demo](Offline_speech_recognition_demo.md), which can process wave files with recorded speech
|
||||
|
||||
* [Live Speech Recognition Demo](Live_speech_recognition_demo.md), which showcases transcription from a microphone or speakers
|
||||
|
||||
* [Kaldi Statistical Language Model Conversion Tool](Kaldi_SLM_conversion_tool.md), which converts custom language models to use in the decoder
|
||||
|
||||
Additionally, new acoustic and language models are available in the OpenVINO™ [storage](https://storage.openvinotoolkit.org/models_contrib/speech/2021.2/librispeech_s5/).
|
||||
|
||||
## <a name="run-demos">Run Speech Recognition Demos with Pre-trained Models</a>
|
||||
|
||||
To download pre-trained models and build all dependencies:
|
||||
|
||||
* On Linux* OS, use the shell script `<INSTALL_DIR>/deployment_tools/demo/demo_speech_recognition.sh`
|
||||
|
||||
* On Windows* OS, use the batch file `<INSTALL_DIR>\deployment_tools\demo\demo_speech_recognition.bat`
|
||||
|
||||
The script follows the steps below:
|
||||
|
||||
1. Downloads US English models trained on the LibriSpeech dataset prepared for direct usage by the Inference Engine
|
||||
2. Installs the required components
|
||||
3. Runs the command line offline demo
|
||||
4. Runs live speech recognition application with graphical interface
|
||||
|
||||
If you are behind a proxy, set the following environment variables in a console session before running the script:
|
||||
|
||||
* On Linux* OS:
|
||||
|
||||
```sh
|
||||
export http_proxy=http://{proxyHost}:{proxyPort}
|
||||
export https_proxy=https://{proxyHost}:{proxyPort}
|
||||
```
|
||||
|
||||
* On Windows* OS:
|
||||
|
||||
```sh
|
||||
set http_proxy=http://{proxyHost}:{proxyPort}
|
||||
set https_proxy=https://{proxyHost}:{proxyPort}
|
||||
```
|
||||
|
||||
## <a name="hardware-support">Hardware Support</a>
|
||||
|
||||
The provided acoustic models have been tested on a CPU, graphics processing unit (GPU), and Intel® Gaussian & Neural Accelerator (Intel® GNA), and you can switch between these targets in offline and live speech recognition demos.
|
||||
|
||||
> **NOTE**: Intel® GNA is a specific low-power coprocessor, which offloads some workloads, thus saving power and CPU resources. If you use a processor supporting the GNA, such as Intel® Core™ i3-8121U and Intel® Core™ i7-1065G7, you can notice that CPU load is much lower when GNA is selected. If you selected GNA as a device for inference, and your processor does not support GNA, then execution is performed in the emulation mode (on CPU) because `GNA_AUTO` configuration option is used.
|
||||
> See [the GNA plugin documentation](https://docs.openvinotoolkit.org/latest/_docs_IE_DG_supported_plugins_GNA.html) for more information.
|
||||
|
||||
Speech Library provides a highly optimized implementation of preprocessing and postprocessing (feature extraction and decoding) on CPU only.
|
||||
|
||||
## Custom Models Requirements
|
||||
|
||||
Before running demonstration applications with custom models, follow the steps below:
|
||||
|
||||
1. Build the Speech Library and demonstration application using the `demo_speech_recognition.sh/.bat` file mentioned in <a href="#run-demos">Run Speech Recognition Demos with Pre-trained Models</a>
|
||||
2. Train acoustic and statistical language models using the Kaldi framework (if required)
|
||||
3. [Convert the acoustic model](../../../docs/MO_DG/prepare_model/convert_model/Convert_Model_From_Kaldi.md) using Model Optimizer for Kaldi
|
||||
4. [Convert the language model](Kaldi_SLM_conversion_tool.md) using the Kaldi toolkit and provided converter
|
||||
5. Create a configuration file that lists all the models required for recognition
|
||||
6. Copy configuration file to `{OpenVINO build folder}/data_processing/audio/speech_recognition/models/{LANG}`. The demo models are trained for US English, so use `en-us` for the `{LANG}` folder name.
|
||||
|
||||
Then you can use new models in the live speech recognition demo.
|
||||
To perform speech recognition using a new model and the command-line application, provide the path to the new configuration file as an input argument of the application.
|
||||
|
||||
## Convert Acoustic Model with OpenVINO™ Model Optimizer for Kaldi*
|
||||
|
||||
In order to convert acoustic models, the following Kaldi files are required:
|
||||
|
||||
- Acoustic model file, `final.nnet` – RAW neural network without topology information
|
||||
- Counts file, `pdf.counts` (if used)
|
||||
- Feature transformation file, `final.feature_transform` (if used)
|
||||
|
||||
For conversion steps, follow [Converting a Kaldi* Model](../../../docs/MO_DG/prepare_model/convert_model/Convert_Model_From_Kaldi.md).
|
||||
|
||||
> **NOTE**: Set the path to the XML file with the converted model in the configuration file.
|
||||
|
||||
## Convert Language Model with Provided Converter
|
||||
|
||||
In order to convert language models, the following Kaldi files are required:
|
||||
- Acoustic model with the Hidden Markov Model (HMM) topology, `final.mdl`
|
||||
- Language model Weighted Finite-State Transducers (WFST) graph, `HCLG.wfst`
|
||||
- Label symbol file, `words.txt`.
|
||||
|
||||
All these files are required to create resources for demo applications.
|
||||
|
||||
Model conversion from Kaldi requires the following steps:
|
||||
|
||||
1. Save HCLG WFST as the openFST const type:
|
||||
```
|
||||
$KALDI_ROOT/tools/openfst/bin/fstconvert --fst_type=const HCLG.fst HCLG.const_fst
|
||||
```
|
||||
|
||||
2. Generate transition ID information using `phones.txt` and `final.mdl`:
|
||||
```
|
||||
$KALDI_ROOT/src/bin/show-transitions phones.txt final.mdl > transitions.txt
|
||||
```
|
||||
|
||||
3. Convert HCLG WFST using resource conversion executable:
|
||||
```
|
||||
kaldi_slm_conversion_tool HCLG.const_fst transitions.txt words.txt cl.fst labels.bin
|
||||
```
|
||||
|
||||
> **NOTE**: Put the paths to `cl.fst` and `labels.bin` files in the configuration file to use them with the Live Speech Recognition Demo Application.
|
||||
|
||||
See the [offline speech recognition demo documentation](Offline_speech_recognition_demo.md) to learn about the configuration file format.
|
||||
|
||||
See [Kaldi* Statistical Language Model Conversion Tool](Kaldi_SLM_conversion_tool.md) for more information on the conversion tool.
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:620b57920f377725860bd8a2add72a8a9f33e5c1aee3aaf560c02d91f31817e1
|
||||
size 115336
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:844ddf108bceba2b7ca998809927aae6d463e32d5f4ee82af19e480752fac1c1
|
||||
size 658
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:11089fbf62b5011215a6767d83920eacc292405f013842b753e58bf5fa2555ec
|
||||
size 865
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b54202be4a359184b17bfa501bd77a6efb5b66c1d5c8da7eb8e4ce53d1192cd1
|
||||
size 717
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:62a4231a21a63ec9e2fc186da14ed3059498e3fc4723cdca8d771622633ef92d
|
||||
size 845
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:01a0e88e53ea004212e6386b920bb5bffd7269164a3535f006c733aa0c25adf0
|
||||
size 764
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:43c0ddfa3ff5f9edd30e03b15b0de1cdc201fb3f03ebc3f014db4c3bd0f03630
|
||||
size 802
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0103887a8cfc7fc03150762ab21fd7c1991ce9bfa8b1ae2ffc62051026e29543
|
||||
size 695
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:65d16aaa28d91b4fc2c0e3ee7e570fabb03e95d9178b96bff1580b7b0ce14b03
|
||||
size 776
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8984135312b3342675bcfa20351738a97c7e795c641b377d91b6b92cd071eb62
|
||||
size 843
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fd0a6dbecfb8b6fa9b7514372ff1ab3ca361dced155cbdf14ba3c7c41b7c357e
|
||||
size 709
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d9775faed5d611f092bb7f7254fe1d1232bd90b0960ebd03d3fe69a24236febe
|
||||
size 840
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4931e6ddeeb2512e997c2fc8a51b064f4e753470d7bf6661120ff4580b7830bd
|
||||
size 846
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b54c302e27470d139285c0958026b0b64d9f91ea41b0bf030e6be933bfd06d02
|
||||
size 86985
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1dea7d13713a9c37bad21a1df8077d43696f67cbbab5eb1e87b92b3b91b2a29
|
||||
size 29798
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1b8a55d04d813db595af843de7dfb89f0776709ec6d1b46d795802cbe424b304
|
||||
size 20367
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:62b8315db0436299decd3e60c7a2171d575735e405610312bb9784e6a1b0b79c
|
||||
size 29120
|
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f965a796362dc4881e1bc129e6936d582d923ae83f2d97d3798fd186f152bbc9
|
||||
size 15294
|
@ -36,6 +36,17 @@ struct ScaleFactorUpdateResult {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Calculates a scale factor from FakeQuantize statistics according to the formula:
|
||||
* scale factor = max representable value / max absolute input value
|
||||
* @param levels Number of integer quants
|
||||
* @param minValue Minimum value to be quantized
|
||||
* @param maxValue Maximum value to be quantized
|
||||
*/
|
||||
inline float CalculateScaleFactorFromStats(size_t levels, float minValue, float maxValue) {
|
||||
return maxValue == minValue ? 1.0f : (levels - 1) / (maxValue - minValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Compares two float values and returns if they are equal
|
||||
* @param p1 First float value
|
||||
@ -372,7 +383,7 @@ class ScaleFactorPerLayer<InferenceEngine::CNNLayer *> {
|
||||
auto maxOutValue = quantizedParams->_dst_quant.GetMaxValues().front();
|
||||
auto absMax = std::max(std::abs(minOutValue), std::abs(maxOutValue));
|
||||
|
||||
result = (quantizedParams->_dst_quant.GetLevels() - 1) / (maxOutValue - minOutValue);
|
||||
result = CalculateScaleFactorFromStats(quantizedParams->_dst_quant.GetLevels(), minOutValue, maxOutValue);
|
||||
if (std::isinf(result) || fp32eq(absMax, 0.0f)) {
|
||||
result = max_activation_scale_factor;
|
||||
}
|
||||
@ -452,7 +463,7 @@ class ScaleFactorPerLayer<InferenceEngine::CNNLayer *> {
|
||||
if (CNNNetHasPrevLayer(cnnLayer) && quant->_dst_quant.IsStatsSet() && !quant->_dst_quant.IsScaleSet()) {
|
||||
auto minOutValue = quant->_dst_quant.GetMinValues().front();
|
||||
auto maxOutValue = quant->_dst_quant.GetMaxValues().front();
|
||||
auto scale = (quant->_dst_quant.GetLevels() - 1) / (maxOutValue - minOutValue);
|
||||
auto scale = CalculateScaleFactorFromStats(quant->_dst_quant.GetLevels(), minOutValue, maxOutValue);
|
||||
quant->_dst_quant.SetScale(scale);
|
||||
quant->_src_quant = quant->_dst_quant;
|
||||
}
|
||||
@ -1068,8 +1079,8 @@ class ScaleFactorPerLayer<InferenceEngine::WeightableLayer*> {
|
||||
quant->_src_quant = quantDataForInputLayer->_dst_quant;
|
||||
if (quant->_weights_quant.IsStatsSet() && !quant->_weights_quant.IsScaleSet()) {
|
||||
auto getScale = [&quant](size_t i) {
|
||||
auto valuesDiff = quant->_weights_quant.GetMaxValues(false)[i] - quant->_weights_quant.GetMinValues(false)[i];
|
||||
return valuesDiff == 0 ? 1.0f : (quant->_weights_quant.GetLevels() - 1) / valuesDiff;
|
||||
return CalculateScaleFactorFromStats(quant->_weights_quant.GetLevels(),
|
||||
quant->_weights_quant.GetMinValues(false)[i], quant->_weights_quant.GetMaxValues(false)[i]);
|
||||
};
|
||||
|
||||
float min_channel_scale = getScale(0);
|
||||
@ -1222,9 +1233,8 @@ public:
|
||||
quantData->_weights_quant.SetScale(quantParams1->_dst_quant.GetScale());
|
||||
if (quantData->_src_quant.IsStatsSet()) {
|
||||
auto getScale = [&quantParams0](size_t i) {
|
||||
return (quantParams0->_dst_quant.GetLevels() - 1) /
|
||||
(quantParams0->_dst_quant.GetMaxValues(false)[i] -
|
||||
quantParams0->_dst_quant.GetMinValues(false)[i]);
|
||||
return CalculateScaleFactorFromStats(quantParams0->_dst_quant.GetLevels(),
|
||||
quantParams0->_dst_quant.GetMinValues(false)[i], quantParams0->_dst_quant.GetMaxValues(false)[i]);
|
||||
};
|
||||
float min_channel_scale = getScale(0);
|
||||
quantParams0->_dst_quant.SetScale(min_channel_scale);
|
||||
|
@ -489,13 +489,9 @@ void GNAPlugin::UpdateInputScaleFromNetwork(InferenceEngine::CNNNetwork & networ
|
||||
return (std::abs(p1 - p2) <= 0.00001f * std::min(std::abs(p1), std::abs(p2)));
|
||||
};
|
||||
// GNA input is always quantized to int16, so number of levels can't be greater than max uint16
|
||||
size_t levels = std::min(fqLayer.getLevels(), static_cast<size_t>(std::numeric_limits<uint16_t>::max()));
|
||||
float scaleInput = (levels - 1) / (inputRange.second[0] - inputRange.first[0]);
|
||||
auto minAbsVal = std::min(std::abs(inputRange.second[0]), std::abs(inputRange.first[0]));
|
||||
auto maxAbsVal = std::max(std::abs(inputRange.second[0]), std::abs(inputRange.first[0]));
|
||||
if (fp32eq(minAbsVal, 0.0f) && !fp32eq(maxAbsVal, 0.0f)) {
|
||||
scaleInput = (fqLayer.getLevels() - 1) / (2 * maxAbsVal);
|
||||
}
|
||||
// todo: should be solved in POT (issue 63330)
|
||||
size_t levels = std::min(fqLayer.getLevels(), static_cast<size_t>(std::numeric_limits<uint16_t>::max() + 1));
|
||||
auto scaleInput = frontend::CalculateScaleFactorFromStats(levels, inputRange.first[0], inputRange.second[0]);
|
||||
|
||||
IE_ASSERT(config.inputScaleFactors.size() > inputIdx);
|
||||
IE_ASSERT(inputsDesc->inputScaleFactors.size() > inputIdx);
|
||||
@ -1616,7 +1612,7 @@ InferenceEngine::IExecutableNetworkInternal::Ptr GNAPlugin::ImportNetwork(std::i
|
||||
// If scale factors are defined in configuration we still need to use them instead of imported values,
|
||||
// for example to change the scale factors for the old models.
|
||||
if (!config.inputScaleFactors.empty()) {
|
||||
IE_ASSERT(config.inputScaleFactors.size() == inputsDesc->inputScaleFactors.size());
|
||||
IE_ASSERT(config.inputScaleFactors.size() <= inputsDesc->inputScaleFactors.size());
|
||||
for (size_t i = 0; i < config.inputScaleFactors.size(); ++i) {
|
||||
if (config.inputScaleFactors[i] != GNAPluginNS::kScaleFactorDefault) {
|
||||
gnalog() << "[Import Network] Using input scale factor defined in configuration for input " << i << std::endl;
|
||||
|
@ -95,8 +95,8 @@ void op::FrameworkNode::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
constexpr ov::DiscreteTypeInfo ov::AttributeAdapter<op::FrameworkNodeAttrs>::type_info;
|
||||
constexpr ov::DiscreteTypeInfo ov::AttributeAdapter<ngraph::op::FrameworkNodeAttrs>::type_info;
|
||||
|
||||
ov::AttributeAdapter<op::FrameworkNodeAttrs>::AttributeAdapter(
|
||||
op::FrameworkNodeAttrs& value)
|
||||
: DirectValueAccessor<op::FrameworkNodeAttrs>(value) {}
|
||||
ov::AttributeAdapter<ngraph::op::FrameworkNodeAttrs>::AttributeAdapter(
|
||||
ngraph::op::FrameworkNodeAttrs& value)
|
||||
: DirectValueAccessor<ngraph::op::FrameworkNodeAttrs>(value) {}
|
||||
|
@ -325,7 +325,7 @@ public:
|
||||
m_xml_node.append_attribute("offset").set_value(offset);
|
||||
m_xml_node.append_attribute("size").set_value(size);
|
||||
}
|
||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<op::FrameworkNodeAttrs>>(&adapter)) {
|
||||
} else if (const auto& a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::FrameworkNodeAttrs>>(&adapter)) {
|
||||
const auto & attrs = a->get();
|
||||
|
||||
// Update type and version attributes
|
||||
@ -623,7 +623,7 @@ bool resolve_dynamic_shapes(const ngraph::Function& f) {
|
||||
auto & op = f_ops[id];
|
||||
auto & clone_op = f_clone_ops[id];
|
||||
|
||||
if (auto op_subgraph = std::dynamic_pointer_cast<op::util::SubGraphOp>(op)) {
|
||||
if (auto op_subgraph = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(op)) {
|
||||
resolve_dynamic_shapes(*op_subgraph->get_function());
|
||||
}
|
||||
|
||||
@ -811,8 +811,34 @@ void ngfunction_2_irv10(pugi::xml_node& netXml,
|
||||
f.validate_nodes_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
std::string valid_xml_path(const std::string &path) {
|
||||
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
|
||||
|
||||
const char *const extension = ".xml";
|
||||
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
|
||||
NGRAPH_CHECK(has_xml_extension,
|
||||
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
|
||||
path + "\"");
|
||||
return path;
|
||||
}
|
||||
|
||||
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
|
||||
if (!binPath.empty()) {
|
||||
return binPath;
|
||||
}
|
||||
assert(xmlPath.size() > 4); // should be check by valid_xml_path
|
||||
std::string bestPath = xmlPath;
|
||||
const char *const extension = "bin";
|
||||
const auto ext_size = std::strlen(extension);
|
||||
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
|
||||
return bestPath;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace ngraph {
|
||||
|
||||
// ! [function_pass:serialize_cpp]
|
||||
// serialize.cpp
|
||||
bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
@ -868,33 +894,6 @@ bool pass::Serialize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::string valid_xml_path(const std::string &path) {
|
||||
NGRAPH_CHECK(path.length() > 4, "Path for xml file is to short: \"" + path + "\"");
|
||||
|
||||
const char *const extension = ".xml";
|
||||
const bool has_xml_extension = path.rfind(extension) == path.size() - std::strlen(extension);
|
||||
NGRAPH_CHECK(has_xml_extension,
|
||||
"Path for xml file doesn't contains file name with 'xml' extension: \"" +
|
||||
path + "\"");
|
||||
return path;
|
||||
}
|
||||
|
||||
std::string provide_bin_path(const std::string &xmlPath, const std::string &binPath) {
|
||||
if (!binPath.empty()) {
|
||||
return binPath;
|
||||
}
|
||||
assert(xmlPath.size() > 4); // should be check by valid_xml_path
|
||||
std::string bestPath = xmlPath;
|
||||
const char *const extension = "bin";
|
||||
const auto ext_size = std::strlen(extension);
|
||||
bestPath.replace(bestPath.size() - ext_size, ext_size, extension);
|
||||
return bestPath;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
pass::Serialize::Serialize(std::ostream& xmlFile,
|
||||
std::ostream& binFile,
|
||||
pass::Serialize::Version version,
|
||||
@ -921,3 +920,4 @@ pass::Serialize::Serialize(const std::string& xmlPath,
|
||||
{
|
||||
}
|
||||
// ! [function_pass:serialize_cpp]
|
||||
} // namespace ngraph
|
||||
|
@ -73,8 +73,9 @@ protected:
|
||||
std::tie(netPrecision, targetDevice, configuration, inputShape, inputMinMax, levels) = this->GetParam();
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
|
||||
auto inputLowNode = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputMinMax.first });
|
||||
auto inputHighNode = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputMinMax.second });
|
||||
std::tie(inputDataMin, inputDataMax) = inputMinMax;
|
||||
auto inputLowNode = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMin });
|
||||
auto inputHighNode = ngraph::builder::makeConstant<float>(ngPrc, {1}, { inputDataMax });
|
||||
|
||||
auto inputVector = ngraph::builder::makeParams(ngPrc, {inputShape});
|
||||
|
||||
|
@ -0,0 +1,118 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "base/import_export_base/import_export_base.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
class ImportMultiInput : public FuncTestUtils::ImportNetworkTestBase {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(netPrecision, targetDevice, exportConfiguration, importConfiguration, applicationHeader) = this->GetParam();
|
||||
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto input = ngraph::builder::makeParams(ngPrc, {{1, 10}, {1, 10}});
|
||||
auto mul1 = ngraph::builder::makeEltwise(input[0], input[1], ngraph::helpers::EltwiseTypes::ADD);
|
||||
auto result = std::make_shared<ngraph::opset7::Result>(mul1);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, input, "multiple_input");
|
||||
}
|
||||
};
|
||||
|
||||
class ImportMultiInputChanged : public ImportMultiInput {};
|
||||
class ImportMultiInputUnchanged : public ImportMultiInput {};
|
||||
|
||||
TEST_P(ImportMultiInputUnchanged, CompareWithRefImpl) {
|
||||
TestRun(false);
|
||||
};
|
||||
|
||||
TEST_P(ImportMultiInputChanged, CompareWithRefImpl) {
|
||||
TestRun(true);
|
||||
};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> exportConfigs = {
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "327.67"},
|
||||
{"GNA_SCALE_FACTOR_1", "327.67"}
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> importConfigsChanged = {
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "32767"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_1", "32767"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "32767"},
|
||||
{"GNA_SCALE_FACTOR_1", "32767"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "1"},
|
||||
{"GNA_SCALE_FACTOR_1", "32767"}
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::map<std::string, std::string>> importConfigsUnchanged = {
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "327.67"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "1"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_0", "327.67"},
|
||||
{"GNA_SCALE_FACTOR_1", "327.67"}
|
||||
},
|
||||
{
|
||||
{"GNA_DEVICE_MODE", "GNA_SW_EXACT"},
|
||||
{"GNA_SCALE_FACTOR_1", "327.67"}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_ImportNetworkGNA, ImportMultiInputUnchanged,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(exportConfigs),
|
||||
::testing::ValuesIn(importConfigsUnchanged),
|
||||
::testing::Values("")),
|
||||
ImportMultiInputUnchanged::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_ImportNetworkGNA, ImportMultiInputChanged,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GNA),
|
||||
::testing::ValuesIn(exportConfigs),
|
||||
::testing::ValuesIn(importConfigsChanged),
|
||||
::testing::Values("")),
|
||||
ImportMultiInputChanged::getTestCaseName);
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -5,11 +5,6 @@
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include <legacy/details/ie_cnn_network_iterator.hpp>
|
||||
|
||||
::std::ostream& ngraph::operator << (::std::ostream & os, const Function&) {
|
||||
throw std::runtime_error("should not be called");
|
||||
return os;
|
||||
}
|
||||
|
||||
namespace CommonTestUtils {
|
||||
|
||||
IE_SUPPRESS_DEPRECATED_START
|
||||
|
@ -17,10 +17,6 @@
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
#include <ngraph/function.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
::std::ostream& operator << (::std::ostream &, const Function&);
|
||||
}
|
||||
|
||||
namespace InferenceEngine {
|
||||
class CNNLayer;
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import xml.etree.ElementTree as ET
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from utils import constants
|
||||
from utils import utils
|
||||
|
||||
logger = utils.get_logger('Summarize')
|
||||
@ -167,11 +166,9 @@ def create_summary(summary_root: ET.Element, output_folder: os.path, report_tag:
|
||||
env = Environment(loader=file_loader)
|
||||
template = env.get_template('report_template.html')
|
||||
|
||||
verified_operations = constants.VERIFIED_OP_REFERENCES
|
||||
|
||||
res_summary = template.render(ordered_ops=op_list, devices=device_list, results=results, timestamp=timestamp,
|
||||
general_pass_rate=general_pass_rate, pass_rate_avg=pass_rate_avg,
|
||||
verified_operations=verified_operations, trusted_ops=trusted_ops,
|
||||
trusted_ops=trusted_ops,
|
||||
general_test_count=general_test_count, report_tag=report_tag)
|
||||
|
||||
report_path = os.path.join(output_folder, f'{output_filename}.html')
|
||||
|
@ -22,10 +22,6 @@
|
||||
<div class="main">
|
||||
<h2>Operations coverage summary: {{report_tag}} {{ timestamp }}</h2>
|
||||
<div class="legend">
|
||||
<div>
|
||||
<span class="border colorRed">Acosh-4</span><span>Not verified Ngraph references</span>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<span class="table-primary border"></span><span>Collected statistic info</span>
|
||||
</div>
|
||||
@ -127,8 +123,7 @@
|
||||
<tbody id="data">
|
||||
{% for op in ordered_ops -%}
|
||||
<tr>
|
||||
<th scope="row" {% if op not in verified_operations -%} class="colorRed" {% endif -%} name="{{ op }}">{{
|
||||
op }}</th>
|
||||
<th scope="row" name="{{ op }}">{{ op }}</th>
|
||||
|
||||
{% for d in devices -%}
|
||||
{% if op in results[d] -%}
|
||||
|
@ -1,131 +0,0 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
VERIFIED_OP_REFERENCES = [
|
||||
'Abs-1',
|
||||
'Acos-1',
|
||||
'Acosh-3',
|
||||
'Add-1',
|
||||
'Asin-1',
|
||||
'Asinh-3',
|
||||
'Assign-6',
|
||||
'AvgPool-1',
|
||||
'BatchNormInference-5',
|
||||
'BatchToSpace-2',
|
||||
'BinaryConvolution-1',
|
||||
'Broadcast-1',
|
||||
'Broadcast-3',
|
||||
'Bucketize-3',
|
||||
'Ceiling-1',
|
||||
'CTCGreedyDecoder-1',
|
||||
'CTCGreedyDecoderSeqLen-6',
|
||||
'Concat-1',
|
||||
'Convert-1',
|
||||
'ConvertLike-1',
|
||||
'Convolution-1',
|
||||
'Constant-1',
|
||||
'Cos-1',
|
||||
'Cosh-1',
|
||||
'DeformableConvolution-1',
|
||||
'DeformablePSROIPooling-1',
|
||||
'DepthToSpace-1',
|
||||
'DetectionOutput-1',
|
||||
'Divide-1',
|
||||
'Equal-1',
|
||||
'Erf-1',
|
||||
'ExperimentalDetectronDetectionOutput-6',
|
||||
'ExperimentalDetectronGenerateProposalsSingleImage-6',
|
||||
'ExperimentalDetectronPriorGridGenerator-6',
|
||||
'ExperimentalDetectronROIFeatureExtractor-6',
|
||||
'ExperimentalDetectronTopKROIs-6',
|
||||
'FakeQuantize-1',
|
||||
'Floor-1'
|
||||
'FloorMod-1'
|
||||
'GRUSequence-5',
|
||||
'Gather-1',
|
||||
'GatherElements-6',
|
||||
'GatherND-5',
|
||||
'Gelu-7',
|
||||
'Greater-1',
|
||||
'GreaterEqual-1',
|
||||
'GRN-1',
|
||||
'GroupConvolution-1',
|
||||
'GroupConvolutionBackpropData-1',
|
||||
'GRUSequence-5',
|
||||
'HSigmoid-5',
|
||||
'HSwish-4',
|
||||
'HardSigmoid-1',
|
||||
'Interpolate-4',
|
||||
'Less-1',
|
||||
'LessEqual-1'
|
||||
'LRN-1',
|
||||
'LSTMCell-4',
|
||||
'LSTMSequence-5',
|
||||
'LogicalAnd-1',
|
||||
'LogicalNot-1'
|
||||
'LogicalOr-1'
|
||||
'LogicalXor-1'
|
||||
'LogSoftmax-5',
|
||||
'Loop-5',
|
||||
'MVN-1',
|
||||
'MVN-6',
|
||||
'Maximum-1',
|
||||
'MaxPool-1',
|
||||
'Mish-4',
|
||||
'Multiply-1',
|
||||
'Negative-1',
|
||||
'NonMaxSuppression-4',
|
||||
'NonMaxSuppression-5',
|
||||
'NonZero-3',
|
||||
'NormalizeL2-1',
|
||||
'PriorBox-1',
|
||||
'PriorBoxClustered-1',
|
||||
'Proposal-1',
|
||||
'Proposal-4',
|
||||
'PSROIPooling-1',
|
||||
'RNNSequence-5',
|
||||
'ROIAlign-3',
|
||||
'ROIPooling-2',
|
||||
'Range-1',
|
||||
'Range-4',
|
||||
'ReadValue-6',
|
||||
'ReduceL1-4',
|
||||
'ReduceL2-4',
|
||||
'ReduceLogicalAnd-1',
|
||||
'ReduceLogicalOr-1',
|
||||
'ReduceMax-1',
|
||||
'ReduceMean-1',
|
||||
'ReduceMin-1',
|
||||
'ReduceProd-1',
|
||||
'ReduceSum-1',
|
||||
'RegionYOLO-1',
|
||||
'Relu-1',
|
||||
'ReorgYOLO-2',
|
||||
'Result-1'
|
||||
'ReverseSequence-1',
|
||||
'Round-5',
|
||||
'SpaceToDepth-1',
|
||||
'ScatterElementsUpdate-3',
|
||||
'ScatterNDUpdate-4',
|
||||
'Select-1',
|
||||
'ShapeOf-1',
|
||||
'ShapeOf-3',
|
||||
'ShuffleChannels-1',
|
||||
'Sigmoid-1',
|
||||
'Sign-1',
|
||||
'Sin-1',
|
||||
'Sinh-1'
|
||||
'SoftPlus-4',
|
||||
'Softmax-1',
|
||||
'Split-1',
|
||||
'Squeeze-1',
|
||||
'StridedSlice-1',
|
||||
'Subtract-1',
|
||||
'Swish-4',
|
||||
'Tile-1',
|
||||
'TopK-1',
|
||||
'TopK-3',
|
||||
'Transpose-1',
|
||||
'Unsqueeze-1',
|
||||
'VariadicSplit-1',
|
||||
]
|
@ -78,7 +78,7 @@ protected:
|
||||
class CreateBaseDecorator : public CreateGraphDecorator {
|
||||
public:
|
||||
// always the first decorator => no prev_builder
|
||||
CreateBaseDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{1, 64, 4096, 4096}) :
|
||||
CreateBaseDecorator(const ngraph::Shape& input_data_shape = ngraph::Shape{1, 64, 1, 4096}) :
|
||||
CreateGraphDecorator(nullptr),
|
||||
input_data_shape_(input_data_shape) {}
|
||||
protected:
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "convolution_kernel_b_fs_yx_fsv16.h"
|
||||
#include "kernel_selector_utils.h"
|
||||
#include "reorder/reorder_kernel_base.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
@ -95,6 +96,8 @@ ParamsKey ConvolutionKernel_b_fs_yx_fsv16::GetSupportedKey() const {
|
||||
|
||||
k.EnableInputLayout(DataLayout::b_fs_yx_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::b_fs_yx_fsv16);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
|
||||
k.EnableTensorOffset();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableDilation();
|
||||
@ -176,12 +179,28 @@ bool ConvolutionKernel_b_fs_yx_fsv16::Validate(const Params& p, const optional_p
|
||||
if (input.Feature().pad.before % tuning_data.feature_block_size != 0 || output.Feature().pad.before % tuning_data.feature_block_size != 0)
|
||||
return false;
|
||||
|
||||
// Not supporting batch padding for different format (reorder-fused case)
|
||||
if (input.GetLayout() == DataLayout::b_fs_yx_fsv16 && output.GetLayout() == DataLayout::bfyx) {
|
||||
if (output.Batch().pad.before != 0 || output.Batch().pad.after != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!params.bias.empty() && params.bias[0].GetDType() != input.GetDType())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool post_reorder_fused(const convolution_params& params) {
|
||||
if (!params.fused_ops.empty()) {
|
||||
if (params.fused_ops.back().GetType() == KernelType::REORDER) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
JitConstants ConvolutionKernel_b_fs_yx_fsv16::GetJitConstants(const convolution_params& params,
|
||||
const DispatchData& dispatchData) const {
|
||||
auto input = params.inputs[0];
|
||||
@ -190,8 +209,18 @@ JitConstants ConvolutionKernel_b_fs_yx_fsv16::GetJitConstants(const convolution_
|
||||
|
||||
ConvolutionTuningData tuning_data = GetTuningParams(params);
|
||||
|
||||
if (post_reorder_fused(params) &&
|
||||
input.GetLayout() == DataLayout::b_fs_yx_fsv16 &&
|
||||
output.GetLayout() == DataLayout::bfyx) {
|
||||
jit.AddConstant(MakeJitConstant("OUTPUT_FORMAT_BFYX", 1));
|
||||
}
|
||||
|
||||
auto blockWidth = dispatchData.cldnnStyle.blockWidth;
|
||||
if (!params.fused_ops.empty()) {
|
||||
DataLayout orig_output_layout = output.GetLayout();
|
||||
if (post_reorder_fused(params)) {
|
||||
orig_output_layout = params.fused_ops.back().GetOpParams<reorder_fuse_params>()->input_layout;
|
||||
}
|
||||
auto input_dt = GetActivationType(params);
|
||||
FusedOpsConfiguration conf_vec = { "_VEC",
|
||||
{"b", "(feature_block * 16)", "y", "x"},
|
||||
@ -201,7 +230,8 @@ JitConstants ConvolutionKernel_b_fs_yx_fsv16::GetJitConstants(const convolution_
|
||||
LoadType::LT_ALIGNED_READ,
|
||||
BoundaryCheck::ENABLED,
|
||||
IndexType::TENSOR_COORD,
|
||||
Tensor::DataChannelName::X };
|
||||
Tensor::DataChannelName::X,
|
||||
{}, false, "", orig_output_layout };
|
||||
FusedOpsConfiguration conf_scalar = { "_SCALAR",
|
||||
{"b", "(feature_block * 16)", "y", "(x + i)"},
|
||||
"dst[i]",
|
||||
@ -210,7 +240,8 @@ JitConstants ConvolutionKernel_b_fs_yx_fsv16::GetJitConstants(const convolution_
|
||||
LoadType::LT_ALIGNED_READ,
|
||||
BoundaryCheck::ENABLED,
|
||||
IndexType::TENSOR_COORD,
|
||||
Tensor::DataChannelName::X };
|
||||
Tensor::DataChannelName::X,
|
||||
{}, false, "", orig_output_layout };
|
||||
jit.Merge(MakeFusedOpsJitConstants(params, {conf_vec, conf_scalar}));
|
||||
}
|
||||
|
||||
|
@ -30,10 +30,15 @@ protected:
|
||||
return (p.groups > 1) ? WeightsLayout::g_os_is_yx_isv16_osv16 : WeightsLayout::os_is_yx_isv16_osv16;
|
||||
}
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
// FusedOpType::REORDER should be registered explicitly here
|
||||
// only when fused_primitive_desc for reorder is added by optimization passes (e.g., remove_redundant_reorder) for corresponding primitive.
|
||||
// The typical usage for fused_primitive_desc for convolution is to get original output layout from jitter,
|
||||
// so that it can decide whether to fuse eltwise along with reorder.
|
||||
return { FusedOpType::ELTWISE,
|
||||
FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION };
|
||||
FusedOpType::ACTIVATION,
|
||||
FusedOpType::REORDER };
|
||||
}
|
||||
|
||||
bool NeedPaddedInput() const override { return false; }
|
||||
|
@ -27,10 +27,15 @@ protected:
|
||||
return (params.groups > 1) ? WeightsLayout::goizyx : WeightsLayout::oizyx;
|
||||
}
|
||||
std::vector<FusedOpType> GetSupportedFusedOps() const override {
|
||||
// FusedOpType::REORDER should be registered explicitly here
|
||||
// only when fused_primitive_desc for reorder is added by optimization passes (e.g., remove_redundant_reorder) for corresponding primitive.
|
||||
// The typical usage for fused_primitive_desc for convolution is to get original output layout from jitter,
|
||||
// so that it can decide whether to fuse eltwise along with reorder.
|
||||
return { FusedOpType::ELTWISE,
|
||||
FusedOpType::QUANTIZE,
|
||||
FusedOpType::SCALE,
|
||||
FusedOpType::ACTIVATION };
|
||||
FusedOpType::ACTIVATION,
|
||||
FusedOpType::REORDER };
|
||||
}
|
||||
|
||||
JitConstants GetJitConstants(const convolution_params& params, const DispatchData& dispatchData) const override;
|
||||
|
@ -43,6 +43,17 @@ struct reorder_optional_params : optional_params {
|
||||
reorder_optional_params() : optional_params(KernelType::REORDER) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// reorder_fuse_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
struct reorder_fuse_params : fuse_params {
|
||||
DataLayout input_layout;
|
||||
DataLayout output_layout;
|
||||
|
||||
reorder_fuse_params(DataLayout input_layout, DataLayout output_layout) :
|
||||
fuse_params(KernelType::REORDER), input_layout(input_layout), output_layout(output_layout) {}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// reorder_weights_params
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -41,6 +41,11 @@
|
||||
# error convolution_gpu_bfyx_f16.cl: unsupported filter type
|
||||
#endif
|
||||
|
||||
#if OUTPUT_FORMAT_BFYX
|
||||
# define OUTPUTVTYPE(n) CAT(OUTPUT_TYPE, n)
|
||||
# define TO_OUTPUTVTYPE CAT(convert_, OUTPUTVTYPE(OUTPUT_X_BLOCK_SIZE))
|
||||
# define VSTORE CAT(vstore, OUTPUT_X_BLOCK_SIZE)
|
||||
#else
|
||||
# if OUTPUT_TYPE_SIZE == 1
|
||||
# define OUTPUT_BLOCK_WRITE(ptr, offset, val) BLOCK_WRITE_UC_1((__global uchar*)(ptr) + (offset), as_uchar(val))
|
||||
# define OUTPUT_BLOCK_WRITE2(ptr, offset, val) BLOCK_WRITE_UC_2((__global uchar*)(ptr) + (offset), as_uchar2(val))
|
||||
@ -59,6 +64,7 @@
|
||||
# else
|
||||
# error convolution_gpu_bfyx_f16.cl: unsupported output type
|
||||
# endif
|
||||
#endif // OUTPUT_FORMAT_BFYX
|
||||
|
||||
#if INPUT0_TYPE_SIZE == 2
|
||||
# define AS_INPUT_SRC CAT(as_, MAKE_VECTOR_TYPE(INPUT_TYPE, OUTPUT_X_BLOCK_SIZE))
|
||||
@ -129,18 +135,30 @@ KERNEL(convolution_bfyx_f16)(
|
||||
(INPUT0_PAD_BEFORE_SIZE_X + input_x) * input_x_pitch;
|
||||
|
||||
// Output offset calculations:
|
||||
|
||||
#if OUTPUT_FORMAT_BFYX
|
||||
const uint output_y_pitch = (OUTPUT_PAD_BEFORE_SIZE_X + OUTPUT_SIZE_X + OUTPUT_PAD_AFTER_SIZE_X);
|
||||
const uint output_fs_pitch = output_y_pitch * (OUTPUT_PAD_BEFORE_SIZE_Y + OUTPUT_SIZE_Y + OUTPUT_PAD_AFTER_SIZE_Y);
|
||||
const uint output_b_pitch = output_fs_pitch * (OUTPUT_PAD_BEFORE_FEATURE_NUM + OUTPUT_FEATURE_NUM + OUTPUT_PAD_AFTER_FEATURE_NUM);
|
||||
|
||||
const uint output_offset = b * output_b_pitch +
|
||||
feature_block * (output_fs_pitch * FEATURE_SLICE_SIZE) +
|
||||
(sglid + OUTPUT_PAD_BEFORE_FEATURE_NUM) * output_fs_pitch +
|
||||
(y + OUTPUT_PAD_BEFORE_SIZE_Y) * output_y_pitch +
|
||||
(x + OUTPUT_PAD_BEFORE_SIZE_X);
|
||||
#else
|
||||
const uint output_x_pitch = FEATURE_SLICE_SIZE;
|
||||
const uint output_y_pitch = output_x_pitch * (OUTPUT_PAD_BEFORE_SIZE_X + OUTPUT_SIZE_X + OUTPUT_PAD_AFTER_SIZE_X);
|
||||
const uint output_total_f_size = OUTPUT_PAD_BEFORE_FEATURE_NUM + OUTPUT_FEATURE_NUM + OUTPUT_PAD_AFTER_FEATURE_NUM;
|
||||
const uint output_fs_pitch = output_y_pitch * (OUTPUT_PAD_BEFORE_SIZE_Y + OUTPUT_SIZE_Y + OUTPUT_PAD_AFTER_SIZE_Y);
|
||||
const uint output_b_pitch = output_fs_pitch * ((output_total_f_size + FEATURE_SLICE_SIZE - 1) / FEATURE_SLICE_SIZE);
|
||||
|
||||
const uint output_fs_pad_before = OUTPUT_PAD_BEFORE_FEATURE_NUM / FEATURE_SLICE_SIZE;
|
||||
|
||||
const uint output_offset = b * output_b_pitch +
|
||||
(feature_block + output_fs_pad_before) * output_fs_pitch +
|
||||
(y + OUTPUT_PAD_BEFORE_SIZE_Y) * output_y_pitch +
|
||||
(x + OUTPUT_PAD_BEFORE_SIZE_X) * output_x_pitch;
|
||||
#endif
|
||||
|
||||
// Filter offset calculations:
|
||||
const uint filter_isv_pitch = FEATURE_SLICE_SIZE;
|
||||
@ -383,15 +401,27 @@ KERNEL(convolution_bfyx_f16)(
|
||||
#if OUTPUT_LEFTOVERS
|
||||
if ((feature_block + 1) * FEATURE_SLICE_SIZE >= OUTPUT_FEATURE_NUM) {
|
||||
for (int i = 0; i < OUTPUT_X_BLOCK_SIZE; i++) {
|
||||
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_SCALAR;
|
||||
# if OUTPUT_FORMAT_BFYX
|
||||
res[i] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SCALAR);
|
||||
# else
|
||||
res[i] = FUSED_OPS_RESULT_SCALAR;
|
||||
# endif
|
||||
#else
|
||||
res[i] = TO_OUTPUT_TYPE(dst[i]);
|
||||
#endif
|
||||
|
||||
#if OUTPUT_FORMAT_BFYX
|
||||
if ((feature_block * FEATURE_SLICE_SIZE + sglid < OUTPUT_FEATURE_NUM) && (x + i) < OUTPUT_SIZE_X) {
|
||||
output[output_offset + i] = res[i];
|
||||
}
|
||||
#else
|
||||
if ((feature_block * FEATURE_SLICE_SIZE + sglid < OUTPUT_FEATURE_NUM) && (x + i) < OUTPUT_SIZE_X) {
|
||||
output[output_offset + i * output_x_pitch + sglid] = res[i];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -400,11 +430,28 @@ KERNEL(convolution_bfyx_f16)(
|
||||
if (x + OUTPUT_X_BLOCK_SIZE <= OUTPUT_SIZE_X || OUTPUT_SIZE_X % OUTPUT_X_BLOCK_SIZE == 0) {
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_VEC;
|
||||
# if OUTPUT_FORMAT_BFYX
|
||||
res = TO_OUTPUTVTYPE(FUSED_OPS_RESULT_VEC);
|
||||
# else
|
||||
res = FUSED_OPS_RESULT_VEC;
|
||||
# endif
|
||||
#else
|
||||
# if OUTPUT_FORMAT_BFYX
|
||||
res = TO_OUTPUTVTYPE(dst);
|
||||
# else
|
||||
res = dst;
|
||||
# endif
|
||||
#endif
|
||||
// TODO Generalize for other block sizes
|
||||
#if OUTPUT_FORMAT_BFYX
|
||||
#if OUTPUT_X_BLOCK_SIZE == 2 || OUTPUT_X_BLOCK_SIZE == 4 || OUTPUT_X_BLOCK_SIZE == 8
|
||||
VSTORE(res, 0, output + output_offset);
|
||||
#elif OUTPUT_X_BLOCK_SIZE == 1
|
||||
output[output_offset] = res[0];
|
||||
#else
|
||||
# error convolution_gpu_bfyx_f16.cl: unsupported output x block size
|
||||
#endif
|
||||
#else
|
||||
#if OUTPUT_X_BLOCK_SIZE == 8
|
||||
OUTPUT_BLOCK_WRITE8(output, output_offset, res);
|
||||
#elif OUTPUT_X_BLOCK_SIZE == 4
|
||||
@ -416,19 +463,28 @@ KERNEL(convolution_bfyx_f16)(
|
||||
#else
|
||||
# error convolution_gpu_bfyx_f16.cl: unsupported output x block size
|
||||
#endif
|
||||
#endif // OUTPUT_FORMAT_BFYX
|
||||
} else {
|
||||
for (int i = 0; i < OUTPUT_SIZE_X % OUTPUT_X_BLOCK_SIZE; i++) {
|
||||
#if HAS_FUSED_OPS
|
||||
FUSED_OPS_SCALAR;
|
||||
# if OUTPUT_FORMAT_BFYX
|
||||
res[i] = TO_OUTPUT_TYPE(FUSED_OPS_RESULT_SCALAR);
|
||||
# else
|
||||
res[i] = FUSED_OPS_RESULT_SCALAR;
|
||||
# endif
|
||||
#else
|
||||
res[i] = TO_OUTPUT_TYPE(dst[i]);
|
||||
#endif
|
||||
OUTPUT_BLOCK_WRITE(output, output_offset + i * output_x_pitch, res[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if OUTPUT_FORMAT_BFYX
|
||||
output[output_offset + i] = res[i];
|
||||
#else
|
||||
OUTPUT_BLOCK_WRITE(output, output_offset + i * output_x_pitch, res[i]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
#if SLM_DIV_FACTOR > 1
|
||||
}
|
||||
#endif
|
||||
@ -462,7 +518,13 @@ KERNEL(convolution_bfyx_f16)(
|
||||
|
||||
#undef FILTER_BLOCK_READ8
|
||||
|
||||
#if OUTPUT_FORMAT_BFYX
|
||||
# undef OUTPUTVTYPE
|
||||
# undef TO_OUTPUTVTYPE
|
||||
# undef VSTORE
|
||||
#else
|
||||
# undef OUTPUT_BLOCK_WRITE
|
||||
# undef OUTPUT_BLOCK_WRITE2
|
||||
# undef OUTPUT_BLOCK_WRITE4
|
||||
# undef OUTPUT_BLOCK_WRITE8
|
||||
#endif // OUTPUT_FORMAT_BFYX
|
||||
|
@ -1741,8 +1741,10 @@ std::string FusedOpsCodeGenerator::GetJitLoad(const FusedOpsConfiguration& conf,
|
||||
|
||||
// Eltwise fused op can't have full tensor argument when requested vec_size > 1, since it might require
|
||||
// splitting load into several parts and some kind of index recalculation which is not supported
|
||||
DataLayout orig_output_layout = conf.IsPostReorderFused() ? conf.orig_output_layout : prim_output.GetLayout();
|
||||
|
||||
if (desc.GetType() == KernelType::ELTWISE && !valid_broadcast_case &&
|
||||
input_tensor.GetLayout() != prim_output.GetLayout() && conf.vec_size > 1) {
|
||||
input_tensor.GetLayout() != orig_output_layout && conf.vec_size > 1) {
|
||||
throw std::runtime_error("[clDNN] Mixed layouts of input tensors are not supported in fused eltwise:"
|
||||
"\nfused_input: " + toString_v2(input_tensor) +
|
||||
"\noutput: " + toString_v2(prim_output));
|
||||
|
@ -108,6 +108,9 @@ JitConstants KernelBase::MakeFusedOpsJitConstants(const kernel_selector::base_pa
|
||||
if (conf.empty())
|
||||
return jit;
|
||||
|
||||
if (params.fused_ops.size() == 1 && params.fused_ops[0].GetType() == KernelType::REORDER)
|
||||
return jit;
|
||||
|
||||
try {
|
||||
for (auto& c : conf) {
|
||||
std::string fused_ops;
|
||||
@ -119,6 +122,10 @@ JitConstants KernelBase::MakeFusedOpsJitConstants(const kernel_selector::base_pa
|
||||
bool can_all_use_preload = true;
|
||||
|
||||
for (size_t i = 0; i < params.fused_ops.size(); i++) {
|
||||
// Reorder is not processed by jitter
|
||||
if (params.fused_ops[i].GetType() == FusedOpType::REORDER)
|
||||
continue;
|
||||
|
||||
auto fused_dep_codegen = FusedOpsCodeGenerator(params.fused_ops[i]);
|
||||
jit.Merge(fused_dep_codegen.MakeLoadJitConstants(c, params.output));
|
||||
jit.Merge(fused_dep_codegen.MakeOpJitConstants(c, in_name, in_type, out_name));
|
||||
|
@ -469,6 +469,8 @@ struct FusedOpsConfiguration {
|
||||
bool allow_for_partial_preload;
|
||||
// Load index for shuffle fused op
|
||||
std::string shuffle_var_name;
|
||||
// Record original output layout before reorder is fused
|
||||
DataLayout orig_output_layout;
|
||||
|
||||
FusedOpsConfiguration(std::string suffix,
|
||||
std::vector<std::string> bfzyx_idx_order,
|
||||
@ -481,7 +483,8 @@ struct FusedOpsConfiguration {
|
||||
Tensor::DataChannelName vec_axis = Tensor::DataChannelName::COUNT,
|
||||
std::vector<Tensor::DataChannelName> loop_axes = {},
|
||||
bool allow_for_partial_preload = false,
|
||||
std::string shuffle_var_name = "")
|
||||
std::string shuffle_var_name = "",
|
||||
DataLayout orig_output_layout = DataLayout::DataLayoutCount)
|
||||
: suffix(suffix)
|
||||
, bfzyx_idx_order(bfzyx_idx_order)
|
||||
, input_var_name(input_var_name)
|
||||
@ -493,7 +496,8 @@ struct FusedOpsConfiguration {
|
||||
, index_type(index_type)
|
||||
, loop_axes(loop_axes)
|
||||
, allow_for_partial_preload(allow_for_partial_preload)
|
||||
, shuffle_var_name(shuffle_var_name) { }
|
||||
, shuffle_var_name(shuffle_var_name)
|
||||
, orig_output_layout(orig_output_layout) { }
|
||||
|
||||
FusedOpsConfiguration& SetVectorSize(size_t val) { vec_size = val; return *this; }
|
||||
FusedOpsConfiguration& SetLoadType(LoadType val) { load_type = val; return *this; }
|
||||
@ -505,6 +509,7 @@ struct FusedOpsConfiguration {
|
||||
allow_for_partial_preload = partial_preload;
|
||||
return *this; }
|
||||
FusedOpsConfiguration& SetShuffleVarName(std::string val) { shuffle_var_name = val; return *this; }
|
||||
bool IsPostReorderFused(void) const { return orig_output_layout != DataLayout::DataLayoutCount; }
|
||||
};
|
||||
|
||||
// Instance of fused_operation_desc is added to fused_ops vector if a node has been fused to current one using program::fuse_nodes
|
||||
|
@ -334,26 +334,24 @@ void remove_redundant_reorders::run(program& p) {
|
||||
p.remove_if_dangling(node);
|
||||
}
|
||||
|
||||
// This pass removes reorder for Convolution BFYX -> FS_B_YX_FSV32
|
||||
itr = p.get_processing_order().begin();
|
||||
while (itr != p.get_processing_order().end()) {
|
||||
auto& node = *itr++;
|
||||
if (!node->is_type<reorder>() || !node->is_in_data_flow() || node->get_users().size() != 1 || node->get_dependencies().size() != 1)
|
||||
continue;
|
||||
// Remove reorder for Convolution bfyx -> fs_b_yx_fsv32
|
||||
auto try_fuse_reorder_bfyx_to_fsv32 = [&](reorder_node* node) {
|
||||
if (node->get_users().size() != 1)
|
||||
return;
|
||||
|
||||
auto& usr = node->get_users().front();
|
||||
auto& dep = node->get_dependency(0);
|
||||
if (!(usr->is_type<convolution>()) ||
|
||||
(usr->get_output_layout().data_type != dep.get_output_layout().data_type) ||
|
||||
(usr->get_output_layout().format != format::fs_b_yx_fsv32) ||
|
||||
(dep.get_output_layout().format != format::bfyx))
|
||||
continue;
|
||||
(dep.get_output_layout().format != format::bfyx) ||
|
||||
(usr->get_output_layout().format != format::fs_b_yx_fsv32))
|
||||
return;
|
||||
|
||||
if (dep.is_type<input_layout>())
|
||||
continue;
|
||||
return;
|
||||
|
||||
if (usr->as<convolution>().get_primitive()->groups != 1)
|
||||
continue;
|
||||
return;
|
||||
|
||||
dep.merge_output_padding(node->get_output_layout().data_padding);
|
||||
p.replace_all_usages(*node, dep);
|
||||
@ -361,6 +359,83 @@ void remove_redundant_reorders::run(program& p) {
|
||||
p.add_optimized_primitive_info(node->id());
|
||||
p.remove_all_connections(*node);
|
||||
p.remove_if_dangling(*node);
|
||||
};
|
||||
|
||||
// Remove reorder for Convolution b_fs_yx_fsv16 -> bfyx
|
||||
auto try_fuse_reorder_fsv16_to_bfyx = [&](reorder_node* node) {
|
||||
if (!node->get_fused_activations_funcs().empty() ||
|
||||
!node->get_fused_primitives().empty())
|
||||
return;
|
||||
|
||||
auto& input = node->input();
|
||||
|
||||
if (!(input.is_type<convolution>()) ||
|
||||
!(input.get_output_layout().format == format::b_fs_yx_fsv16) ||
|
||||
!(node->get_output_layout().format == format::bfyx))
|
||||
return;
|
||||
|
||||
if (input.as<convolution>().get_primitive()->groups != 1)
|
||||
return;
|
||||
|
||||
if (input.get_users().size() != 1)
|
||||
return;
|
||||
|
||||
auto& input_dep = input.get_dependency(0);
|
||||
if (input_dep.get_output_layout().format != format::b_fs_yx_fsv16 ||
|
||||
input_dep.get_output_layout().data_type == data_types::u8 ||
|
||||
input_dep.get_output_layout().data_type == data_types::i8)
|
||||
return;
|
||||
|
||||
for (auto& user : node->get_users()) {
|
||||
// if concat is reorder's user and concat's axis is 0(Batch) or 1(Feature), conv's output would have padding.
|
||||
// This padding might lead not to select the optimized conv kernel("convolution_gpu_bfyx_f16")
|
||||
if (user->is_type<concatenation>()) {
|
||||
auto& concat_node = user->as<concatenation>();
|
||||
auto concat_axis = concat_node.get_primitive()->axis;
|
||||
if (concat_axis == 0 || concat_axis == 1)
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto output_layout = node->get_output_layout();
|
||||
input.set_output_layout(output_layout, false);
|
||||
if (input.type()->does_possible_implementation_exist(input)) {
|
||||
input.set_output_padding(node->get_output_layout().data_padding);
|
||||
|
||||
// Add fused_primitive_desc of reorder to convolution which propagate original output layout to jitter
|
||||
fused_primitive_desc local_desc;
|
||||
local_desc.node = p.get_node_ptr(node->id());
|
||||
local_desc.dep_start_idx = input.get_fused_primitives().size();
|
||||
local_desc.output_layout = output_layout;
|
||||
local_desc.input_layout = input.get_dependency(0).get_output_layout(); // original convolution's output layout
|
||||
local_desc.activation = activation_func::none;
|
||||
input.add_fused_primitive(local_desc);
|
||||
node->set_input_layout(local_desc.input_layout);
|
||||
|
||||
// remove reorder node
|
||||
node->can_be_optimized(true);
|
||||
p.add_optimized_primitive_info(node->id());
|
||||
p.extract_and_remove(*node);
|
||||
}
|
||||
};
|
||||
|
||||
if (enable_reorder_fusing) {
|
||||
itr = p.get_processing_order().begin();
|
||||
while (itr != p.get_processing_order().end()) {
|
||||
auto& node = *itr++;
|
||||
if (!node->is_type<reorder>())
|
||||
continue;
|
||||
|
||||
if (!node->is_in_data_flow() || node->get_dependencies().size() != 1)
|
||||
continue;
|
||||
|
||||
auto& r_node = node->as<reorder>();
|
||||
|
||||
// Remove reorder for Convolution bfyx -> fs_b_yx_fsv32
|
||||
try_fuse_reorder_bfyx_to_fsv32(&r_node);
|
||||
// Remove reorder for Convolution b_fs_yx_fsv16 -> bfyx
|
||||
try_fuse_reorder_fsv16_to_bfyx(&r_node);
|
||||
}
|
||||
}
|
||||
|
||||
// Additional reshape chains shrink.
|
||||
|
@ -41,6 +41,7 @@ struct fused_primitive_desc {
|
||||
std::vector<primitive_id> fused_deps;
|
||||
activation_func activation;
|
||||
activation_additional_params activation_params;
|
||||
layout input_layout = layout(data_types::f32, format::bfyx, tensor());
|
||||
layout output_layout = layout(data_types::f32, format::bfyx, tensor());
|
||||
};
|
||||
|
||||
|
@ -7,6 +7,8 @@
|
||||
|
||||
#include "cldnn/primitives/reorder.hpp"
|
||||
#include "primitive_inst.h"
|
||||
#include "kernel_selector/core/actual_kernels/reorder/reorder_kernel_base.h"
|
||||
#include "kernel_selector/common/tensor_type.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
@ -33,11 +35,19 @@ public:
|
||||
void requires_reinterpret(bool val) { req_reinterpr = (optimized && val); }
|
||||
|
||||
void set_input_offset(tensor const& io) { input_offset = io; }
|
||||
void set_input_layout(layout const& lo) { input_layout = lo; }
|
||||
tensor get_input_offset() const { return input_offset; }
|
||||
|
||||
std::shared_ptr<kernel_selector::fuse_params> get_fuse_params() const override {
|
||||
kernel_selector::DataLayout ks_input_layout = convert_data_tensor(input_layout).GetLayout();
|
||||
kernel_selector::DataLayout ks_output_layout = convert_data_tensor(get_output_layout()).GetLayout();
|
||||
return std::make_shared<kernel_selector::reorder_fuse_params>(ks_input_layout, ks_output_layout);
|
||||
}
|
||||
|
||||
private:
|
||||
bool req_reinterpr = false;
|
||||
tensor input_offset = tensor{0}; // used by reorder to winograd domain
|
||||
layout input_layout = layout(data_types::f32, format::bfyx, { 0, 0, 0, 0 });
|
||||
};
|
||||
|
||||
using reorder_node = typed_program_node<reorder>;
|
||||
|
@ -7609,6 +7609,222 @@ TEST_P(convolution_general_gpu, conv_fp16_cases) {
|
||||
}
|
||||
}
|
||||
|
||||
struct convolution_gpu_fsv16_to_bfyx : public convolution_general_gpu {};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(conv_b_fs_yx_fsv16_to_bfyx,
|
||||
convolution_gpu_fsv16_to_bfyx,
|
||||
::testing::Values(
|
||||
// Input X size, Input Y size, Input Z size, Input features, Output features,
|
||||
// Kernel size X, Kernel size Y, Kernel size Z, Groups number, Stride, Batch,
|
||||
// Input data format, Implementation name, WithBias
|
||||
TestParamType_general_convolution_gpu(6, 6, 0, 16, 16, 3, 3, 0, 1, 1, 4, format::b_fs_yx_fsv16, "convolution_gpu_fsv16_to_bfyx", false),
|
||||
TestParamType_general_convolution_gpu(6, 6, 0, 32, 32, 3, 3, 0, 1, 1, 1, format::b_fs_yx_fsv16, "convolution_gpu_fsv16_to_bfyx", false),
|
||||
TestParamType_general_convolution_gpu(6, 6, 0, 16, 16, 3, 3, 0, 1, 1, 16, format::b_fs_yx_fsv16, "convolution_gpu_fsv16_to_bfyx", false),
|
||||
TestParamType_general_convolution_gpu(16, 6, 0, 20, 16, 3, 3, 0, 1, 1, 20, format::b_fs_yx_fsv16, "convolution_gpu_fsv16_to_bfyx", false)
|
||||
),
|
||||
convolution_gpu_fsv16_to_bfyx::PrintToStringParamName);
|
||||
|
||||
TEST_P(convolution_gpu_fsv16_to_bfyx, conv_b_fs_yx_fsv16_to_bfyx_padding)
|
||||
{
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
if (!engine.get_device_info().supports_fp16)
|
||||
{
|
||||
std::cout << "[ SKIPPED ] The test is skipped (cl_khr_fp16 is not supported)." << std::endl;
|
||||
EXPECT_EQ(1, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
const int input_b = testing::get<10>(GetParam());
|
||||
const int input_f = testing::get<3>(GetParam());
|
||||
const int input_y = testing::get<1>(GetParam());
|
||||
const int input_x = testing::get<0>(GetParam());
|
||||
|
||||
const int filter_x = testing::get<5>(GetParam());
|
||||
const int filter_y = testing::get<6>(GetParam());
|
||||
const int stride = testing::get<9>(GetParam());
|
||||
|
||||
const int input_offset_y = (filter_y - 1) / 2;
|
||||
const int input_offset_x = (filter_x - 1) / 2;
|
||||
|
||||
auto input_size = tensor(input_b, input_f, input_x, input_y);
|
||||
auto input_data = generate_random_4d<FLOAT16>(input_b, input_f, input_y, input_x, -1, 1);
|
||||
auto input_data_bfyx = flatten_4d(format::bfyx, input_data);
|
||||
auto input_mem = engine.allocate_memory({ data_types::f16, format::bfyx, input_size });
|
||||
set_values(input_mem, input_data_bfyx);
|
||||
|
||||
auto weights_size = tensor(input_b, input_f, filter_x, filter_y, 1);
|
||||
auto weights_data = generate_random_4d<FLOAT16>(input_b, input_f, filter_x, filter_y, -1, 1);
|
||||
auto weights_data_bfyx = flatten_4d(format::bfyx, weights_data);
|
||||
auto weights_mem = engine.allocate_memory({ data_types::f16, format::goiyx, weights_size });
|
||||
set_values(weights_mem, weights_data_bfyx);
|
||||
|
||||
// Set topology
|
||||
topology topology(
|
||||
input_layout("input_origin", input_mem->get_layout()),
|
||||
data("weights_fsv", weights_mem),
|
||||
reorder("input_fsv16", "input_origin", { data_types::f16, format::b_fs_yx_fsv16, input_size })); // format 3 to 8
|
||||
|
||||
// Add convolution
|
||||
auto input_stride = tensor(1, 1, stride, stride);
|
||||
auto input_offset = tensor(0, 0, input_offset_x, input_offset_y);
|
||||
auto input_dilation = tensor(1, 1, 1, 1);
|
||||
auto input_padding_before = tensor(0, 0, input_offset_x, input_offset_y);
|
||||
auto input_padding_after = tensor(0, 0, input_offset_x, input_offset_y);
|
||||
|
||||
auto conv_fsv = convolution("conv_fsv", "input_fsv16", { "weights_fsv" }, input_stride, input_offset, input_dilation, input_padding_before, input_padding_after);
|
||||
conv_fsv.output_padding = padding({ 0, 32, 2, 2 }, 0.f);
|
||||
topology.add(conv_fsv); // format 8 to 8 -> after fusing, format 8 to 3
|
||||
|
||||
// Add reorder to bfyx
|
||||
auto reorder_bfyx = reorder("reorder_bfyx", "conv_fsv", { data_types::f16, format::bfyx, input_size });
|
||||
reorder_bfyx.output_padding = padding({ 0, 16, 1, 1 }, 0.f);
|
||||
topology.add(reorder_bfyx); // format 8 to 3 -> after fusing, removed
|
||||
|
||||
// Exec ref network (non-fusing)
|
||||
build_options options_ref;
|
||||
options_ref.set_option(build_option::optimize_data(false));
|
||||
options_ref.set_option(build_option::allow_static_input_reorder(true));
|
||||
|
||||
network network_ref(engine, topology, options_ref);
|
||||
network_ref.set_input_data("input_origin", input_mem);
|
||||
auto ref_out = network_ref.execute();
|
||||
|
||||
auto ref_out_mem = ref_out.begin()->second.get_memory();
|
||||
cldnn::mem_lock<FLOAT16> ref_out_ptr(ref_out_mem, get_test_stream());
|
||||
|
||||
// Exec target network (fusing: conv+reorder)
|
||||
build_options options_target;
|
||||
implementation_desc conv_impl = { format::b_fs_yx_fsv16, "convolution_gpu_bfyx_f16" };
|
||||
options_target.set_option(build_option::force_implementations({ {"conv_fsv", conv_impl} }));
|
||||
options_target.set_option(build_option::optimize_data(true));
|
||||
|
||||
network network_target(engine, topology, options_target);
|
||||
network_target.set_input_data("input_origin", input_mem);
|
||||
auto target_out = network_target.execute();
|
||||
|
||||
auto target_out_mem = target_out.begin()->second.get_memory();
|
||||
cldnn::mem_lock<FLOAT16> target_out_ptr(target_out_mem, get_test_stream());
|
||||
|
||||
// Compare ref and target result
|
||||
for (size_t i = 0; i < ref_out_ptr.size(); i++) {
|
||||
auto ref_val = static_cast<float>(ref_out_ptr[i]);
|
||||
auto target_val = static_cast<float>(target_out_ptr[i]);
|
||||
auto diff = std::fabs(ref_val - target_val);
|
||||
auto equal = (diff > 1e-5f) ? false : true;
|
||||
|
||||
EXPECT_TRUE(equal);
|
||||
if (!equal)
|
||||
{
|
||||
std::cout << "i:" << i \
|
||||
<< "\t ref_out = " << ref_val \
|
||||
<< "\t target_out = " << target_val \
|
||||
<< std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(convolution_gpu_fsv16_to_bfyx, conv_b_fs_yx_fsv16_to_bfyx_different_type)
|
||||
{
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
if (!engine.get_device_info().supports_fp16)
|
||||
{
|
||||
std::cout << "[ SKIPPED ] The test is skipped (cl_khr_fp16 is not supported)." << std::endl;
|
||||
EXPECT_EQ(1, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
const int input_b = testing::get<10>(GetParam());
|
||||
const int input_f = testing::get<3>(GetParam());
|
||||
const int input_y = testing::get<1>(GetParam());
|
||||
const int input_x = testing::get<0>(GetParam());
|
||||
|
||||
const int filter_x = testing::get<5>(GetParam());
|
||||
const int filter_y = testing::get<6>(GetParam());
|
||||
const int stride = testing::get<9>(GetParam());
|
||||
|
||||
const int input_offset_y = (filter_y - 1) / 2;
|
||||
const int input_offset_x = (filter_x - 1) / 2;
|
||||
|
||||
auto input_size = tensor(input_b, input_f, input_x, input_y);
|
||||
auto input_data = generate_random_4d<FLOAT16>(input_b, input_f, input_y, input_x, -1, 1);
|
||||
auto input_data_bfyx = flatten_4d(format::bfyx, input_data);
|
||||
auto input_mem = engine.allocate_memory({ data_types::f16, format::bfyx, input_size });
|
||||
set_values(input_mem, input_data_bfyx);
|
||||
|
||||
auto weights_size = tensor(input_b, input_f, filter_x, filter_y, 1);
|
||||
auto weights_data = generate_random_4d<FLOAT16>(input_b, input_f, filter_x, filter_y, -1, 1);
|
||||
auto weights_data_bfyx = flatten_4d(format::bfyx, weights_data);
|
||||
auto weights_mem = engine.allocate_memory({ data_types::f16, format::goiyx, weights_size });
|
||||
set_values(weights_mem, weights_data_bfyx);
|
||||
|
||||
// Set topology
|
||||
topology topology(
|
||||
input_layout("input_origin", input_mem->get_layout()),
|
||||
data("weights_fsv", weights_mem),
|
||||
reorder("input_fsv16", "input_origin", { data_types::f16, format::b_fs_yx_fsv16, input_size })); // format 3 to 8
|
||||
|
||||
// Add convolution
|
||||
auto input_stride = tensor(1, 1, stride, stride);
|
||||
auto input_offset = tensor(0, 0, input_offset_x, input_offset_y);
|
||||
auto input_dilation = tensor(1, 1, 1, 1);
|
||||
auto no_padding = tensor(0, 0, input_offset_x, input_offset_y);
|
||||
|
||||
auto conv_fsv = convolution("conv_fsv", "input_fsv16", { "weights_fsv" }, input_stride, input_offset, input_dilation, no_padding, no_padding);
|
||||
topology.add(conv_fsv); // format 8 to 8 -> after fusing, format 8 to 3
|
||||
|
||||
// Add reorder to bfyx
|
||||
auto reorder_bfyx = reorder("reorder_bfyx", "conv_fsv", { data_types::f32, format::bfyx, input_size });
|
||||
topology.add(reorder_bfyx); // format 8 to 3 -> after fusing, removed
|
||||
|
||||
// Exec ref network (non-fusing)
|
||||
build_options options_ref;
|
||||
options_ref.set_option(build_option::optimize_data(false));
|
||||
options_ref.set_option(build_option::allow_static_input_reorder(true));
|
||||
|
||||
network network_ref(engine, topology, options_ref);
|
||||
network_ref.set_input_data("input_origin", input_mem);
|
||||
auto ref_out = network_ref.execute();
|
||||
|
||||
auto ref_out_mem = ref_out.begin()->second.get_memory();
|
||||
cldnn::mem_lock<float> ref_out_ptr(ref_out_mem, get_test_stream());
|
||||
|
||||
// Exec target network (fusing: conv+reorder)
|
||||
build_options options_target;
|
||||
implementation_desc conv_impl = { format::b_fs_yx_fsv16, "convolution_gpu_bfyx_f16" };
|
||||
options_target.set_option(build_option::force_implementations({ {"conv_fsv", conv_impl} }));
|
||||
options_target.set_option(build_option::optimize_data(true));
|
||||
|
||||
network network_target(engine, topology, options_target);
|
||||
network_target.set_input_data("input_origin", input_mem);
|
||||
auto target_out = network_target.execute();
|
||||
|
||||
auto target_out_mem = target_out.begin()->second.get_memory();
|
||||
cldnn::mem_lock<float> target_out_ptr(target_out_mem, get_test_stream());
|
||||
|
||||
// Compare ref and target result
|
||||
for (size_t i = 0; i < ref_out_ptr.size(); i++) {
|
||||
auto ref_val = static_cast<float>(ref_out_ptr[i]);
|
||||
auto target_val = static_cast<float>(target_out_ptr[i]);
|
||||
auto diff = std::abs(ref_val - target_val);
|
||||
auto equal = (diff > 1e-5f) ? false : true;
|
||||
|
||||
EXPECT_TRUE(equal);
|
||||
if (!equal)
|
||||
{
|
||||
std::cout << "i:" << i \
|
||||
<< "\t ref_out = " << ref_val \
|
||||
<< "\t target_out = " << target_val \
|
||||
<< std::endl;
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputT, typename WeightsT, typename OutputT>
|
||||
class convolution_test_base {
|
||||
public:
|
||||
|
@ -616,6 +616,73 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class conv_fp32_reorder_fsv16_to_bfyx : public ConvFusingTest {};
|
||||
TEST_P(conv_fp32_reorder_fsv16_to_bfyx, basic) {
|
||||
auto p = GetParam();
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p))),
|
||||
reorder("reorder_fsv16", "input", format::b_fs_yx_fsv16, data_types::f32),
|
||||
convolution("conv_prim", "reorder_fsv16", { "weights" }, p.groups, p.stride, p.pad, p.dilation),
|
||||
reorder("reorder_bfyx", "conv_prim", format::bfyx, data_types::f32)
|
||||
);
|
||||
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_reorder_fsv16_to_bfyx, ::testing::ValuesIn(std::vector<bc_test_params>{
|
||||
bc_test_params{ CASE_CONV_FP32_1, 2, 2},
|
||||
bc_test_params{ CASE_CONV_FP32_2, 2, 2},
|
||||
bc_test_params{ CASE_CONV_FP32_3, 2, 2},
|
||||
bc_test_params{ CASE_CONV_FP32_4, 2, 2 },
|
||||
bc_test_params{ CASE_CONV_FP32_5, 2, 2 },
|
||||
bc_test_params{ CASE_CONV_FP32_14, 2, 2 },
|
||||
|
||||
bc_test_params{ CASE_CONV_FP16_1, 2, 2},
|
||||
bc_test_params{ CASE_CONV_FP16_2, 2, 2},
|
||||
bc_test_params{ CASE_CONV_FP16_3, 2, 2},
|
||||
bc_test_params{ CASE_CONV_FP16_4, 2, 2 },
|
||||
bc_test_params{ CASE_CONV_FP16_5, 2, 2 },
|
||||
bc_test_params{ CASE_CONV_FP16_13, 2, 2}
|
||||
}));
|
||||
|
||||
class conv_fp32_reorder_fsv16_to_bfyx_conv : public ConvFusingTest {};
|
||||
TEST_P(conv_fp32_reorder_fsv16_to_bfyx_conv, basic) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto dw_tensor = cldnn::tensor(group(p.out_shape.feature[0]), batch(1), feature(1), spatial(3, 3));
|
||||
auto dw_weights_layout = layout{ p.default_type, format::goiyx, dw_tensor };
|
||||
auto dw_stride = tensor{ 0, 0, 1, 1 };
|
||||
|
||||
create_topologies(input_layout("input", get_input_layout(p)),
|
||||
data("weights", get_mem(get_weights_layout(p), -127, 127)),
|
||||
data("weights_dw", get_mem(dw_weights_layout, -127, 127)),
|
||||
reorder("reorder_fsv16", "input", format::b_fs_yx_fsv16, data_types::f32),
|
||||
convolution("conv_prim", "reorder_fsv16", { "weights" }, p.groups, p.stride, p.pad, p.dilation),
|
||||
reorder("reorder_bfyx", "conv_prim", format::bfyx, data_types::f32),
|
||||
convolution("conv_output", "reorder_bfyx", { "weights_dw" }, 1, dw_stride, p.pad, p.dilation),
|
||||
activation("activation", "conv_output", activation_func::abs),
|
||||
reorder("reorder_output", "activation", p.default_format, data_types::f32)
|
||||
);
|
||||
|
||||
execute(p);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(fusings_gpu, conv_fp32_reorder_fsv16_to_bfyx_conv, ::testing::ValuesIn(std::vector<bc_test_params>{
|
||||
bc_test_params{ CASE_CONV_FP32_1, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP32_2, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP32_3, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP32_4, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP32_5, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP32_14, 3, 4 },
|
||||
|
||||
bc_test_params{ CASE_CONV_FP16_1, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP16_2, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP16_3, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP16_4, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP16_5, 3, 4 },
|
||||
bc_test_params{ CASE_CONV_FP16_13, 3, 4 },
|
||||
}));
|
||||
|
||||
class conv_fp32_activation : public ConvFusingTest {};
|
||||
TEST_P(conv_fp32_activation, basic) {
|
||||
auto p = GetParam();
|
||||
@ -8279,9 +8346,6 @@ INSTANTIATE_TEST_SUITE_P(fusings_gpu, scatter_nd_update_scale_activation_eltwise
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_4, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_5, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_6, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_7, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_9, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_5D_8, 2, 5 },
|
||||
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_1, 2, 5 },
|
||||
scatter_nd_update_test_params{ CASE_SCATTER_ND_UPDATE_FP16_6D_2, 2, 5 },
|
||||
|
@ -9,168 +9,20 @@
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/except.hpp"
|
||||
#include "openvino/core/check.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
static inline std::ostream& write_all_to_stream(std::ostream& str) {
|
||||
return str;
|
||||
}
|
||||
template <typename T, typename... TS>
|
||||
static inline std::ostream& write_all_to_stream(std::ostream& str, const T& arg, TS&&... args) {
|
||||
return write_all_to_stream(str << arg, args...);
|
||||
}
|
||||
using ov::write_all_to_stream;
|
||||
|
||||
struct CheckLocInfo {
|
||||
const char* file;
|
||||
int line;
|
||||
const char* check_string;
|
||||
};
|
||||
|
||||
/// Base class for check failure exceptions.
|
||||
class NGRAPH_API CheckFailure : public ngraph_error {
|
||||
public:
|
||||
CheckFailure(const CheckLocInfo& check_loc_info, const std::string& context_info, const std::string& explanation)
|
||||
: ngraph_error(make_what(check_loc_info, context_info, explanation)) {}
|
||||
|
||||
private:
|
||||
static std::string make_what(const CheckLocInfo& check_loc_info,
|
||||
const std::string& context_info,
|
||||
const std::string& explanation);
|
||||
};
|
||||
using ov::CheckFailure;
|
||||
using ov::CheckLocInfo;
|
||||
} // namespace ngraph
|
||||
|
||||
//
|
||||
// Helper macro for defining custom check macros, which throw custom exception classes and provide
|
||||
// useful context information (the check condition, source filename, line number, and any domain-
|
||||
// specific context information [e.g., a summary of the node that was being processed at the time
|
||||
// of the check]).
|
||||
//
|
||||
// For example (actually implemented in node.cpp), let's say we want to define a macro for
|
||||
// checking conditions during node validation, usable as follows:
|
||||
//
|
||||
// NODE_VALIDATION_CHECK(node_being_checked,
|
||||
// node_being_checked->get_input_shape(0).size() == 1,
|
||||
// "Node must have an input rank of 1, but got ",
|
||||
// node_being_checked->get_input_shape(0).size(), ".");
|
||||
//
|
||||
// In case of failure, this will throw an exception of type NodeValidationFailure with a what()
|
||||
// string something like:
|
||||
//
|
||||
// Check 'node_being_checked->get_input_shape(0).size() == 1' failed at foo.cpp:123:
|
||||
// While validating node 'Broadcast[Broadcast_10](Reshape_9: float{1,3,4,5}) -> (??)':
|
||||
// Node must have an input of rank 1, but got 2.
|
||||
//
|
||||
// To implement this, he first step is to define a subclass of CheckFailure (let's say it's called
|
||||
// MyFailure), which must have a constructor of the form:
|
||||
//
|
||||
// MyFailure(const CheckLocInfo& check_loc_info,
|
||||
// T context_info, // "T" can be any type; you'll supply a function to convert "T"
|
||||
// // to std::string
|
||||
// const std::string& explanation)
|
||||
//
|
||||
// Here, we define a custom class for node validation failures as follows:
|
||||
//
|
||||
// static std::string node_validation_failure_loc_string(const Node* node)
|
||||
// {
|
||||
// std::stringstream ss;
|
||||
// ss << "While validating node '" << *node << "'";
|
||||
// return ss.str();
|
||||
// }
|
||||
//
|
||||
// class NodeValidationFailure : public CheckFailure
|
||||
// {
|
||||
// public:
|
||||
// NodeValidationFailure(const CheckLocInfo& check_loc_info,
|
||||
// const Node* node,
|
||||
// const std::string& explanation)
|
||||
// : CheckFailure(check_loc_info, node_validation_failure_loc_string(node), explanation)
|
||||
// {
|
||||
// }
|
||||
// };
|
||||
//
|
||||
// Then, we define the macro NODE_VALIDATION_CHECK as follows:
|
||||
//
|
||||
// #define NODE_VALIDATION_CHECK(node, cond, ...) <backslash>
|
||||
// NGRAPH_CHECK_HELPER(::ngraph::NodeValidationFailure, (node), (cond), ##__VA_ARGS__)
|
||||
//
|
||||
// The macro NODE_VALIDATION_CHECK can now be called on any condition, with a Node* pointer
|
||||
// supplied to generate an informative error message via node_validation_failure_loc_string().
|
||||
//
|
||||
// Take care to fully qualify the exception class name in the macro body.
|
||||
//
|
||||
// The "..." may be filled with expressions of any type that has an "operator<<" overload for
|
||||
// insertion into std::ostream.
|
||||
//
|
||||
// TODO(amprocte): refactor NGRAPH_CHECK_HELPER so we don't have to introduce a locally-scoped
|
||||
// variable (ss___) and risk shadowing.
|
||||
//
|
||||
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) \
|
||||
do { \
|
||||
if (!(check)) { \
|
||||
::std::stringstream ss___; \
|
||||
::ngraph::write_all_to_stream(ss___, __VA_ARGS__); \
|
||||
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ss___.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#define NGRAPH_CHECK_HELPER2(exc_class, ctx, check, ...) OV_CHECK_HELPER2(exc_class, ctx, check, __VA_ARGS__)
|
||||
|
||||
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) \
|
||||
do { \
|
||||
if (!(check)) { \
|
||||
throw exc_class((::ngraph::CheckLocInfo{__FILE__, __LINE__, #check}), (ctx), ""); \
|
||||
} \
|
||||
} while (0)
|
||||
#define NGRAPH_CHECK_HELPER1(exc_class, ctx, check) OV_CHECK_HELPER1(exc_class, ctx, check)
|
||||
|
||||
/// \brief Macro to check whether a boolean condition holds.
|
||||
/// \param cond Condition to check
|
||||
/// \param ... Additional error message info to be added to the error message via the `<<`
|
||||
/// stream-insertion operator. Note that the expressions here will be evaluated lazily,
|
||||
/// i.e., only if the `cond` evalutes to `false`.
|
||||
/// \throws ::ngraph::CheckFailure if `cond` is false.
|
||||
#define NGRAPH_CHECK(...) NGRAPH_CHECK_HELPER(::ngraph::CheckFailure, "", __VA_ARGS__)
|
||||
#define NGRAPH_CHECK(...) OV_CHECK(__VA_ARGS__)
|
||||
|
||||
/// \brief Macro to signal a code path that is unreachable in a successful execution. It's
|
||||
/// implemented with NGRAPH_CHECK macro.
|
||||
/// \param ... Additional error message that should describe why that execution path is unreachable.
|
||||
/// \throws ::ngraph::CheckFailure if the macro is executed.
|
||||
#define NGRAPH_UNREACHABLE(...) NGRAPH_CHECK(false, "Unreachable: ", __VA_ARGS__)
|
||||
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) CALL_OVERLOAD(NGRAPH_CHECK_HELPER, exc_class, ctx, __VA_ARGS__)
|
||||
|
||||
#define GLUE(x, y) x y
|
||||
|
||||
#define RETURN_ARG_COUNT(_1_, \
|
||||
_2_, \
|
||||
_3_, \
|
||||
_4_, \
|
||||
_5_, \
|
||||
_6, \
|
||||
_7, \
|
||||
_8, \
|
||||
_9, \
|
||||
_10, \
|
||||
_11, \
|
||||
_12, \
|
||||
_13, \
|
||||
_14, \
|
||||
_15, \
|
||||
_16, \
|
||||
_17, \
|
||||
_18, \
|
||||
_19, \
|
||||
_20, \
|
||||
_21, \
|
||||
_22, \
|
||||
_23, \
|
||||
_24, \
|
||||
_25, \
|
||||
count, \
|
||||
...) \
|
||||
count
|
||||
#define EXPAND_ARGS(args) RETURN_ARG_COUNT args
|
||||
#define COUNT_ARGS_MAXN(...) \
|
||||
EXPAND_ARGS((__VA_ARGS__, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0))
|
||||
|
||||
#define OVERLOAD_MACRO2(name, count) name##count
|
||||
#define OVERLOAD_MACRO1(name, count) OVERLOAD_MACRO2(name, count)
|
||||
#define OVERLOAD_MACRO(name, count) OVERLOAD_MACRO1(name, count)
|
||||
|
||||
#define CALL_OVERLOAD(name, exc_class, ctx, ...) \
|
||||
GLUE(OVERLOAD_MACRO(name, COUNT_ARGS_MAXN(__VA_ARGS__)), (exc_class, ctx, __VA_ARGS__))
|
||||
#define NGRAPH_CHECK_HELPER(exc_class, ctx, ...) OV_CHECK_HELPER(exc_class, ctx, __VA_ARGS__)
|
||||
|
@ -67,6 +67,12 @@ public:
|
||||
m_auto_broadcast = auto_broadcast;
|
||||
}
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
bool has_evaluate() const override;
|
||||
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t m_levels;
|
||||
AutoBroadcastSpec m_auto_broadcast = op::AutoBroadcastType::NUMPY;
|
||||
|
@ -7,14 +7,10 @@
|
||||
#include <string>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "openvino/op/op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
/// Root of all actual ops
|
||||
class NGRAPH_API Op : public Node {
|
||||
protected:
|
||||
Op() : Node() {}
|
||||
Op(const OutputVector& arguments);
|
||||
};
|
||||
using ov::op::Op;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -9,98 +9,25 @@
|
||||
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
|
||||
#ifdef _WIN32
|
||||
# pragma warning(push)
|
||||
|
||||
# pragma warning(disable : 4100)
|
||||
#endif
|
||||
|
||||
// Prevents the compiler from complaining about or optimizing away variables
|
||||
// that appear unused on Linux
|
||||
#if (defined(__GNUC__) && !defined(__clang__))
|
||||
# undef NG_ATTRIBUTE_UNUSED
|
||||
# define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
|
||||
#else
|
||||
# define NG_ATTRIBUTE_UNUSED
|
||||
#endif
|
||||
|
||||
#define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
|
||||
#include "openvino/op/util/activation_functions.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
namespace error {
|
||||
struct UnknownActivationFunction : ngraph_error {
|
||||
UnknownActivationFunction(const std::string& func_name)
|
||||
: ngraph_error{"Unknown activation function: " + func_name} {}
|
||||
};
|
||||
using ov::op::util::error::UnknownActivationFunction;
|
||||
} // namespace error
|
||||
|
||||
namespace detail {
|
||||
std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
|
||||
float alpha UNUSED_PARAMETER,
|
||||
float beta UNUSED_PARAMETER);
|
||||
std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg, float alpha UNUSED_PARAMETER, float beta UNUSED_PARAMETER);
|
||||
std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg, float alpha UNUSED_PARAMETER, float beta UNUSED_PARAMETER);
|
||||
std::shared_ptr<Node> hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
|
||||
using ov::op::util::detail::hardsigmoid;
|
||||
using ov::op::util::detail::relu;
|
||||
using ov::op::util::detail::sigmoid;
|
||||
using ov::op::util::detail::tanh;
|
||||
} // namespace detail
|
||||
|
||||
using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&, float, float);
|
||||
|
||||
///
|
||||
/// \brief Class representing activation function used in RNN cells.
|
||||
///
|
||||
class NGRAPH_API ActivationFunction {
|
||||
public:
|
||||
ActivationFunction(ActivationFunctionType f, float alpha, float beta);
|
||||
ActivationFunction(ActivationFunctionType f, float alpha);
|
||||
ActivationFunction(ActivationFunctionType f);
|
||||
ActivationFunction() = default;
|
||||
|
||||
///
|
||||
/// \brief Calls stored activation function with provided node argument.
|
||||
///
|
||||
std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;
|
||||
|
||||
void set_alpha(float alpha) {
|
||||
m_alpha = alpha;
|
||||
}
|
||||
void set_beta(float beta) {
|
||||
m_beta = beta;
|
||||
}
|
||||
|
||||
private:
|
||||
/// \brief Activation function wrapper.
|
||||
ActivationFunctionType m_function;
|
||||
/// \brief Activation function alpha parameter (may be unused).
|
||||
float m_alpha;
|
||||
/// \brief Activation function beta parameter (may be unused).
|
||||
float m_beta;
|
||||
};
|
||||
|
||||
/// \brief Gets the activation function by name.
|
||||
///
|
||||
/// \param[in] func_name The function name
|
||||
///
|
||||
/// \throws UnknownActivationFunction When provided func_name is unknown.
|
||||
///
|
||||
/// \return The activation function object.
|
||||
///
|
||||
ActivationFunction get_activation_func_by_name(const std::string& func_name);
|
||||
using ov::op::util::ActivationFunction;
|
||||
using ov::op::util::ActivationFunctionType;
|
||||
using ov::op::util::get_activation_func_by_name;
|
||||
} // namespace util
|
||||
|
||||
} // namespace op
|
||||
|
||||
} // namespace ngraph
|
||||
|
||||
#ifdef _WIN32
|
||||
# pragma warning(pop)
|
||||
#endif
|
||||
|
||||
#ifdef UNUSED_PARAMETER
|
||||
# undef UNUSED_PARAMETER
|
||||
#endif
|
||||
#ifdef NG_ATTRIBUTE_UNUSED
|
||||
# undef NG_ATTRIBUTE_UNUSED
|
||||
#endif
|
||||
|
@ -6,39 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/reduction_base.hpp"
|
||||
#include "openvino/op/util/arithmetic_reduction.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Abstract base class for arithmetic reduction operations, i.e., operations
|
||||
/// where chosen axes of the input tensors are eliminated (reduced out) by
|
||||
/// repeated application of a particular binary arithmetic operation.
|
||||
class NGRAPH_API ArithmeticReduction : public ReductionBase {
|
||||
protected:
|
||||
/// \brief Constructs an arithmetic reduction operation.
|
||||
ArithmeticReduction();
|
||||
|
||||
/// \brief Constructs an arithmetic reduction operation.
|
||||
///
|
||||
/// \param arg Output that produces the first input tensor.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return true if reduction axes are constant else false.
|
||||
bool reduction_axes_constant() const;
|
||||
|
||||
/// \return The axis positions (0-based) to be eliminated through reduction.
|
||||
/// \throws CheckFailure if the reduction axes are not constant. (Use
|
||||
/// reduction_axes_constant to check.)
|
||||
const AxisSet get_reduction_axes() const;
|
||||
|
||||
/// \brief Change the reduction axes
|
||||
void set_reduction_axes(const AxisSet& reduction_axes);
|
||||
};
|
||||
using ov::op::util::ArithmeticReduction;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,37 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/arithmetic_reduction.hpp"
|
||||
#include "openvino/op/util/arithmetic_reductions_keep_dims.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class NGRAPH_API ArithmeticReductionKeepDims : public util::ArithmeticReduction {
|
||||
protected:
|
||||
ArithmeticReductionKeepDims() = default;
|
||||
|
||||
/// \param arg The tensor to be summed.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
|
||||
ArithmeticReductionKeepDims(const Output<Node>& arg, const Output<Node>& reduction_axes, bool keep_dims = false);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return If set to 1 it holds axes that are used for reduction.
|
||||
/// For each such axis, output dimension is equal to 1.
|
||||
bool get_keep_dims() const {
|
||||
return m_keep_dims;
|
||||
}
|
||||
void set_keep_dims(bool keep_dims) {
|
||||
m_keep_dims = keep_dims;
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_keep_dims = false;
|
||||
};
|
||||
using ov::op::util::ArithmeticReductionKeepDims;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -10,332 +10,20 @@
|
||||
#include "ngraph/attribute_adapter.hpp"
|
||||
#include "ngraph/ngraph_visibility.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
#include "openvino/op/util/attr_types.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
/// \brief Modes for the `Pad` operator.
|
||||
enum class PadMode { CONSTANT = 0, EDGE, REFLECT, SYMMETRIC };
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const PadMode& type);
|
||||
|
||||
/// \brief Padding Type used for `Convolution` and `Pooling`
|
||||
///
|
||||
/// Follows ONNX padding type definitions
|
||||
/// EXPLICIT - Pad dimensions are explicity specified
|
||||
/// SAME_LOWER - Pad dimensions computed to match input shape
|
||||
/// Ceil(num_dims/2) at the beginning and
|
||||
/// Floor(num_dims/2) at the end
|
||||
/// SAME_UPPER - Pad dimensions computed to match input shape
|
||||
/// Floor(num_dims/2) at the beginning and
|
||||
/// Ceil(num_dims/2) at the end
|
||||
/// VALID - No padding
|
||||
/// AUTO - Deprecated. User should not use it in the future
|
||||
/// NOTSET - Deprecated. User should not use it in the future
|
||||
|
||||
enum class PadType {
|
||||
EXPLICIT = 0,
|
||||
SAME_LOWER,
|
||||
SAME_UPPER,
|
||||
VALID,
|
||||
AUTO = SAME_UPPER,
|
||||
NOTSET = EXPLICIT,
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const PadType& type);
|
||||
|
||||
/// \brief Rounding Type used for `Pooling` operators.
|
||||
enum class RoundingType {
|
||||
FLOOR = 0,
|
||||
CEIL = 1,
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const RoundingType& type);
|
||||
|
||||
/// \brief Specifies the algorithm to use for implicit broadcasting of a tensor
|
||||
/// to align with another tensor
|
||||
///
|
||||
/// NONE - No implicit broadcasting of tensor
|
||||
/// NUMPY - Numpy-style implicit broadcasting
|
||||
/// (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
|
||||
/// Right-align dimensions of the two tensors, with missing dimensions
|
||||
/// treated as size 1 dimensions. After alignment, for each dimension,
|
||||
/// their sizes should either match or one of them should be of size 1.
|
||||
/// Size 1 dimension will be implicitly broadcast to match the other
|
||||
/// size.
|
||||
///
|
||||
/// E.g.,
|
||||
/// A: Shape(2, 1, 6)
|
||||
/// B: Shape( 3, 1)
|
||||
/// Result: Shape(2, 3, 6)
|
||||
///
|
||||
/// A: Shape(2, 1, 6)
|
||||
/// B: Shape( 3, 1)
|
||||
/// Result: Shape(2, 3, 6)
|
||||
/// PDPD - PaddlePaddle-style implicit broadcasting
|
||||
/// (https://github.com/PaddlePaddle/Paddle/blob/release/1.5/paddle/
|
||||
/// fluid/operators/elementwise/elementwise_op.h#L126)
|
||||
/// Broadcast B to match the shape of A, where axis is the start
|
||||
/// dimension index to align B with A. If axis is -1 (default), i
|
||||
/// axis = rank(A) - rank(B). The trailing dimensions of size 1 for B
|
||||
/// will be ignored.
|
||||
///
|
||||
/// E.g.,
|
||||
/// A: Shape(2, 3, 4, 5)
|
||||
/// B: Shape( 3, 4 ) with axis =1
|
||||
/// Result: Shape(2, 3, 4, 5)
|
||||
///
|
||||
/// A: Shape(2, 3, 4, 5)
|
||||
/// B: Shape( 3, 1 ) with axis = 1
|
||||
/// Result: Shape(2, 3, 4, 5)
|
||||
///
|
||||
/// TODO: Add more implicit broadcast modes used by frameworks
|
||||
enum class AutoBroadcastType {
|
||||
NONE = 0,
|
||||
EXPLICIT = NONE,
|
||||
NUMPY,
|
||||
PDPD,
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const AutoBroadcastType& type);
|
||||
/// \brief BroadcastType specifies rules used for mapping of input tensor axes to output
|
||||
/// shape axes.
|
||||
///
|
||||
/// \note Broadcasting rules are different for Broadcast op and for element-wise ops.
|
||||
/// AutoBroadcastType::NUMPY is equivalent of BroadcastType::BIDIRECTIONAL
|
||||
/// according to spec.
|
||||
///
|
||||
/// EXPLICIT - Mapping of the input data shape to output shape
|
||||
/// based on axes_mapping input.
|
||||
/// NUMPY - Numpy broadcasting rules, aligned with ONNX Broadcasting.
|
||||
/// (https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md)
|
||||
/// PDPD - PaddlePaddle-style implicit broadcasting.
|
||||
/// For more informaction see AutoBroadcastType documentation.
|
||||
/// BIDIRECTIONAL - The broadcast rule is similar to
|
||||
/// numpy.array(input) * numpy.ones(target_shape).
|
||||
/// Dimensions are right alignment.
|
||||
enum class BroadcastType { NONE, EXPLICIT = NONE, NUMPY, PDPD, BIDIRECTIONAL };
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const BroadcastType& type);
|
||||
|
||||
/// \brief Specifies how eps is combined with L2 value
|
||||
enum class EpsMode {
|
||||
// Add bias to norm
|
||||
ADD,
|
||||
// Calculate max of norm and bias
|
||||
MAX
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const EpsMode& type);
|
||||
|
||||
enum class TopKSortType {
|
||||
// Returned values are not sorte
|
||||
NONE,
|
||||
// Sort result based on element indices
|
||||
SORT_INDICES,
|
||||
// Sort result based on element values
|
||||
SORT_VALUES,
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const TopKSortType& type);
|
||||
|
||||
enum class TopKMode {
|
||||
MAX,
|
||||
MIN,
|
||||
};
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const TopKMode& type);
|
||||
|
||||
/// \brief Implicit broadcast specification
|
||||
struct NGRAPH_API AutoBroadcastSpec {
|
||||
AutoBroadcastSpec() : m_type(AutoBroadcastType::NONE), m_axis(0) {}
|
||||
AutoBroadcastSpec(AutoBroadcastType type) : m_type(type), m_axis(0) {}
|
||||
AutoBroadcastSpec(const char* type) : AutoBroadcastSpec(type_from_string(type)) {}
|
||||
AutoBroadcastSpec(AutoBroadcastType type, int64_t axis) : m_type(type), m_axis(axis) {}
|
||||
|
||||
AutoBroadcastType m_type; // Implicit broadcasting algorithm
|
||||
int64_t m_axis; // Axis to start alignment on
|
||||
|
||||
bool operator==(const AutoBroadcastSpec& a) const {
|
||||
return a.m_type == m_type && a.m_axis == m_axis;
|
||||
}
|
||||
|
||||
bool operator!=(const AutoBroadcastSpec& a) const {
|
||||
return !(*this == a);
|
||||
}
|
||||
static const AutoBroadcastSpec NUMPY;
|
||||
static const AutoBroadcastSpec NONE;
|
||||
|
||||
private:
|
||||
AutoBroadcastType type_from_string(const std::string& type) const;
|
||||
};
|
||||
|
||||
/// \brief Implicit broadcast specification
|
||||
struct NGRAPH_API BroadcastModeSpec {
|
||||
BroadcastModeSpec() : m_type(BroadcastType::NUMPY), m_axis(0) {}
|
||||
BroadcastModeSpec(BroadcastType type) : m_type(type), m_axis(0) {}
|
||||
BroadcastModeSpec(const char* type) : BroadcastModeSpec(as_enum<BroadcastType>(type)) {}
|
||||
BroadcastModeSpec(BroadcastType type, int64_t axis) : m_type(type), m_axis(axis) {}
|
||||
|
||||
BroadcastType m_type; // Implicit broadcasting algorithm
|
||||
int64_t m_axis; // Axis to start alignment on
|
||||
|
||||
bool operator==(const BroadcastModeSpec& a) const {
|
||||
return a.m_type == m_type && a.m_axis == m_axis;
|
||||
}
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief This class defines possible recurrent sequence directions.
|
||||
///
|
||||
enum class RecurrentSequenceDirection { FORWARD, REVERSE, BIDIRECTIONAL };
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const RecurrentSequenceDirection& direction);
|
||||
using ov::op::AutoBroadcastSpec;
|
||||
using ov::op::AutoBroadcastType;
|
||||
using ov::op::BroadcastModeSpec;
|
||||
using ov::op::BroadcastType;
|
||||
using ov::op::EpsMode;
|
||||
using ov::op::PadMode;
|
||||
using ov::op::PadType;
|
||||
using ov::op::RecurrentSequenceDirection;
|
||||
using ov::op::RoundingType;
|
||||
using ov::op::TopKMode;
|
||||
using ov::op::TopKSortType;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::PadMode> : public EnumAttributeAdapterBase<ngraph::op::PadMode> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::PadMode& value) : EnumAttributeAdapterBase<ngraph::op::PadMode>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PadMode>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::PadType> : public EnumAttributeAdapterBase<ngraph::op::PadType> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::PadType& value) : EnumAttributeAdapterBase<ngraph::op::PadType>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::PadType>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::RoundingType>
|
||||
: public EnumAttributeAdapterBase<ngraph::op::RoundingType> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::RoundingType& value) : EnumAttributeAdapterBase<ngraph::op::RoundingType>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::RoundingType>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::AutoBroadcastType>
|
||||
: public EnumAttributeAdapterBase<ngraph::op::AutoBroadcastType> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::AutoBroadcastType& value)
|
||||
: EnumAttributeAdapterBase<ngraph::op::AutoBroadcastType>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastType>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::BroadcastType>
|
||||
: public EnumAttributeAdapterBase<ngraph::op::BroadcastType> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::BroadcastType& value) : EnumAttributeAdapterBase<ngraph::op::BroadcastType>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::BroadcastType>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::EpsMode> : public EnumAttributeAdapterBase<ngraph::op::EpsMode> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::EpsMode& value) : EnumAttributeAdapterBase<ngraph::op::EpsMode>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::EpsMode>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::TopKSortType>
|
||||
: public EnumAttributeAdapterBase<ngraph::op::TopKSortType> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::TopKSortType& value) : EnumAttributeAdapterBase<ngraph::op::TopKSortType>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::TopKSortType>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::TopKMode> : public EnumAttributeAdapterBase<ngraph::op::TopKMode> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::TopKMode& value) : EnumAttributeAdapterBase<ngraph::op::TopKMode>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::TopKMode>", 1};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
class AttributeAdapter<ngraph::op::AutoBroadcastSpec> : public VisitorAdapter {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::AutoBroadcastSpec& value) : m_ref(value) {}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::AutoBroadcastSpec>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
protected:
|
||||
ngraph::op::AutoBroadcastSpec& m_ref;
|
||||
};
|
||||
|
||||
template <>
|
||||
class AttributeAdapter<ngraph::op::BroadcastModeSpec> : public VisitorAdapter {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::BroadcastModeSpec& value) : m_ref(value) {}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::BroadcastModeSpec>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
protected:
|
||||
ngraph::op::BroadcastModeSpec& m_ref;
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::RecurrentSequenceDirection>
|
||||
: public EnumAttributeAdapterBase<ngraph::op::RecurrentSequenceDirection> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::RecurrentSequenceDirection& value)
|
||||
: EnumAttributeAdapterBase<ngraph::op::RecurrentSequenceDirection>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::RecurrentSequenceDirection>", 1};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
} // namespace ov
|
||||
|
@ -6,65 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "openvino/op/util/binary_elementwise_arithmetic.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
// clang-format off
|
||||
/// \brief Abstract base class for elementwise binary arithmetic operations, i.e.,
|
||||
/// operations where the same scalar binary arithmetic operation is applied to
|
||||
/// each corresponding pair of elements in the two input tensors. Implicit
|
||||
/// broadcast of input tensors is supported through one of the AutoBroadcast
|
||||
/// modes.
|
||||
///
|
||||
/// For example, if the underlying arithmetic operation (determined by the subclass) is
|
||||
/// \f$\mathit{op}(x,y)\f$, the input tensors
|
||||
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to
|
||||
/// \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
|
||||
///
|
||||
/// ## Inputs
|
||||
///
|
||||
/// | | Type | Description |
|
||||
/// | ------ | --------------------------------- | ------------------------------------------------------------------------ |
|
||||
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. The element type \f$N\f$ may be any numeric type. |
|
||||
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same element type as `arg0`. |
|
||||
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
|
||||
///
|
||||
/// ## Output
|
||||
///
|
||||
/// | Type | Description |
|
||||
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors (after auto broadcasting). |
|
||||
// clang-format on
|
||||
class NGRAPH_API BinaryElementwiseArithmetic : public Op {
|
||||
protected:
|
||||
BinaryElementwiseArithmetic(const AutoBroadcastSpec& autob);
|
||||
|
||||
/// \brief Constructs a binary elementwise arithmetic operation.
|
||||
///
|
||||
/// \param arg0 Output that produces the first input tensor.
|
||||
/// \param arg1 Output that produces the second input tensor.
|
||||
BinaryElementwiseArithmetic(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override {
|
||||
return m_autob;
|
||||
}
|
||||
void set_autob(const AutoBroadcastSpec& autob) {
|
||||
m_autob = autob;
|
||||
}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
|
||||
private:
|
||||
AutoBroadcastSpec m_autob;
|
||||
void validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob);
|
||||
};
|
||||
using ov::op::util::BinaryElementwiseArithmetic;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,66 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "openvino/op/util/binary_elementwise_comparison.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
// clang-format off
|
||||
/// \brief Abstract base class for elementwise binary comparison operations, i.e.,
|
||||
/// operations where the same scalar binary comparison operation is applied to
|
||||
/// each corresponding pair of elements in two input tensors. Implicit
|
||||
/// broadcast of input tensors is supported through one of the AutoBroadcast
|
||||
/// modes.
|
||||
///
|
||||
/// For example, if the underlying comparison operation (determined by the subclass) is
|
||||
/// \f$\mathit{op}(x,y)\f$, the input tensors \f$[[x_0,y_0],[z_0,w_0]]\f$ and
|
||||
/// \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to
|
||||
/// \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
|
||||
///
|
||||
/// ## Inputs
|
||||
///
|
||||
/// | | Type | Description |
|
||||
/// | ------ | --------------------------------- | ------------------------------------------------------ |
|
||||
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
|
||||
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
|
||||
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
|
||||
///
|
||||
/// ## Output
|
||||
///
|
||||
/// | Type | Description |
|
||||
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
|
||||
// clang-format on
|
||||
class NGRAPH_API BinaryElementwiseComparison : public Op {
|
||||
protected:
|
||||
/// \brief Constructs a binary elementwise comparison operation.
|
||||
BinaryElementwiseComparison(const AutoBroadcastSpec& autob);
|
||||
|
||||
/// \brief Constructs a binary elementwise comparison operation.
|
||||
///
|
||||
/// \param arg0 Output that produces the first input tensor.
|
||||
/// \param arg1 Output that produces the second input tensor.
|
||||
/// \param autob AutoBroadcast mode.
|
||||
BinaryElementwiseComparison(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override {
|
||||
return m_autob;
|
||||
}
|
||||
void set_autob(const AutoBroadcastSpec& autob) {
|
||||
m_autob = autob;
|
||||
}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
AutoBroadcastSpec m_autob;
|
||||
};
|
||||
using ov::op::util::BinaryElementwiseComparison;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -5,64 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/binary_elementwise_logical.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
// clang-format off
|
||||
/// \brief Abstract base class for elementwise binary logical operations, i.e.,
|
||||
/// operations where the same scalar binary logical operation is applied to
|
||||
/// each corresponding pair of elements in two boolean input tensors. Implicit
|
||||
/// broadcast of input tensors is supported through one of the AutoBroadcast
|
||||
/// modes.
|
||||
///
|
||||
/// For example, if the underlying operation (determined by the subclass) is
|
||||
/// \f$\mathit{op}(x,y)\f$, the input tensors \f$[[x_0,y_0],[z_0,w_0]]\f$ and
|
||||
/// \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to
|
||||
/// \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
|
||||
///
|
||||
/// ## Inputs
|
||||
///
|
||||
/// | | Type | Description |
|
||||
/// | ------ | --------------------------------------------- | ------------------------------------------------------ |
|
||||
/// | `arg0` | \f$\texttt{bool}[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape, with element type `bool`. |
|
||||
/// | `arg1` | \f$\texttt{bool}[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
|
||||
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
|
||||
///
|
||||
/// ## Output
|
||||
///
|
||||
/// | Type | Description |
|
||||
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
|
||||
// clang-format on
|
||||
class NGRAPH_API BinaryElementwiseLogical : public Op {
|
||||
protected:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
BinaryElementwiseLogical();
|
||||
|
||||
/// \brief Constructs a binary elementwise logical operation.
|
||||
///
|
||||
/// \param arg0 Output that produces the first input tensor.
|
||||
/// \param arg1 Output that produces the second input tensor.
|
||||
BinaryElementwiseLogical(const Output<Node>& arg0,
|
||||
const Output<Node>& arg1,
|
||||
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
|
||||
|
||||
public:
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override {
|
||||
return m_autob;
|
||||
}
|
||||
void set_autob(const AutoBroadcastSpec& autob) {
|
||||
m_autob = autob;
|
||||
}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
AutoBroadcastSpec m_autob = AutoBroadcastSpec::NUMPY;
|
||||
};
|
||||
using ov::op::util::BinaryElementwiseLogical;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -2,84 +2,18 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/axis_set.hpp"
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
|
||||
#pragma once
|
||||
#include "openvino/op/util/broadcast_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class NGRAPH_API BroadcastBase : public Op {
|
||||
protected:
|
||||
BroadcastBase() = default;
|
||||
/// \brief Constructs a broadcast operation.
|
||||
///
|
||||
/// \param arg The input tensor to be broadcast.
|
||||
/// \param target_shape The shape of the output tensor.
|
||||
/// \param axes_mapping The axis positions (0-based) in the result that correspond
|
||||
/// to input axes.
|
||||
/// \param broadcast_mode Broadcast specification to use for determining broadcast
|
||||
/// axes. 'axes_mapping' should not be provided if mode other
|
||||
///
|
||||
BroadcastBase(const Output<Node>& arg,
|
||||
const Output<Node>& target_shape,
|
||||
const Output<Node>& axes_mapping,
|
||||
const BroadcastModeSpec& broadcast_mode = BroadcastType::EXPLICIT);
|
||||
|
||||
/// \brief Constructs a broadcast operation.
|
||||
///
|
||||
/// \param arg The input tensor to be broadcast.
|
||||
/// \param target_shape The shape of the output tensor.
|
||||
/// \param broadcast_mode Broadcast specification to use for determining broadcast
|
||||
/// axes
|
||||
BroadcastBase(const Output<Node>& arg,
|
||||
const Output<Node>& target_shape,
|
||||
const BroadcastModeSpec& broadcast_mode = BroadcastType::NUMPY);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
/// \return true and the AxisSet if broadcast axes can be fully determined.
|
||||
virtual std::pair<bool, AxisSet> get_broadcast_axes() const;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
|
||||
protected:
|
||||
BroadcastModeSpec m_mode;
|
||||
|
||||
bool evaluate_broadcast(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& out,
|
||||
const std::pair<bool, AxisSet> pair_broadcast_axes,
|
||||
const Shape output_shape) const;
|
||||
|
||||
bool evaluate_broadcast(const HostTensorPtr& arg0, const HostTensorPtr& out, const AxisSet& broadcast_axes) const;
|
||||
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
|
||||
PartialShape get_result_shape_pdpd(const PartialShape& arg0_shape,
|
||||
const PartialShape& target_shape,
|
||||
const op::BroadcastModeSpec& broadcast_spec) const;
|
||||
|
||||
void validate_target_shape_numpy(const PartialShape& arg_shape, const PartialShape& target_shape) const;
|
||||
|
||||
static std::pair<bool, AxisSet> get_broadcast_axes_numpy_pdpd(const Shape& arg_shape,
|
||||
const Shape& result_shape,
|
||||
const op::BroadcastModeSpec& broadcast_spec);
|
||||
|
||||
static std::pair<bool, AxisSet> get_broadcast_axes_none(const AxisVector axes_mapping_val,
|
||||
const size_t target_shape);
|
||||
|
||||
void validate_target_shape_none(const PartialShape& arg_shape,
|
||||
const AxisVector& axes_mapping_val,
|
||||
const PartialShape& target_shape) const;
|
||||
|
||||
Shape get_target_shape(const HostTensorPtr& input1) const;
|
||||
};
|
||||
using ov::op::util::BroadcastBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -7,99 +7,12 @@
|
||||
#include "ngraph/coordinate_diff.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "openvino/op/util/deformable_convolution_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Base class for operations DeformableConvolution v1 and DeformableConvolution
|
||||
/// v8.
|
||||
class NGRAPH_API DeformableConvolutionBase : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
/// \brief Constructs a conversion operation.
|
||||
DeformableConvolutionBase() = default;
|
||||
|
||||
/// \brief Constructs a conversion operation.
|
||||
/// \param strides Convolution strides.
|
||||
/// \param pads_begin Amount of padding to be added to the beginning along
|
||||
/// each axis. For example in case of a 2D input the value
|
||||
/// of (1, 2) means that 1 element will be added to the
|
||||
/// top and 2 elements to the left.
|
||||
/// \param pads_end Amount of padding to be added to the end along each
|
||||
/// axis.
|
||||
/// \param dilations The distance in width and height between the weights
|
||||
/// in the filters tensor.
|
||||
/// \param auto_pad Specifies how the automatic calculation of padding
|
||||
/// should be done.
|
||||
/// \param group The number of groups which both output and input
|
||||
/// should be split into.
|
||||
/// \param deformable_group The number of groups which deformable values and
|
||||
/// output should be split into along the channel axis.
|
||||
DeformableConvolutionBase(const OutputVector& arguments,
|
||||
const Strides& strides,
|
||||
const CoordinateDiff& pads_begin,
|
||||
const CoordinateDiff& pads_end,
|
||||
const Strides& dilations,
|
||||
const PadType& auto_pad = PadType::EXPLICIT,
|
||||
int64_t group = 1,
|
||||
int64_t deformable_group = 1);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const Strides& get_strides() const {
|
||||
return m_strides;
|
||||
}
|
||||
void set_strides(const Strides& strides) {
|
||||
m_strides = strides;
|
||||
}
|
||||
const Strides& get_dilations() const {
|
||||
return m_dilations;
|
||||
}
|
||||
void set_dilations(const Strides& dilations) {
|
||||
m_dilations = dilations;
|
||||
}
|
||||
const CoordinateDiff& get_pads_begin() const {
|
||||
return m_pads_begin;
|
||||
}
|
||||
void set_pads_begin(const CoordinateDiff& pads_begin) {
|
||||
m_pads_begin = pads_begin;
|
||||
}
|
||||
const CoordinateDiff& get_pads_end() const {
|
||||
return m_pads_end;
|
||||
}
|
||||
void set_pads_end(const CoordinateDiff& pads_end) {
|
||||
m_pads_end = pads_end;
|
||||
}
|
||||
const PadType& get_auto_pad() const {
|
||||
return m_auto_pad;
|
||||
}
|
||||
void set_auto_pad(const PadType& auto_pad) {
|
||||
m_auto_pad = auto_pad;
|
||||
}
|
||||
int64_t get_group() const {
|
||||
return m_group;
|
||||
}
|
||||
void set_group(const int64_t group) {
|
||||
m_group = group;
|
||||
}
|
||||
int64_t get_deformable_group() const {
|
||||
return m_deformable_group;
|
||||
}
|
||||
void set_deformable_group(const int64_t deformable_group) {
|
||||
m_deformable_group = deformable_group;
|
||||
}
|
||||
|
||||
protected:
|
||||
Strides m_strides;
|
||||
Strides m_dilations;
|
||||
CoordinateDiff m_pads_begin;
|
||||
CoordinateDiff m_pads_end;
|
||||
PadType m_auto_pad;
|
||||
int64_t m_group;
|
||||
int64_t m_deformable_group;
|
||||
};
|
||||
using ov::op::util::DeformableConvolutionBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -5,13 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "openvino/op/util/elementwise_args.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
|
||||
Node* node,
|
||||
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
|
||||
using ov::op::util::validate_and_infer_elementwise_args;
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,61 +6,12 @@
|
||||
|
||||
#include "ngraph/axis_set.hpp"
|
||||
#include "ngraph/op/util/index_reduction.hpp"
|
||||
#include "openvino/op/util/embeddingbag_offsets_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Returns embeddings for given indices
|
||||
class NGRAPH_API EmbeddingBagOffsetsBase : public Op {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsBase", 3};
|
||||
const NodeTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
/// \brief Constructs a EmbeddingBagOffsetsBase operation.
|
||||
EmbeddingBagOffsetsBase() = default;
|
||||
/// \brief Constructs a EmbeddingBagOffsetsBase operation.
|
||||
///
|
||||
/// EmbeddingBagOffsetsBase constructs an output tensor by replacing every index in
|
||||
/// a
|
||||
/// given
|
||||
/// input tensor with a row (from the weights matrix) at that index
|
||||
///
|
||||
/// \param emb_table tensor containing the embedding lookup table of the module of
|
||||
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
|
||||
/// \param tensor of shape [num_indices] and of type T_IND. Required
|
||||
/// \param offsets tensor of shape [batch] and of type T_IND containing the starting
|
||||
/// index positions of each "bag" in indices. Required.
|
||||
/// \param per_sample_weigths tensor of the same shape as indices and of type T.
|
||||
/// Each value in this tensor are multiplied with each
|
||||
/// value pooled from embedding table for each index. Optional.
|
||||
/// \param default_index scalar of type T_IND containing default index in embedding
|
||||
/// table to fill empty "bags". If not provided empty "bags"
|
||||
/// are filled with zeros. Optional.
|
||||
|
||||
EmbeddingBagOffsetsBase(const Output<Node>& emb_table,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& offsets,
|
||||
const Output<Node>& default_index,
|
||||
const Output<Node>& per_sample_weights);
|
||||
|
||||
EmbeddingBagOffsetsBase(const Output<Node>& emb_table,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& offsets,
|
||||
const Output<Node>& default_index);
|
||||
|
||||
EmbeddingBagOffsetsBase(const Output<Node>& emb_table, const Output<Node>& indices, const Output<Node>& offsets);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
static constexpr int EMB_TABLE = 0;
|
||||
static constexpr int INDICES = 1;
|
||||
static constexpr int OFFSETS = 2;
|
||||
static constexpr int DEFAULT_INDEX = 3;
|
||||
static constexpr int PER_SAMPLE_WEIGHTS = 4;
|
||||
};
|
||||
using ov::op::util::EmbeddingBagOffsetsBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,47 +6,12 @@
|
||||
|
||||
#include "ngraph/axis_set.hpp"
|
||||
#include "ngraph/op/util/index_reduction.hpp"
|
||||
#include "openvino/op/util/embeddingbag_packed_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Returns embeddings for given indices
|
||||
class NGRAPH_API EmbeddingBagPackedBase : public Op {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"EmbeddingBagPackedBase", 3};
|
||||
const NodeTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
/// \brief Constructs a EmbeddingBagPackedBase operation.
|
||||
EmbeddingBagPackedBase() = default;
|
||||
/// \brief Constructs a EmbeddingBagPackedBase operation.
|
||||
///
|
||||
/// EmbeddingBagPackedBase constructs an output tensor by replacing every index in a
|
||||
/// given
|
||||
/// input tensor with a row (from the weights matrix) at that index
|
||||
///
|
||||
/// \param emb_table Tensor containing the embedding lookup table of the module of
|
||||
/// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
|
||||
/// \param indices Tensor of shape `[batch, indices_per_bag]` and of type *T_IND*.
|
||||
/// Required.
|
||||
/// \param per_sample_weigths tensor of the same shape as indices and of type T.
|
||||
/// Each value in this tensor are multiplied with each
|
||||
/// value pooled from embedding table for each index. Optional.
|
||||
|
||||
EmbeddingBagPackedBase(const Output<Node>& emb_table,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& per_sample_weights);
|
||||
|
||||
EmbeddingBagPackedBase(const Output<Node>& emb_table, const Output<Node>& indices);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
static constexpr int EMB_TABLE = 0;
|
||||
static constexpr int INDICES = 1;
|
||||
static constexpr int PER_SAMPLE_WEIGHTS = 2;
|
||||
};
|
||||
using ov::op::util::EmbeddingBagPackedBase;
|
||||
} // namespace util
|
||||
using util::EmbeddingBagPackedBase;
|
||||
} // namespace op
|
||||
|
@ -6,35 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "openvino/op/util/fft_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Base class for operations DFT and DFT.
|
||||
class NGRAPH_API FFTBase : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
FFTBase() = default;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
/// \brief Constructs an FFT operation. FFT is performed for full size axes.
|
||||
///
|
||||
/// \param data Input data
|
||||
/// \param axes Axes to perform FFT
|
||||
FFTBase(const Output<Node>& data, const Output<Node>& axes);
|
||||
|
||||
/// \brief Constructs a FFT operation.
|
||||
///
|
||||
/// \param data Input data
|
||||
/// \param axes Axes to perform FFT
|
||||
/// \param signal_size Signal sizes for 'axes'
|
||||
FFTBase(const Output<Node>& data, const Output<Node>& axes, const Output<Node>& signal_size);
|
||||
|
||||
void validate();
|
||||
};
|
||||
using ov::op::util::FFTBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -5,38 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/gather_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief GatherBase basic class for Gather v1 and v7
|
||||
class NGRAPH_API GatherBase : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
GatherBase() = default;
|
||||
|
||||
/// \param data The tensor from which slices are gathered
|
||||
/// \param indices Tensor with indexes to gather
|
||||
/// \param axis The tensor is a dimension index to gather data from
|
||||
/// \param batch_dims The number of batch dimension in data and indices tensors
|
||||
GatherBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims = 0);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
virtual int64_t get_axis() const;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
|
||||
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
|
||||
|
||||
protected:
|
||||
int64_t m_batch_dims = 0;
|
||||
};
|
||||
using ov::op::util::GatherBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -10,28 +10,12 @@
|
||||
#include <utility>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/index_reduction.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class NGRAPH_API IndexReduction : public Op {
|
||||
protected:
|
||||
IndexReduction();
|
||||
|
||||
IndexReduction(const Output<Node>& arg, uint64_t axis, const element::Type& index_element_type);
|
||||
|
||||
public:
|
||||
uint64_t get_reduction_axis() const;
|
||||
void set_reduction_axis(uint64_t value);
|
||||
element::Type get_index_element_type() const;
|
||||
void set_index_element_type(const element::Type& index_element_type);
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
uint64_t m_axis{0};
|
||||
element::Type m_index_element_type;
|
||||
};
|
||||
using ov::op::util::IndexReduction;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,41 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/reduction_base.hpp"
|
||||
#include "openvino/op/util/logical_reduction.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Abstract base class for logical reduction operations, i.e., operations where
|
||||
/// chosen axes of the input tensors are eliminated (reduced out) by repeated
|
||||
/// application of a particular binary logical operation.
|
||||
class NGRAPH_API LogicalReduction : public ReductionBase {
|
||||
protected:
|
||||
/// \brief Constructs a logical reduction operation.
|
||||
LogicalReduction();
|
||||
/// \brief Constructs a logical reduction operation.
|
||||
///
|
||||
/// \param arg Output that produces the first input tensor.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
LogicalReduction(const Output<Node>& arg, const AxisSet& reduction_axes);
|
||||
/// \brief Constructs a 'dynamic' logical reduction operation.
|
||||
///
|
||||
/// \param arg Node that produces the first input tensor.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
LogicalReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return true if reduction axes are constant else false.
|
||||
bool reduction_axes_constant() const;
|
||||
|
||||
/// \return The axis positions (0-based) to be eliminated through reduction.
|
||||
/// \throws CheckFailure if the reduction axes are not constant. (Use
|
||||
/// reduction_axes_constant to check.)
|
||||
const AxisSet get_reduction_axes() const;
|
||||
void set_reduction_axes(const AxisSet& reduction_axes);
|
||||
};
|
||||
using ov::op::util::LogicalReduction;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,37 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/logical_reduction.hpp"
|
||||
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class NGRAPH_API LogicalReductionKeepDims : public util::LogicalReduction {
|
||||
protected:
|
||||
LogicalReductionKeepDims() = default;
|
||||
|
||||
/// \param arg The tensor to be reduced.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
|
||||
LogicalReductionKeepDims(const Output<Node>& arg, const Output<Node>& reduction_axes, const bool keep_dims = false);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return If set to 1 it holds axes that are used for reduction.
|
||||
/// For each such axis, output dimension is equal to 1.
|
||||
bool get_keep_dims() const {
|
||||
return m_keep_dims;
|
||||
}
|
||||
void set_keep_dims(bool keep_dims) {
|
||||
m_keep_dims = keep_dims;
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_keep_dims = false;
|
||||
};
|
||||
using ov::op::util::LogicalReductionKeepDims;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -6,91 +6,12 @@
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "openvino/op/util/max_pool_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class NGRAPH_API MaxPoolBase : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MaxPoolBase() = default;
|
||||
|
||||
/// \param arg The node producing the input data batch tensor.
|
||||
/// \param strides The strides.
|
||||
/// \param pads_begin The beginning of padding shape.
|
||||
/// \param pads_end The end of padding shape.
|
||||
/// \param kernel The kernel shape.
|
||||
/// \param rounding_mode Whether to use ceiling or floor rounding type while
|
||||
/// computing output shape.
|
||||
/// \param auto_pad The pad type for automatically computing padding sizes.
|
||||
MaxPoolBase(const Output<Node>& arg,
|
||||
const Strides& strides,
|
||||
const Shape& pads_begin,
|
||||
const Shape& pads_end,
|
||||
const Shape& kernel,
|
||||
const op::RoundingType rounding_mode = op::RoundingType::FLOOR,
|
||||
const PadType auto_pad = op::PadType::EXPLICIT);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
/// \return The kernel shape.
|
||||
const Shape& get_kernel() const {
|
||||
return m_kernel;
|
||||
}
|
||||
void set_kernel(const Shape& kernel) {
|
||||
m_kernel = kernel;
|
||||
}
|
||||
/// \return The strides.
|
||||
const Strides& get_strides() const {
|
||||
return m_strides;
|
||||
}
|
||||
void set_strides(const Strides& strides) {
|
||||
m_strides = strides;
|
||||
}
|
||||
/// \return The beginning of padding shape.
|
||||
const Shape& get_pads_begin() const {
|
||||
return m_pads_begin;
|
||||
}
|
||||
void set_pads_begin(const Shape& pads_begin) {
|
||||
m_pads_begin = pads_begin;
|
||||
}
|
||||
/// \return The end of padding shape.
|
||||
const Shape& get_pads_end() const {
|
||||
return m_pads_end;
|
||||
}
|
||||
void set_adding_above(const Shape& pads_end) {
|
||||
m_pads_end = pads_end;
|
||||
}
|
||||
/// \return The pad type for pooling.
|
||||
PadType get_auto_pad() const {
|
||||
return m_auto_pad;
|
||||
}
|
||||
void set_auto_pad(const PadType auto_pad) {
|
||||
m_auto_pad = auto_pad;
|
||||
}
|
||||
/// \return The ceiling mode being used for output shape computations
|
||||
op::RoundingType get_rounding_type() const {
|
||||
return m_rounding_type;
|
||||
}
|
||||
void set_rounding_type(op::RoundingType rounding_type) {
|
||||
m_rounding_type = rounding_type;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool update_auto_padding(const PartialShape& in_shape,
|
||||
const Strides& filter_dilations,
|
||||
Shape& new_pads_end,
|
||||
Shape& new_pads_begin) const;
|
||||
|
||||
PartialShape infer_output_shape(const Strides& dilations);
|
||||
|
||||
Shape m_kernel;
|
||||
Strides m_strides;
|
||||
Shape m_pads_begin;
|
||||
Shape m_pads_end;
|
||||
PadType m_auto_pad;
|
||||
op::RoundingType m_rounding_type;
|
||||
};
|
||||
using ov::op::util::MaxPoolBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -8,317 +8,16 @@
|
||||
#include <ngraph/op/parameter.hpp>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/multi_subgraph_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Abstract base class for sub-graph based ops, i.e ops that have some
|
||||
/// sub-graphs
|
||||
///
|
||||
class NGRAPH_API MultiSubGraphOp : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
/// \brief Abstract class describes a connection between a MultiSubGraphOp input and
|
||||
/// the body.
|
||||
class InputDescription {
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the MultiSubGraphOp input
|
||||
/// \param body_parameter_index Body parameter to receive input
|
||||
///
|
||||
InputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InputDescription() = default;
|
||||
|
||||
public:
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~InputDescription() = default;
|
||||
virtual std::shared_ptr<InputDescription> copy() const = 0;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
uint64_t m_input_index{0};
|
||||
uint64_t m_body_parameter_index{0};
|
||||
};
|
||||
|
||||
/// \brief Abstract class describes how a MultiSubGraphOp output is produced from
|
||||
/// the body.
|
||||
class OutputDescription {
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The MultiSubGraphOp output index
|
||||
///
|
||||
OutputDescription(uint64_t body_value_index, uint64_t output_index);
|
||||
OutputDescription() = default;
|
||||
|
||||
public:
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~OutputDescription() = default;
|
||||
virtual std::shared_ptr<OutputDescription> copy() const = 0;
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
uint64_t m_body_value_index{0};
|
||||
uint64_t m_output_index{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input formed from slices of an input to
|
||||
/// MultiSubGraphOp.
|
||||
///
|
||||
class NGRAPH_API SliceInputDescription : public InputDescription {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the MultiSubGraphOp input
|
||||
/// \param body_parameter_index Body parameter position to receive input
|
||||
/// \param start First index for slices
|
||||
/// \param stride Step amount for slices
|
||||
/// \param part_size Width of slices
|
||||
/// \param end Last index for slices
|
||||
/// \param axis Axis being sliced
|
||||
///
|
||||
SliceInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
SliceInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
int64_t m_end{0};
|
||||
int64_t m_axis{0};
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Describes a body input initialized from a MultiSubGraphOp input
|
||||
/// on the first iteration, and then a body output thereafter.
|
||||
///
|
||||
class NGRAPH_API MergedInputDescription : public InputDescription {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the MultiSubGraphOp input
|
||||
/// supplying a value to body_parameter for
|
||||
/// the initial iteration.
|
||||
/// \param body_parameter_index Body parameter position to receive input.
|
||||
/// \param body_value_index Body value to supply body_parameter for
|
||||
/// successive
|
||||
/// iterations.
|
||||
///
|
||||
MergedInputDescription(uint64_t input_index, uint64_t body_parameter_index, uint64_t body_value_index);
|
||||
MergedInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
uint64_t m_body_value_index{0};
|
||||
};
|
||||
|
||||
/// \brief Produces an output by concatenating an output from each iteration
|
||||
class NGRAPH_API ConcatOutputDescription : public OutputDescription {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The MultiSubGraphOp output index
|
||||
/// \param start First index for slices
|
||||
/// \param stride Step amount for slices
|
||||
/// \param part_size Width of slices
|
||||
/// \param end Last index for slices
|
||||
/// \param axis Axis being sliced
|
||||
///
|
||||
ConcatOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
ConcatOutputDescription() = default;
|
||||
|
||||
std::shared_ptr<OutputDescription> copy() const override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
int64_t m_end{0};
|
||||
int64_t m_axis{0};
|
||||
};
|
||||
|
||||
/// \brief Produces an input
|
||||
class NGRAPH_API InvariantInputDescription : public InputDescription {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param input_index Position of the MultiSubGraphOp input
|
||||
/// \param body_parameter_index Body parameter to receive input
|
||||
///
|
||||
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InvariantInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
};
|
||||
|
||||
/// \brief Produces an output from a specific iteration
|
||||
class NGRAPH_API BodyOutputDescription : public MultiSubGraphOp::OutputDescription {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
///
|
||||
/// \brief Constructs a new instance.
|
||||
///
|
||||
/// \param body_value_index A body value that produces the output
|
||||
/// \param output_index The SubGraphOp output index
|
||||
/// \param iteration which iteration (typically -1, final) will
|
||||
/// supply the value
|
||||
///
|
||||
BodyOutputDescription(uint64_t body_value_index, uint64_t output_index, int64_t iteration = -1);
|
||||
BodyOutputDescription() = default;
|
||||
std::shared_ptr<MultiSubGraphOp::OutputDescription> copy() const override;
|
||||
int64_t m_iteration{0};
|
||||
};
|
||||
using MultiSubgraphInputDescriptionPtr = std::shared_ptr<MultiSubGraphOp::InputDescription>;
|
||||
using MultiSubgraphOutputDescriptionPtr = std::shared_ptr<MultiSubGraphOp::OutputDescription>;
|
||||
using MultiSubgraphInputDescriptionVector = std::vector<MultiSubgraphInputDescriptionPtr>;
|
||||
using MultiSubgraphOutputDescriptionVector = std::vector<MultiSubgraphOutputDescriptionPtr>;
|
||||
|
||||
/// \brief Gets internal sub-graph by index in MultiSubGraphOp
|
||||
///
|
||||
/// \param index sub-graph's index in op
|
||||
/// \return pointer to ngraph::Function with sub-graph
|
||||
virtual const std::shared_ptr<Function>& get_function(int index) const {
|
||||
return m_bodies[index];
|
||||
};
|
||||
/// \brief Adds sub-graph to MultiSubGraphOp
|
||||
///
|
||||
/// \param index index of new sub-graph
|
||||
/// \param func func new sub_graph as ngraph::Function
|
||||
virtual void set_function(int index, const std::shared_ptr<Function>& func) {
|
||||
m_bodies[index] = func;
|
||||
}
|
||||
/// \brief Gets vector with connections beewtwen operation inputs
|
||||
/// and internal sub-graph parameters
|
||||
///
|
||||
/// \param index index of internal sub-graph
|
||||
/// \return vector of input descriptions
|
||||
const MultiSubgraphInputDescriptionVector& get_input_descriptions(int index) const {
|
||||
return m_input_descriptions[index];
|
||||
}
|
||||
/// \brief Gets vector with connections beewtwen operation inputs
|
||||
/// and internal sub-graph parameters
|
||||
///
|
||||
/// \param index index of internal sub-graph
|
||||
/// \return vector of input descriptions
|
||||
MultiSubgraphInputDescriptionVector& get_input_descriptions(int index) {
|
||||
return m_input_descriptions[index];
|
||||
}
|
||||
/// \brief Gets vector with connections beewtwen operation outputs
|
||||
/// and internal sub-graph results
|
||||
///
|
||||
/// \param index index of internal sub-graph
|
||||
/// \return vector of output descriptions
|
||||
const MultiSubgraphOutputDescriptionVector& get_output_descriptions(int index) const {
|
||||
return m_output_descriptions[index];
|
||||
}
|
||||
/// \brief Gets vector with connections beewtwen operation outputs
|
||||
/// and internal sub-graph results
|
||||
///
|
||||
/// \param index index of internal sub-graph
|
||||
/// \return vector of output descriptions
|
||||
MultiSubgraphOutputDescriptionVector& get_output_descriptions(int index) {
|
||||
return m_output_descriptions[index];
|
||||
}
|
||||
/// \brief Sets vector with connections beewtwen operation inputs
|
||||
/// and internal sub-graph parameters
|
||||
///
|
||||
/// \param index index of internal sub-graph
|
||||
/// \param inputs vector of input descriptions
|
||||
void set_input_descriptions(int index, const MultiSubgraphInputDescriptionVector& inputs) {
|
||||
m_input_descriptions[index] = inputs;
|
||||
}
|
||||
|
||||
/// \brief Sets vector with connections beewtwen operation outputs
|
||||
/// and internal sub-graph results
|
||||
///
|
||||
/// \param index index of internal sub-graph
|
||||
/// \param outputs vector of input descriptions
|
||||
void set_output_descriptions(int index, const MultiSubgraphOutputDescriptionVector& outputs) {
|
||||
m_output_descriptions[index] = outputs;
|
||||
}
|
||||
|
||||
///
|
||||
/// \brief Set input decriptions for MultiSubGraphOp input.
|
||||
///
|
||||
/// \param value The value supplied as an input to the block.
|
||||
/// \param bodies_parameters vector of bodies parameters.
|
||||
virtual void set_invariant_inputs(const Output<Node>& value, const ParameterVector& bodies_parameters);
|
||||
///
|
||||
/// \brief Set output decriptions for MultiSubGraphOp output.
|
||||
///
|
||||
/// \param bodies_results vector of bodies results for one output.
|
||||
/// \return value Output node for bodies_results.
|
||||
virtual Output<Node> set_body_outputs(const ResultVector& bodies_results);
|
||||
|
||||
MultiSubGraphOp(const MultiSubGraphOp&) = delete;
|
||||
MultiSubGraphOp(MultiSubGraphOp&&) = default;
|
||||
|
||||
MultiSubGraphOp& operator=(const MultiSubGraphOp&) = delete;
|
||||
MultiSubGraphOp& operator=(MultiSubGraphOp&&) = default;
|
||||
|
||||
protected:
|
||||
// Find an input corresponding to value, adding one if necessary.
|
||||
Input<Node> input_for_value(const Output<Node>& value);
|
||||
|
||||
MultiSubGraphOp(size_t number_of_bodies);
|
||||
MultiSubGraphOp() = default;
|
||||
MultiSubGraphOp(const OutputVector& args, size_t number_of_bodies);
|
||||
explicit MultiSubGraphOp(const OutputVector& args);
|
||||
|
||||
std::vector<std::shared_ptr<Function>> m_bodies;
|
||||
std::vector<MultiSubgraphInputDescriptionVector> m_input_descriptions;
|
||||
std::vector<MultiSubgraphOutputDescriptionVector> m_output_descriptions;
|
||||
};
|
||||
using MultiSubgraphInputDescriptionPtr = util::MultiSubGraphOp::MultiSubgraphInputDescriptionPtr;
|
||||
using MultiSubgraphOutputDescriptionPtr = util::MultiSubGraphOp::MultiSubgraphOutputDescriptionPtr;
|
||||
using ov::op::util::MultiSubGraphOp;
|
||||
using MultiSubgraphInputDescriptionPtr = ov::op::util::MultiSubGraphOp::InputDescription::Ptr;
|
||||
using MultiSubgraphOutputDescriptionPtr = ov::op::util::MultiSubGraphOp::OutputDescription::Ptr;
|
||||
using MultiSubgraphInputDescriptionVector = util::MultiSubGraphOp::MultiSubgraphInputDescriptionVector;
|
||||
using MultiSubgraphOutputDescriptionVector = util::MultiSubGraphOp::MultiSubgraphOutputDescriptionVector;
|
||||
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>>>
|
||||
: public DirectValueAccessor<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>>> {
|
||||
public:
|
||||
AttributeAdapter(std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>>& value)
|
||||
: DirectValueAccessor<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::InputDescription>>>(
|
||||
value) {}
|
||||
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::OutputDescription>>>
|
||||
: public DirectValueAccessor<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::OutputDescription>>> {
|
||||
public:
|
||||
AttributeAdapter(std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::OutputDescription>>& value)
|
||||
: DirectValueAccessor<std::vector<std::shared_ptr<ngraph::op::util::MultiSubGraphOp::OutputDescription>>>(
|
||||
value) {}
|
||||
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
};
|
||||
|
||||
} // namespace ov
|
||||
|
@ -5,91 +5,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/nms_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Base class for operations NmsBase and MatrixNms
|
||||
///
|
||||
class NGRAPH_API NmsBase : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
enum class SortResultType {
|
||||
CLASSID, // sort selected boxes by class id (ascending) in each batch element
|
||||
SCORE, // sort selected boxes by score (descending) in each batch element
|
||||
NONE // do not guarantee the order in each batch element
|
||||
};
|
||||
|
||||
NmsBase() = delete;
|
||||
|
||||
/// \brief Constructs a NmsBase operation
|
||||
///
|
||||
/// \param output_type Specifies the output tensor type
|
||||
/// \param nms_top_k Specifies maximum number of boxes to be selected per
|
||||
/// class, -1 meaning to keep all boxes
|
||||
/// \param keep_top_k Specifies maximum number of boxes to be selected per
|
||||
/// batch element, -1 meaning to keep all boxes
|
||||
NmsBase(ngraph::element::Type& output_type, int& nms_top_k, int& keep_top_k);
|
||||
|
||||
/// \brief Constructs a NmsBase operation
|
||||
///
|
||||
/// \param boxes Node producing the box coordinates
|
||||
/// \param scores Node producing the box scores
|
||||
/// \param output_type Specifies the output tensor type
|
||||
/// \param nms_top_k Specifies maximum number of boxes to be selected per
|
||||
/// class, -1 meaning to keep all boxes
|
||||
/// \param keep_top_k Specifies maximum number of boxes to be selected per
|
||||
/// batch element, -1 meaning to keep all boxes
|
||||
NmsBase(const Output<Node>& boxes,
|
||||
const Output<Node>& scores,
|
||||
ngraph::element::Type& output_type,
|
||||
int& nms_top_k,
|
||||
int& keep_top_k);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const element::Type& get_output_type() const {
|
||||
return m_output_type;
|
||||
}
|
||||
void set_output_type(const element::Type& output_type) {
|
||||
m_output_type = output_type;
|
||||
}
|
||||
using Node::set_output_type;
|
||||
|
||||
int get_nms_top_k() const {
|
||||
return m_nms_top_k;
|
||||
}
|
||||
|
||||
int get_keep_top_k() const {
|
||||
return m_keep_top_k;
|
||||
}
|
||||
|
||||
protected:
|
||||
ngraph::element::Type& m_output_type;
|
||||
int& m_nms_top_k;
|
||||
int& m_keep_top_k;
|
||||
virtual void validate();
|
||||
};
|
||||
using ov::op::util::NmsBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const op::util::NmsBase::SortResultType& type);
|
||||
using ov::operator<<;
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<ngraph::op::util::NmsBase::SortResultType>
|
||||
: public EnumAttributeAdapterBase<ngraph::op::util::NmsBase::SortResultType> {
|
||||
public:
|
||||
AttributeAdapter(ngraph::op::util::NmsBase::SortResultType& value)
|
||||
: EnumAttributeAdapterBase<ngraph::op::util::NmsBase::SortResultType>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::util::NmsBase::SortResultType>", 1};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ov
|
||||
|
@ -8,57 +8,20 @@
|
||||
|
||||
#include "ngraph/ngraph_visibility.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
NGRAPH_API
|
||||
bool is_unary_elementwise_arithmetic(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_arithmetic(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_comparison(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_logical(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool supports_auto_broadcast(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool is_op(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_parameter(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_output(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_sink(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_constant(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_commutative(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
NGRAPH_API
|
||||
bool supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
NGRAPH_API
|
||||
bool is_op(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_parameter(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_output(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_sink(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_constant(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_commutative(const std::shared_ptr<ngraph::Node>& node);
|
||||
using ov::op::util::is_binary_elementwise_arithmetic;
|
||||
using ov::op::util::is_binary_elementwise_comparison;
|
||||
using ov::op::util::is_binary_elementwise_logical;
|
||||
using ov::op::util::is_commutative;
|
||||
using ov::op::util::is_constant;
|
||||
using ov::op::util::is_op;
|
||||
using ov::op::util::is_output;
|
||||
using ov::op::util::is_parameter;
|
||||
using ov::op::util::is_sink;
|
||||
using ov::op::util::is_unary_elementwise_arithmetic;
|
||||
using ov::op::util::supports_auto_broadcast;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -5,31 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/reduction_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
class NGRAPH_API ReductionBase : public Op {
|
||||
protected:
|
||||
/// \brief Constructs a reduction operation.
|
||||
ReductionBase();
|
||||
|
||||
/// \brief Constructs a reduction operation.
|
||||
///
|
||||
/// \param arg Output that produces the first input tensor.
|
||||
/// \param reduction_axes The axis positions (0-based) to be eliminated.
|
||||
ReductionBase(const Output<Node>& arg, const Output<Node>& reduction_axes);
|
||||
|
||||
/// \brief Infers reduction operations output shape.
|
||||
///
|
||||
/// \param[in] keep_dims Reduction operation keeps dimensions.
|
||||
///
|
||||
/// \return Partial shape of the output.
|
||||
PartialShape infer_reduction_output_shape(const bool keep_dims);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
};
|
||||
using ov::op::util::ReductionBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -12,151 +12,14 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/activation_functions.hpp"
|
||||
#include "openvino/op/util/rnn_cell_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
enum class LSTMWeightsFormat {
|
||||
FICO, // IE
|
||||
ICOF, // PyTorch
|
||||
IFCO, // DNNL, TF, MxNet
|
||||
IFOC, // Caffe
|
||||
IOFC, // ONNX
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Change data format of provided node.
|
||||
///
|
||||
/// \param[in] node The input node to be permuted.
|
||||
///
|
||||
///
|
||||
/// \param[in] from_format Original node weights format.
|
||||
///
|
||||
///
|
||||
/// \param[in] to_format Weights format to convert to.
|
||||
///
|
||||
/// \return Node representing reshaped tensor according to `to_format` weights
|
||||
/// format.
|
||||
///
|
||||
std::shared_ptr<Node> NGRAPH_API convert_lstm_node_format(const Output<Node>& node,
|
||||
LSTMWeightsFormat from_format,
|
||||
LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO,
|
||||
int64_t axis = 0);
|
||||
|
||||
/// \brief Base class for all recurrent network cells.
|
||||
///
|
||||
/// \note It holds all common attributes.
|
||||
///
|
||||
class NGRAPH_API RNNCellBase : public Op {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
///
|
||||
/// \brief Constructs a RNNCellBase class.
|
||||
///
|
||||
/// \param[in] hidden_size The number of hidden units for recurrent cell.
|
||||
/// \param[in] clip The value defining clipping range [-clip, clip]
|
||||
/// on input of activation functions.
|
||||
/// \param[in] activations The vector of activation functions used inside
|
||||
/// recurrent cell.
|
||||
/// \param[in] activations_alpha The vector of alpha parameters for activation
|
||||
/// functions in order respective to activation list.
|
||||
/// \param[in] activations_beta The vector of beta parameters for activation
|
||||
/// functions in order respective to activation list.
|
||||
///
|
||||
RNNCellBase(const OutputVector& args,
|
||||
std::size_t hidden_size,
|
||||
float clip,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta);
|
||||
|
||||
RNNCellBase();
|
||||
virtual ~RNNCellBase() = default;
|
||||
|
||||
///
|
||||
/// \brief Validates static rank and dimension for provided input parameters.
|
||||
/// Additionally input_size dimension is checked for X and W inputs.
|
||||
///
|
||||
///
|
||||
/// \param[in] input Vector with RNN-Cell op inputs in following order:
|
||||
/// X, initial_hidden_state, W, R and B.
|
||||
///
|
||||
void validate_input_rank_dimension(const std::vector<ngraph::PartialShape>& input);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::size_t get_hidden_size() const {
|
||||
return m_hidden_size;
|
||||
}
|
||||
float get_clip() const {
|
||||
return m_clip;
|
||||
}
|
||||
const std::vector<std::string>& get_activations() const {
|
||||
return m_activations;
|
||||
}
|
||||
const std::vector<float>& get_activations_alpha() const {
|
||||
return m_activations_alpha;
|
||||
}
|
||||
const std::vector<float>& get_activations_beta() const {
|
||||
return m_activations_beta;
|
||||
}
|
||||
|
||||
protected:
|
||||
///
|
||||
/// \brief Constructs activation function object.
|
||||
///
|
||||
/// \param[in] idx The index of the activation function name.
|
||||
///
|
||||
/// \return The object representing activation function.
|
||||
///
|
||||
ActivationFunction get_activation_function(std::size_t idx) const;
|
||||
///
|
||||
/// \brief Creates node with element-wise add operation with numpy
|
||||
/// broadcasting.
|
||||
///
|
||||
/// \param[in] lhs The left hand side argument node.
|
||||
/// \param[in] rhs The right hand side argument node.
|
||||
///
|
||||
/// \return Node with element-wise add operation.
|
||||
///
|
||||
static std::shared_ptr<Node> add(const Output<Node>& lhs, const Output<Node>& rhs);
|
||||
///
|
||||
/// \brief Creates node with element-wise subtract operation with numpy
|
||||
/// broadcasting.
|
||||
///
|
||||
/// \param[in] lhs The left hand side argument node.
|
||||
/// \param[in] rhs The right hand side argument node.
|
||||
///
|
||||
/// \return Node with element-wise subtract operation.
|
||||
///
|
||||
static std::shared_ptr<Node> sub(const Output<Node>& lhs, const Output<Node>& rhs);
|
||||
///
|
||||
/// \brief Creates node with element-wise multiply operation with numpy
|
||||
/// broadcasting.
|
||||
///
|
||||
/// \param[in] lhs The left hand side argument node.
|
||||
/// \param[in] rhs The right hand side argument node.
|
||||
///
|
||||
/// \return Node with element-wise multiply operation.
|
||||
///
|
||||
static std::shared_ptr<Node> mul(const Output<Node>& lhs, const Output<Node>& rhs);
|
||||
///
|
||||
/// \brief Creates node with element-wise clip operation with numpy
|
||||
/// broadcasting.
|
||||
///
|
||||
/// \param[in] data The input tensor for clipping.
|
||||
///
|
||||
/// \return Node with element-wise clip operation.
|
||||
///
|
||||
std::shared_ptr<Node> clip(const Output<Node>& data) const;
|
||||
|
||||
protected:
|
||||
std::size_t m_hidden_size;
|
||||
float m_clip;
|
||||
std::vector<std::string> m_activations;
|
||||
std::vector<float> m_activations_alpha;
|
||||
std::vector<float> m_activations_beta;
|
||||
};
|
||||
using ov::op::util::convert_lstm_node_format;
|
||||
using ov::op::util::LSTMWeightsFormat;
|
||||
using ov::op::util::RNNCellBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -5,45 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/scatter_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
///
|
||||
/// \brief Base class for ScatterXXX operators.
|
||||
///
|
||||
class NGRAPH_API ScatterBase : public Op {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ScatterBase", 3};
|
||||
const NodeTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
virtual void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
ScatterBase() = default;
|
||||
|
||||
///
|
||||
/// \brief Constructs ScatterBase object.
|
||||
///
|
||||
/// \param inputs The input tensor to be updated.
|
||||
/// \param indices The tensor with indexes which will be updated.
|
||||
/// \param updates The tensor with update values.
|
||||
/// \param[in] axis The axis at which elements will be updated.
|
||||
///
|
||||
ScatterBase(const Output<Node>& inputs,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& updates,
|
||||
const Output<Node>& axis);
|
||||
|
||||
private:
|
||||
// Respective input ordinal number.
|
||||
static constexpr int DATA = 0;
|
||||
static constexpr int INDICES = 1;
|
||||
static constexpr int UPDATES = 2;
|
||||
static constexpr int AXIS = 3;
|
||||
};
|
||||
using ov::op::util::ScatterBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -5,38 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/scatter_nd_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
///
|
||||
/// \brief Base class for ScatterNDXXX operators.
|
||||
///
|
||||
class NGRAPH_API ScatterNDBase : public Op {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"ScatterNDBase", 3};
|
||||
const NodeTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
// Respective input ordinal number.
|
||||
static constexpr int INPUTS = 0;
|
||||
static constexpr int INDICES = 1;
|
||||
static constexpr int UPDATES = 2;
|
||||
virtual void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
ScatterNDBase() = default;
|
||||
|
||||
///
|
||||
/// \brief Constructs ScatterNDBase object.
|
||||
///
|
||||
/// \param inputs The input tensor to be updated.
|
||||
/// \param indices The tensor with indexes which will be updated.
|
||||
/// \param updates The tensor with update values.
|
||||
///
|
||||
ScatterNDBase(const Output<Node>& inputs, const Output<Node>& indices, const Output<Node>& updates);
|
||||
};
|
||||
using ov::op::util::ScatterNDBase;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -7,142 +7,14 @@
|
||||
#include <ngraph/op/parameter.hpp>
|
||||
|
||||
#include "ngraph/op/util/multi_subgraph_base.hpp"
|
||||
#include "openvino/op/util/sub_graph_base.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
/// \brief Abstract base class for sub-graph based ops, i.e ops that have only one
|
||||
/// sub-graph
|
||||
///
|
||||
class NGRAPH_API SubGraphOp : public MultiSubGraphOp {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
virtual const std::shared_ptr<Function>& get_function() const {
|
||||
return m_bodies[0];
|
||||
};
|
||||
virtual void set_function(const std::shared_ptr<Function>& func) {
|
||||
m_bodies[0] = func;
|
||||
};
|
||||
/// \return a reference to the input descriptions.
|
||||
const std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() const {
|
||||
return m_input_descriptions[0];
|
||||
}
|
||||
/// \return a reference to the input descriptions. Can add input descriptions
|
||||
/// before
|
||||
/// validation.
|
||||
std::vector<std::shared_ptr<InputDescription>>& get_input_descriptions() {
|
||||
return m_input_descriptions[0];
|
||||
}
|
||||
/// \return a reference to the output descriptions.
|
||||
const std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions() const {
|
||||
return m_output_descriptions[0];
|
||||
}
|
||||
/// \return a reference to the output descriptions. Can add output descriptions
|
||||
/// before
|
||||
/// validation.
|
||||
std::vector<std::shared_ptr<OutputDescription>>& get_output_descriptions() {
|
||||
return m_output_descriptions[0];
|
||||
}
|
||||
|
||||
///
|
||||
/// \brief Indicate that a body parameter comes from slices of a value
|
||||
///
|
||||
/// \param parameter The parameter to receive the slices
|
||||
/// \param value The value to be sliced. This will be added as an input to
|
||||
/// SubGraphOp.
|
||||
/// \param start First index on axis of the slicing
|
||||
/// \param stride Stepping of the slice
|
||||
/// \param part_size Size of the slice on axis
|
||||
/// \param end The last index on axis of the slicing
|
||||
/// \param axis The axis to slice along
|
||||
///
|
||||
virtual void set_sliced_input(const std::shared_ptr<Parameter>& parameter,
|
||||
const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
///
|
||||
/// \brief Indicates that a body parameter has an initial value in the first
|
||||
/// iteration and computed value thereafter
|
||||
///
|
||||
/// \param[in] body_parameter The body parameter
|
||||
/// \param initial_value Value for the parameter in first iteration. This
|
||||
/// will be added as an input to Loop.
|
||||
/// \param successive_value Value for the parameter in successive iterations.
|
||||
/// The value is what is active in the most recent
|
||||
/// completed iteration.
|
||||
///
|
||||
virtual void set_merged_input(const std::shared_ptr<Parameter>& body_parameter,
|
||||
const Output<Node>& initial_value,
|
||||
const Output<Node>& successive_value);
|
||||
///
|
||||
/// \brief Indicates that a body parameter has an invariant value during
|
||||
/// iteration that may depend on values computed outside of the
|
||||
/// iteration.
|
||||
///
|
||||
/// \param body_parameter The body parameter
|
||||
/// \param value The value supplied as an input to the block
|
||||
///
|
||||
virtual void set_invariant_input(const std::shared_ptr<Parameter>& body_parameter, const Output<Node>& value);
|
||||
///
|
||||
/// \brief Gets a value for a particular iteration point
|
||||
///
|
||||
/// \param body_value The value
|
||||
/// \param iteration The iteration that supplies the value. Negative values
|
||||
/// are from the last iteration.
|
||||
/// Default value -1 (the last iteration).
|
||||
///
|
||||
/// \return The iterator value.
|
||||
///
|
||||
virtual Output<Node> get_iter_value(const Output<Node>& body_value, int64_t iteration = -1);
|
||||
///
|
||||
/// \brief Concatenates slices from all iterations
|
||||
///
|
||||
/// \param value The value supplying slice values from each iteration.
|
||||
/// \param start First index on axis of the slicing
|
||||
/// \param stride Stepping of the slice
|
||||
/// \param part_size Size of the slice on axis
|
||||
/// \param end The last index on axis of the slicing
|
||||
/// \param axis The axis to slice along
|
||||
///
|
||||
/// \return The concatenated slices.
|
||||
///
|
||||
virtual Output<Node> get_concatenated_slices(const Output<Node>& value,
|
||||
int64_t start,
|
||||
int64_t stride,
|
||||
int64_t part_size,
|
||||
int64_t end,
|
||||
int64_t axis);
|
||||
|
||||
SubGraphOp(const SubGraphOp&) = delete;
|
||||
SubGraphOp(SubGraphOp&&) = default;
|
||||
|
||||
SubGraphOp& operator=(const SubGraphOp&) = delete;
|
||||
SubGraphOp& operator=(SubGraphOp&&) = default;
|
||||
|
||||
int64_t get_num_iterations() const {
|
||||
return m_num_iterations;
|
||||
}
|
||||
|
||||
protected:
|
||||
int64_t m_num_iterations = -1; // -1 means infinity for Loop op, inconsistent for TensorIterator
|
||||
|
||||
// Find an input corresponding to value, adding one if necessary.
|
||||
Input<Node> input_for_value(const Output<Node>& value);
|
||||
|
||||
SubGraphOp();
|
||||
explicit SubGraphOp(const OutputVector& args);
|
||||
|
||||
private:
|
||||
using MultiSubGraphOp::get_function;
|
||||
|
||||
using MultiSubGraphOp::set_function;
|
||||
};
|
||||
using InputDescriptionPtr = std::shared_ptr<util::SubGraphOp::InputDescription>;
|
||||
using OutputDescriptionPtr = std::shared_ptr<util::SubGraphOp::OutputDescription>;
|
||||
using ov::op::util::SubGraphOp;
|
||||
using InputDescriptionPtr = util::SubGraphOp::InputDescription::Ptr;
|
||||
using OutputDescriptionPtr = util::SubGraphOp::OutputDescription::Ptr;
|
||||
using InputDescriptionVector = std::vector<InputDescriptionPtr>;
|
||||
using OutputDescriptionVector = std::vector<OutputDescriptionPtr>;
|
||||
} // namespace util
|
||||
|
@ -5,49 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "openvino/op/util/unary_elementwise_arithmetic.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
namespace util {
|
||||
// clang-format off
|
||||
/// \brief Abstract base class for elementwise unary arithmetic operations, i.e.,
|
||||
/// operations where the same scalar arithmetic operation is applied to each
|
||||
/// element.
|
||||
///
|
||||
/// For example, if the underlying operation (determined by the subclass) is
|
||||
/// \f$\mathit{op}(x)\f$, the input tensor \f$[[x,y],[z,w]]\f$ will be mapped to
|
||||
/// \f$[[\mathit{op}(x),\mathit{op}(y)],[\mathit{op}(z),\mathit{op}(w)]]\f$.
|
||||
///
|
||||
/// ## Inputs
|
||||
///
|
||||
/// | | Type | Description |
|
||||
/// | ----- | --------------------------------- | ------------------------------------------------------------------------ |
|
||||
/// | `arg` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. The element type \f$N\f$ may be any numeric type. |
|
||||
///
|
||||
/// ## Output
|
||||
///
|
||||
/// | Type | Description |
|
||||
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensor. |
|
||||
// clang-format on
|
||||
class NGRAPH_API UnaryElementwiseArithmetic : public Op {
|
||||
protected:
|
||||
/// \brief Constructs a unary elementwise arithmetic operation.
|
||||
UnaryElementwiseArithmetic();
|
||||
/// \brief Constructs a unary elementwise arithmetic operation.
|
||||
///
|
||||
/// \param arg Output that produces the input tensor.
|
||||
UnaryElementwiseArithmetic(const Output<Node>& arg);
|
||||
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
void validate_and_infer_elementwise_arithmetic();
|
||||
};
|
||||
using ov::op::util::UnaryElementwiseArithmetic;
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -10,51 +10,11 @@
|
||||
#include "ngraph/partial_shape.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "openvino/op/util/variable.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
struct VariableInfo {
|
||||
PartialShape data_shape;
|
||||
element::Type data_type;
|
||||
std::string variable_id;
|
||||
|
||||
inline bool operator==(const VariableInfo& other) const {
|
||||
return data_shape == other.data_shape && data_type == other.data_type && variable_id == other.variable_id;
|
||||
}
|
||||
};
|
||||
|
||||
class NGRAPH_API Variable {
|
||||
public:
|
||||
Variable() = default;
|
||||
|
||||
explicit Variable(const VariableInfo& variable_info) : m_info(variable_info) {}
|
||||
|
||||
VariableInfo get_info() const {
|
||||
return m_info;
|
||||
}
|
||||
void update(const VariableInfo& variable_info) {
|
||||
m_info = variable_info;
|
||||
}
|
||||
|
||||
private:
|
||||
VariableInfo m_info;
|
||||
};
|
||||
using ov::op::util::Variable;
|
||||
using ov::op::util::VariableInfo;
|
||||
using VariablePtr = std::shared_ptr<Variable>;
|
||||
using VariableVector = std::vector<VariablePtr>;
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<ngraph::Variable>>
|
||||
: public DirectValueAccessor<std::shared_ptr<ngraph::Variable>> {
|
||||
public:
|
||||
explicit AttributeAdapter(std::shared_ptr<ngraph::Variable>& value)
|
||||
: DirectValueAccessor<std::shared_ptr<ngraph::Variable>>(value) {}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Variable>>", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ov
|
||||
|
@ -11,81 +11,9 @@
|
||||
#include "ngraph/op/util/variable_value.hpp"
|
||||
#include "ngraph/output_vector.hpp"
|
||||
#include "ngraph/variant.hpp"
|
||||
#include "openvino/op/util/variable_context.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
using VariableValuePtr = std::shared_ptr<VariableValue>;
|
||||
using VariableMap = std::unordered_map<VariablePtr, VariableValuePtr>;
|
||||
|
||||
/// VariableContext stores and manages a evaluation context for Variables.
|
||||
class NGRAPH_API VariableContext {
|
||||
public:
|
||||
/// \brief Constructs an uninitialized VariableContext.
|
||||
VariableContext() = default;
|
||||
|
||||
/// \brief Constructor for VariableContext.
|
||||
/// \param variable_values The values associated with a particular Variables.
|
||||
explicit VariableContext(const VariableMap& variable_values) : m_variable_values(variable_values) {}
|
||||
|
||||
/// \brief Sets the reset flags for all stored Variables to true.
|
||||
void reset_variable_context() const {
|
||||
for (const auto& el : m_variable_values) {
|
||||
el.second->set_reset(true);
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Sets the new values for Variables.
|
||||
/// \param variable_values The new values associated with a particular Variable.
|
||||
void set_variable_values(const VariableMap& variable_values) {
|
||||
m_variable_values = variable_values;
|
||||
}
|
||||
|
||||
/// \brief Changes/sets the values for Variable.
|
||||
/// \param variable New or stored Variable.
|
||||
/// \param variable_value The values associated with the variable.
|
||||
void set_variable_value(const VariablePtr& variable, const VariableValuePtr& variable_value) {
|
||||
m_variable_values[variable] = variable_value;
|
||||
}
|
||||
|
||||
/// \brief Removes context for a particular Variable.
|
||||
/// \param variable The variable for which the context will be cleared.
|
||||
void remove_variable_value(const VariablePtr& variable) {
|
||||
m_variable_values.erase(variable);
|
||||
}
|
||||
|
||||
/// \brief Returns the current values for Variables.
|
||||
const VariableMap& get_variable_values() const {
|
||||
return m_variable_values;
|
||||
}
|
||||
|
||||
/// \brief Returns the value for specified Variable.
|
||||
VariableValuePtr get_variable_value(const VariablePtr& variable) const {
|
||||
auto var_value = m_variable_values.find(variable);
|
||||
if (var_value != m_variable_values.end()) {
|
||||
return (*var_value).second;
|
||||
}
|
||||
return VariableValuePtr();
|
||||
}
|
||||
|
||||
private:
|
||||
/// The values associated with a particular Variable.
|
||||
VariableMap m_variable_values;
|
||||
};
|
||||
using ov::op::util::VariableContext;
|
||||
} // namespace ngraph
|
||||
|
||||
namespace ov {
|
||||
template <>
|
||||
class NGRAPH_API VariantWrapper<ngraph::VariableContext> : public VariantImpl<ngraph::VariableContext> {
|
||||
public:
|
||||
static constexpr VariantTypeInfo type_info{"Variant::EvaluationContext::VariableContext", 0};
|
||||
|
||||
const VariantTypeInfo& get_type_info() const override {
|
||||
return type_info;
|
||||
}
|
||||
|
||||
explicit VariantWrapper(const value_type& value) : VariantImpl<value_type>(value) {}
|
||||
|
||||
private:
|
||||
using Variant::init;
|
||||
using Variant::merge;
|
||||
};
|
||||
} // namespace ov
|
||||
|
@ -4,37 +4,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/runtime/host_tensor.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/op/util/variable_extension.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
class NGRAPH_API VariableExtension {
|
||||
public:
|
||||
VariableExtension() = default;
|
||||
|
||||
/// \brief Returns variable connected to this node.
|
||||
virtual std::shared_ptr<ngraph::Variable> get_variable() const {
|
||||
return m_variable;
|
||||
}
|
||||
|
||||
/// \brief Sets a new variable to be connected to this node.
|
||||
///
|
||||
/// \param variable New variable to be connected to this node.
|
||||
virtual void set_variable(const std::shared_ptr<ngraph::Variable>& variable) {
|
||||
m_variable = variable;
|
||||
}
|
||||
|
||||
/// \brief Sets the identifier to a variable
|
||||
///
|
||||
/// \param variable_id New identifier of the variable.
|
||||
virtual void set_variable_id(const std::string& variable_id) {
|
||||
m_variable->get_info().variable_id = variable_id;
|
||||
};
|
||||
|
||||
/// \brief Returns the identifier of corresponding variable.
|
||||
virtual std::string get_variable_id() const = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ngraph::Variable> m_variable;
|
||||
};
|
||||
using ov::op::util::VariableExtension;
|
||||
} // namespace ngraph
|
||||
|
@ -4,51 +4,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/runtime/host_tensor.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "openvino/op/util/variable_value.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
/// VariableValue stores data and state (reset flag) for a Variable,
|
||||
/// and provides an interface for changing them.
|
||||
class NGRAPH_API VariableValue {
|
||||
public:
|
||||
/// \brief Constructs an uninitialized VariableValue.
|
||||
VariableValue() = default;
|
||||
|
||||
/// \brief Constructor for VariableValue.
|
||||
/// \param value The data for Variable.
|
||||
explicit VariableValue(HostTensorPtr value) : m_value(std::move(value)) {}
|
||||
|
||||
/// \brief Constructor for VariableValue.
|
||||
/// \param value Data for Variable.
|
||||
/// \param reset The current state of the reset flag.
|
||||
VariableValue(HostTensorPtr value, bool reset) : m_reset(reset), m_value(std::move(value)) {}
|
||||
|
||||
/// \brief Sets the reset flag to a new state.
|
||||
/// \param reset The new state of the reset flag.
|
||||
void set_reset(bool reset) {
|
||||
m_reset = reset;
|
||||
}
|
||||
|
||||
/// \brief Returns the current reset flag state.
|
||||
bool get_reset() const {
|
||||
return m_reset;
|
||||
}
|
||||
|
||||
/// \brief Returns the current stored data.
|
||||
const HostTensorPtr& get_value() const {
|
||||
return m_value;
|
||||
}
|
||||
|
||||
/// \brief Sets new values for Variable.
|
||||
/// \param value New data for Variable.
|
||||
void set_value(const HostTensorPtr& value) {
|
||||
m_value = value;
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_reset = true;
|
||||
HostTensorPtr m_value;
|
||||
};
|
||||
using ov::op::util::VariableValue;
|
||||
using VariableValuePtr = std::shared_ptr<VariableValue>;
|
||||
} // namespace ngraph
|
||||
|
@ -5,24 +5,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
/**
|
||||
* @brief Constant folding iterates over the function and tries to evaluate nodes
|
||||
* with constant inputs. Such nodes are then replaced with new Constants containing
|
||||
* the result of a folded operation.
|
||||
*/
|
||||
class NGRAPH_API ConstantFolding : public FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
void copy_runtime_info_to_target_inputs(const std::shared_ptr<Node>& node, const Output<Node>& replacement);
|
||||
/// \brief Folds pre-calculated output tensor values to constants in case lower and
|
||||
/// upper estimations are equal. Traverses graph backwards starting from the results.
|
||||
bool pre_calculated_values_folding(const std::shared_ptr<ngraph::Function>& f);
|
||||
};
|
||||
using ov::pass::ConstantFolding;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -4,14 +4,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/convert_fp32_to_fp16.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class NGRAPH_API ConvertFP32ToFP16 : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
};
|
||||
using ov::pass::ConvertFP32ToFP16;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -10,240 +10,17 @@
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
using matcher_pass_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
||||
using graph_rewrite_callback = std::function<bool(ngraph::pattern::Matcher& m)>;
|
||||
using recurrent_graph_rewrite_callback = std::function<bool(ngraph::pattern::RecurrentMatcher& m)>;
|
||||
using handler_callback = std::function<bool(const std::shared_ptr<Node>& node)>;
|
||||
using ov::graph_rewrite_callback;
|
||||
using ov::handler_callback;
|
||||
using ov::matcher_pass_callback;
|
||||
using ov::recurrent_graph_rewrite_callback;
|
||||
namespace pass {
|
||||
/// \brief MatcherPass is a basic block for pattern based transformations. It describes
|
||||
/// pattern and
|
||||
/// action that is applied if pattern is matched.
|
||||
///
|
||||
/// MatcherPass consists of Matcher and matcher_pass_callback that needs to be implemented
|
||||
/// and
|
||||
/// finally registered by using \sa register_matcher. MatcherPass can be executed on node
|
||||
/// within
|
||||
/// \sa apply method. To run matcher pass on Function use GraphRewrite.
|
||||
/// In addition MatcherPass provides a way for adding new operations into GraphRewrite
|
||||
/// execution
|
||||
/// queue. That means that operations that were created inside transformation callback can
|
||||
/// be added
|
||||
/// for matching. To register node use \sa register_new_node method. GraphRewrite
|
||||
/// automatically
|
||||
/// takes registered nodes and put them to execution queue. If multiple nodes were register
|
||||
/// make
|
||||
/// sure that they were registered in topological order.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
class NGRAPH_API MatcherPass : public ngraph::pass::PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
MatcherPass() = default;
|
||||
|
||||
MatcherPass(const MatcherPass&) = delete;
|
||||
MatcherPass& operator=(const MatcherPass&) = delete;
|
||||
|
||||
explicit MatcherPass(const std::string& name,
|
||||
const std::shared_ptr<pattern::Matcher>& m,
|
||||
const handler_callback& handler,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE)
|
||||
: PassBase(),
|
||||
m_handler(handler),
|
||||
m_matcher(m) {
|
||||
set_name(name);
|
||||
set_property(property, true);
|
||||
}
|
||||
|
||||
bool apply(std::shared_ptr<ngraph::Node> node);
|
||||
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> register_new_node(Args&&... args) {
|
||||
auto node = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::shared_ptr<T> register_new_node(const std::shared_ptr<T>& node) {
|
||||
m_new_nodes.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ngraph::Node>>& get_new_nodes() {
|
||||
return m_new_nodes;
|
||||
}
|
||||
void clear_new_nodes() {
|
||||
m_new_nodes.clear();
|
||||
}
|
||||
std::shared_ptr<pattern::Matcher> get_matcher() {
|
||||
return m_matcher;
|
||||
}
|
||||
|
||||
protected:
|
||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
|
||||
private:
|
||||
handler_callback m_handler;
|
||||
std::shared_ptr<pattern::Matcher> m_matcher;
|
||||
std::vector<std::shared_ptr<ngraph::Node>> m_new_nodes;
|
||||
};
|
||||
|
||||
/// \brief GraphRewrite is a container for MatcherPasses that allows to run them on Function
|
||||
/// in
|
||||
/// efficient way
|
||||
///
|
||||
/// Graph rewrite pass is used for matcher passes execution on Function.
|
||||
/// To register MatcherPass use \sa add_matcher<T>(args) method where T is a MatcherPass
|
||||
/// class.
|
||||
/// As a default algorithm graph rewrite pass traverse Function in topological order and
|
||||
/// applies
|
||||
/// registered matcher passes for each node. But if all registered matcher passes have type
|
||||
/// based
|
||||
/// root node in Matcher pattern then efficient mechanism is used to execute them.
|
||||
/// Matcher pattern root is type based if it's operation from opset or
|
||||
/// pattern::op::WrapType.
|
||||
/// Note: when implementing pattern for Matcher make sure that root node is an operation
|
||||
/// from opset
|
||||
/// or has ngraph::pattern::op::WrapType. That will help GraphRewrite to execute matcher
|
||||
/// passes more
|
||||
/// efficient.
|
||||
|
||||
class NGRAPH_API GraphRewrite : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
GraphRewrite() = default;
|
||||
|
||||
explicit GraphRewrite(const std::shared_ptr<MatcherPass>& pass) : FunctionPass() {
|
||||
m_matchers.push_back(pass);
|
||||
}
|
||||
|
||||
/// \brief Register given transformation class type to GraphRewrite execution list
|
||||
/// All registered transformations will be executed in a single graph traversal.
|
||||
/// Example below show the basic usage of pass::GraphRewrite
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<MatcherPassA>();
|
||||
/// anchor->add_matcher<MatcherPassB>();
|
||||
/// anchor->set_name("CommonMatchers");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// anchor->add_matcher<MatcherPassB, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T,
|
||||
bool Enabled = true,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::MatcherPass, T>::value, bool>::type = true>
|
||||
std::shared_ptr<T> add_matcher(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::MatcherPass, T>::value, "pass not derived from MatcherPass");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
if (!Enabled && !pass_config->is_enabled<T>()) {
|
||||
pass_config->disable<T>();
|
||||
}
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// \brief Register passes from GraphRewrite class that contains sequence of matcher
|
||||
/// passes registered in its ctor.
|
||||
/// For example:
|
||||
///
|
||||
/// class ngraph::pass::LinFusions: public ngraph::pass::GraphRewrite {
|
||||
/// public:
|
||||
/// NGRAPH_RTTI_DECLARATION;
|
||||
/// Fusions() {
|
||||
/// add_matcher<ngraph::pass::AddFusion>();
|
||||
/// add_matcher<ngraph::pass::MulFusion>();
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// auto anchor = manager.register_pass<GraphRewrite>();
|
||||
/// anchor->add_matcher<LinFusions>();
|
||||
/// anchor->add_matcher<OtherFusions>();
|
||||
/// anchor->set_name("CommonFusions");
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// In this case all matcher passes from LinFusions pass will be united with other
|
||||
/// registered matchers.
|
||||
template <typename T,
|
||||
class... Args,
|
||||
typename std::enable_if<std::is_base_of<pass::GraphRewrite, T>::value, bool>::type = true>
|
||||
void add_matcher(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::GraphRewrite, T>::value, "pass not derived from GraphRewrite");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_config = get_pass_config();
|
||||
|
||||
for (auto& matcher : pass->m_matchers) {
|
||||
pass->set_pass_config(pass_config);
|
||||
m_matchers.push_back(matcher);
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const ngraph::graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
NGRAPH_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ngraph::graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
||||
protected:
|
||||
bool apply_matcher_passes(std::shared_ptr<Function> f, std::deque<std::weak_ptr<Node>> nodes_to_run);
|
||||
|
||||
bool m_enable_shape_inference = false;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
|
||||
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
BackwardGraphRewrite() = default;
|
||||
|
||||
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass) : GraphRewrite(pass) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
|
||||
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
RecurrentGraphRewrite(size_t num_iters = 10) : FunctionPass(), m_num_iters(num_iters) {}
|
||||
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
// TODO: This interface may deprecate after all passes are refactored.
|
||||
void add_matcher(const std::shared_ptr<pattern::RecurrentMatcher>& m,
|
||||
const ngraph::recurrent_graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
size_t m_num_iters;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
using ov::pass::BackwardGraphRewrite;
|
||||
using ov::pass::GraphRewrite;
|
||||
using ov::pass::MatcherPass;
|
||||
using ov::pass::RecurrentGraphRewrite;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -5,10 +5,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <ngraph/pass/pass.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/low_latency.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
/**
|
||||
@ -46,38 +48,6 @@ public:
|
||||
LowLatency();
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief The transformation finds all TensorIterator/Loop layers in the network,
|
||||
* processes all back edges that describe a connection between Result and Parameter
|
||||
* of the TensorIterator/Loop bodies,and inserts ReadValue and Assign layers at the
|
||||
* input and output corresponding to this back edge.
|
||||
* Supported platforms: CPU, GNA.
|
||||
*
|
||||
* The example below describes the changes made by the transformation
|
||||
* [] - TensorIterator body
|
||||
* () - new layer
|
||||
* BE - back-edge
|
||||
*
|
||||
* before applying the transformation:
|
||||
* -> input1[BE_1 -> Parameter -> Layers ... -> Result -> BE_1 ]output1->
|
||||
*
|
||||
* after applying the transformation:
|
||||
* ->(ReadValue)-> input1[BE_1 ->Parameter->Layers ...->Result->BE_1]output1 ->(Assign)
|
||||
* \
|
||||
* ->...
|
||||
* After applying the transformation, the resulting network can be inferred
|
||||
* step by step, the states will store between inferences.
|
||||
*/
|
||||
class NGRAPH_API LowLatency2 : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
explicit LowLatency2(bool use_const_initializer = true) : m_use_const_initializer(use_const_initializer) {}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
||||
private:
|
||||
bool m_use_const_initializer;
|
||||
};
|
||||
using ov::pass::LowLatency2;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -11,106 +11,10 @@
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pass/validate.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class NGRAPH_API Manager {
|
||||
public:
|
||||
Manager();
|
||||
~Manager();
|
||||
|
||||
//// \brief Construct Manager with shared PassConfig instance
|
||||
explicit Manager(std::shared_ptr<PassConfig> pass_config);
|
||||
|
||||
/// \brief Register given transformation class type to execution list
|
||||
/// Example below show the basic usage of pass::Manager
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<MyTransformation>(/*transformation constructor ars*/);
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// For some purposes transformation can be registered and disabled by default.
|
||||
///
|
||||
/// manager.register_pass<MyTransformation, false>();
|
||||
///
|
||||
/// \return shared_ptr to the transformation instance
|
||||
template <typename T, bool Enable = true, class... Args>
|
||||
std::shared_ptr<T> register_pass(Args&&... args) {
|
||||
auto rc = push_pass<T>(std::forward<Args>(args)...);
|
||||
rc->set_pass_config(m_pass_config);
|
||||
if (m_per_pass_validation) {
|
||||
push_pass<Validate>();
|
||||
}
|
||||
if (!Enable && !m_pass_config->is_enabled<T>()) {
|
||||
m_pass_config->disable<T>();
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Function>);
|
||||
|
||||
void set_pass_visualization(bool new_state) {
|
||||
m_visualize = new_state;
|
||||
}
|
||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||
/// each registered pass
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
void set_per_pass_validation(bool new_state) {
|
||||
m_per_pass_validation = new_state;
|
||||
}
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations based on some conditions. In some cases plugins may want not to
|
||||
/// execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance
|
||||
/// reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) !=
|
||||
/// nullptr;
|
||||
/// };
|
||||
/// This callback returns true in case of DepthToSpace operation. So when execution
|
||||
/// DepthToSpace
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute
|
||||
/// this
|
||||
/// operation directly. And of course on transformation side we need to have a response
|
||||
/// for this
|
||||
/// callback.
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by
|
||||
/// plugin and
|
||||
/// transformation is not needed
|
||||
NGRAPH_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to
|
||||
/// particular
|
||||
/// transformation. For mo details see PassConfig class.
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
|
||||
protected:
|
||||
template <typename T, class... Args>
|
||||
std::shared_ptr<T> push_pass(Args&&... args) {
|
||||
static_assert(std::is_base_of<pass::PassBase, T>::value, "pass not derived from pass base");
|
||||
auto pass = std::make_shared<T>(std::forward<Args>(args)...);
|
||||
auto pass_base = std::static_pointer_cast<PassBase>(pass);
|
||||
m_pass_list.push_back(pass_base);
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
std::vector<std::shared_ptr<PassBase>> m_pass_list;
|
||||
bool m_visualize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
};
|
||||
using ov::pass::Manager;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -13,105 +13,32 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pass/pass_config.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class Manager;
|
||||
|
||||
}
|
||||
} // namespace ov
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
enum class PassProperty : uint32_t {
|
||||
// Pass requires node shapes to be static
|
||||
REQUIRE_STATIC_SHAPE = 0x1,
|
||||
// Pass transformation will change the function's dynamic state
|
||||
CHANGE_DYNAMIC_STATE = 1 << 1,
|
||||
};
|
||||
|
||||
typedef EnumMask<PassProperty> PassPropertyMask;
|
||||
using ov::pass::FunctionPass;
|
||||
using ov::pass::FusionType;
|
||||
using ov::pass::FusionTypeMask;
|
||||
using ov::pass::Manager;
|
||||
using ov::pass::PassBase;
|
||||
using ov::pass::PassProperty;
|
||||
using ov::pass::PassPropertyMask;
|
||||
NGRAPH_DEPRECATED("This variable is deprecated and will be removed soon.")
|
||||
const PassPropertyMask all_pass_property_off;
|
||||
|
||||
class NGRAPH_API PassBase {
|
||||
friend class Manager;
|
||||
|
||||
public:
|
||||
PassBase();
|
||||
virtual ~PassBase() {}
|
||||
/// Check if this pass has all the pass properties.
|
||||
bool get_property(const PassPropertyMask& prop_mask) const;
|
||||
|
||||
void set_name(const std::string& name) {
|
||||
m_name = name;
|
||||
}
|
||||
std::string get_name() const;
|
||||
|
||||
/// \brief Set callback for particular transformation type.
|
||||
/// This method set global callback. For more details see PassConfig class
|
||||
/// documentation.
|
||||
/// \param callback lambda function that takes node and returns bool
|
||||
void set_callback(const param_callback& callback);
|
||||
|
||||
/// \brief Set PassConfig for particular transformation instance
|
||||
/// \param pass_config is a PassConfig shared_ptr
|
||||
virtual void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) {
|
||||
m_pass_config = pass_config;
|
||||
}
|
||||
|
||||
/// \brief Allows to access PassConfig shared instance
|
||||
/// \return Shared instance of PassConfig class
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// This method remains here only for backward compatibility and will be removed
|
||||
/// after all transformations are moved to transformation_callback() method.
|
||||
/// \return result of callback execution for given node
|
||||
NGRAPH_DEPRECATED("Please use transformation_callback method instead")
|
||||
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// \param node which will be used inside callback
|
||||
/// \return result of callback execution for given node
|
||||
bool transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
protected:
|
||||
void set_property(const PassPropertyMask& prop, bool value);
|
||||
|
||||
private:
|
||||
PassPropertyMask m_property;
|
||||
|
||||
std::string m_name;
|
||||
std::shared_ptr<PassConfig> m_pass_config;
|
||||
};
|
||||
|
||||
class NGRAPH_API FunctionPass : public PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
virtual ~FunctionPass();
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function>) = 0;
|
||||
};
|
||||
|
||||
class NGRAPH_DEPRECATED("Use MatcherPass or FunctionPass instead.") NGRAPH_API NodePass : public PassBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
virtual ~NodePass();
|
||||
~NodePass() override;
|
||||
virtual bool run_on_node(std::shared_ptr<ngraph::Node>) = 0;
|
||||
};
|
||||
|
||||
class Manager;
|
||||
enum class FusionType : uint32_t {
|
||||
//`DIFFERENTIABLE_FUSIONS` produce ops that support autodiff
|
||||
// i.e. implement `generate_adjoints`
|
||||
DIFFERENTIABLE_FUSIONS = 0x1,
|
||||
REGULAR_FUSIONS = 0x2,
|
||||
//`FOP_FUSIONS` produce ops in the FusedOps category that might
|
||||
// not be supported by all backends
|
||||
FOP_FUSIONS = 0x4,
|
||||
ALL_FUSIONS = 0xFFFFFFFF
|
||||
};
|
||||
typedef EnumMask<FusionType> FusionTypeMask;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -12,164 +12,12 @@
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "openvino/pass/pass_config.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
using param_callback = std::function<bool(const std::shared_ptr<const ::ngraph::Node>)>;
|
||||
using param_callback_map = std::map<ngraph::DiscreteTypeInfo, param_callback>;
|
||||
|
||||
/// \brief Class representing a transformations config that is used for disabling/enabling
|
||||
/// transformations registered inside pass::Manager and also allows to set callback for all
|
||||
/// transformations or for particular transformation.
|
||||
///
|
||||
/// When pass::Manager is created all passes registered inside this manager including nested
|
||||
/// passes will share the same instance of PassConfig class.
|
||||
/// To work with this class first you need to get shared instance of this class by calling
|
||||
/// manager.get_pass_config() method. Then you will be able to disable/enable passes based
|
||||
/// on transformations type_info. For example:
|
||||
///
|
||||
/// pass::Manager manager;
|
||||
/// manager.register_pass<CommonOptimizations>();
|
||||
/// auto pass_config = manager.get_pass_config();
|
||||
/// pass_config->disable<ConvertGELU>(); // this will disable nested pass inside
|
||||
/// // CommonOptimizations pipeline
|
||||
/// manager.run_passes(f);
|
||||
///
|
||||
/// Sometimes it is needed to call transformation inside other transformation manually. And
|
||||
/// for that case before running transformation you need manually check that this pass is
|
||||
/// not disabled and then you need to set current PassConfig instance to this
|
||||
/// transformation. For example:
|
||||
///
|
||||
/// // Inside MatcherPass callback or inside FunctionPass run_on_function() method
|
||||
/// // you need to call get_pass_config() method to get shared instance of PassConfig
|
||||
/// auto pass_config = get_pass_config();
|
||||
///
|
||||
/// // Before running nested transformation you need to check is it disabled or not
|
||||
/// if (!pass_config->is_disabled<ConvertGELU>()) {
|
||||
/// auto pass = ConvertGELU();
|
||||
/// pass->set_pass_config(pass_config);
|
||||
/// pass.apply(node);
|
||||
/// }
|
||||
///
|
||||
/// Following this logic inside your transformations you will guaranty that transformations
|
||||
/// will be executed in a right way.
|
||||
class NGRAPH_API PassConfig {
|
||||
public:
|
||||
/// \brief Disable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void disable(const DiscreteTypeInfo& type_info);
|
||||
/// \brief Disable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void disable() {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
disable(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Enable transformation by its type_info
|
||||
/// \param type_info Transformation type_info
|
||||
void enable(const DiscreteTypeInfo& type_info);
|
||||
/// \brief Enable transformation by its class type (based on type_info)
|
||||
template <typename T>
|
||||
void enable() {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
enable(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Set callback for all kind of transformations
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_callback = callback;
|
||||
}
|
||||
template <typename... Args>
|
||||
typename std::enable_if<sizeof...(Args) == 0>::type set_callback(const param_callback& callback) {}
|
||||
|
||||
/// \brief Set callback for particular transformation class types
|
||||
///
|
||||
/// Example below show how to set callback for one or multiple passes using this method.
|
||||
///
|
||||
/// pass_config->set_callback<ngraph::pass::ConvertBatchToSpace,
|
||||
/// ngraph::pass::ConvertSpaceToBatch>(
|
||||
/// [](const_node_ptr &node) -> bool {
|
||||
/// // Disable transformations for cases when input shape rank is not
|
||||
/// equal to 4
|
||||
/// const auto input_shape_rank =
|
||||
/// node->get_output_partial_shape(0).rank().get_length();
|
||||
/// if (input_shape_rank != 4) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// return true;
|
||||
/// });
|
||||
///
|
||||
/// Note that inside transformations you must provide code that work with this callback.
|
||||
/// See example below:
|
||||
///
|
||||
/// if (transformation_callback(node)) {
|
||||
/// return false; // exit from transformation
|
||||
/// }
|
||||
///
|
||||
template <typename T, class... Args>
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_callback_map[T::type_info] = callback;
|
||||
set_callback<Args...>(callback);
|
||||
}
|
||||
|
||||
/// \brief Get callback for given transformation type_info
|
||||
/// \param type_info Transformation type_info
|
||||
///
|
||||
/// In case if callback wasn't set for given transformation type then global callback
|
||||
/// will be returned. But if even global callback wasn't set then default callback will
|
||||
/// be returned.
|
||||
param_callback get_callback(const DiscreteTypeInfo& type_info) const;
|
||||
|
||||
/// \brief Get callback for given transformation class type
|
||||
/// \return callback lambda function
|
||||
template <typename T>
|
||||
param_callback get_callback() const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
return get_callback(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is disabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
bool is_disabled(const DiscreteTypeInfo& type_info) const {
|
||||
return m_disabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is disabled or not
|
||||
/// \return true if transformation type was disabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_disabled() const {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
return is_disabled(T::type_info);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
/// \brief Check either transformation type is force enabled or not
|
||||
/// \param type_info Transformation type_info
|
||||
/// \return true if transformation type was force enabled and false otherwise
|
||||
bool is_enabled(const DiscreteTypeInfo& type_info) const {
|
||||
return m_enabled.count(type_info);
|
||||
}
|
||||
|
||||
/// \brief Check either transformation class type is force enabled or not
|
||||
/// \return true if transformation type was force enabled and false otherwise
|
||||
template <typename T>
|
||||
bool is_enabled() const {
|
||||
return is_enabled(T::type_info);
|
||||
}
|
||||
|
||||
void add_disabled_passes(const PassConfig& rhs);
|
||||
|
||||
private:
|
||||
param_callback m_callback = [](const std::shared_ptr<const ::ngraph::Node>&) {
|
||||
return false;
|
||||
};
|
||||
param_callback_map m_callback_map;
|
||||
std::unordered_set<DiscreteTypeInfo> m_disabled;
|
||||
std::unordered_set<DiscreteTypeInfo> m_enabled;
|
||||
};
|
||||
using ov::pass::param_callback;
|
||||
using ov::pass::param_callback_map;
|
||||
using ov::pass::PassConfig;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -5,27 +5,10 @@
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "openvino/pass/validate.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
/// \brief The Validate pass performs sanity checks on attributes and inputs, and
|
||||
/// computes output shapes and element types for all computation nodes in a given
|
||||
/// computation graph.
|
||||
///
|
||||
/// \details The verification and inference is done via invoking each node's specific
|
||||
/// implementation of \link ngraph::Node::validate_and_infer_types() \endlink function.
|
||||
///
|
||||
/// By default, the \ref ngraph::pass::Manager runs this pass after executing every
|
||||
/// optimization pass. This is to ensure that any update to the graph by an optimization
|
||||
/// pass does not break the shape and data type requirement on a computation node.
|
||||
/// This default validation run can be changed via calling the
|
||||
/// \link ngraph::pass::Manager::set_per_pass_validation(bool) \endlink function.
|
||||
class NGRAPH_API Validate : public FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Validate() : FunctionPass() {}
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
using ov::pass::Validate;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -14,44 +14,10 @@
|
||||
#include <utility>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
class HeightMap;
|
||||
|
||||
using visualize_tree_ops_map_t =
|
||||
std::unordered_map<ngraph::Node::type_info_t, std::function<void(const ngraph::Node&, std::ostream& ss)>>;
|
||||
#include "openvino/pass/visualize_tree.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
class NGRAPH_API VisualizeTree : public FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
using node_modifiers_t = std::function<void(const Node& node, std::vector<std::string>& attributes)>;
|
||||
VisualizeTree(const std::string& file_name, node_modifiers_t nm = nullptr, bool dot_only = false);
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
|
||||
void set_ops_to_details(const visualize_tree_ops_map_t& ops_map) {
|
||||
m_ops_to_details = ops_map;
|
||||
}
|
||||
|
||||
protected:
|
||||
void add_node_arguments(std::shared_ptr<Node> node,
|
||||
std::unordered_map<Node*, HeightMap>& height_maps,
|
||||
size_t& fake_node_ctr);
|
||||
std::string add_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_attributes(std::shared_ptr<Node> node);
|
||||
virtual std::string get_node_name(std::shared_ptr<Node> node);
|
||||
std::string get_constant_value(std::shared_ptr<Node> node, size_t max_elements = 7);
|
||||
|
||||
void render() const;
|
||||
|
||||
std::stringstream m_ss;
|
||||
std::string m_name;
|
||||
std::set<std::shared_ptr<Node>> m_nodes_with_attributes;
|
||||
visualize_tree_ops_map_t m_ops_to_details;
|
||||
node_modifiers_t m_node_modifiers = nullptr;
|
||||
bool m_dot_only;
|
||||
static const int max_jump_distance;
|
||||
};
|
||||
using ov::pass::VisualizeTree;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
@ -16,255 +16,21 @@
|
||||
#include "ngraph/pattern/op/any_output.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class GraphRewrite;
|
||||
}
|
||||
} // namespace ov
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
using ov::pass::GraphRewrite;
|
||||
}
|
||||
|
||||
namespace pattern {
|
||||
class Matcher;
|
||||
|
||||
class NGRAPH_API MatcherState {
|
||||
public:
|
||||
MatcherState(Matcher*);
|
||||
bool finish(bool is_successful);
|
||||
~MatcherState();
|
||||
|
||||
protected:
|
||||
Matcher* m_matcher;
|
||||
PatternValueMap m_pattern_value_map;
|
||||
PatternValueMaps m_pattern_value_maps;
|
||||
size_t m_watermark;
|
||||
size_t m_capture_size;
|
||||
bool m_restore{true};
|
||||
};
|
||||
|
||||
/// Matcher looks for node patterns in a computation graph. The patterns are described by an
|
||||
/// automaton that is described by an extended computation graph. The matcher executes
|
||||
/// by attempting to match the start node of the pattern to a computation graph value
|
||||
/// (output of a Node). In addition to determing if a match occurs, a pattern node may add
|
||||
/// graph nodes to a list of matched nodes, associate nodes with graph values, and start
|
||||
/// submatches. Submatches add match state changes to the enclosing match if the submatch
|
||||
/// succeeds; otherwise the state is reverted.
|
||||
///
|
||||
/// The default match behavior of a pattern node with a graph nodes is that the computation
|
||||
/// graph value is added to the end of the matched value list and the match succeeds if the
|
||||
/// node/pattern types match and the input values match. In the case of a commutative node,
|
||||
/// the inputs can match in any order. If the matcher is in strict mode, the graph value
|
||||
/// element type and shape must also match.
|
||||
///
|
||||
/// Pattern nodes that have different match behavior are in ngraph::pattern::op and have
|
||||
/// descriptions of their match behavior.
|
||||
class NGRAPH_API Matcher {
|
||||
public:
|
||||
using PatternMap = ngraph::pattern::PatternMap;
|
||||
|
||||
// Avoid implicit string construction from nullptr.
|
||||
Matcher(const std::shared_ptr<Node> pattern_node, std::nullptr_t name) = delete;
|
||||
|
||||
Matcher() {}
|
||||
Matcher(Output<Node>& pattern_node) : m_pattern_node{pattern_node} {}
|
||||
|
||||
Matcher(Output<Node>& pattern_node, const std::string& name) : m_pattern_node(pattern_node), m_name{name} {}
|
||||
|
||||
/// \brief Constructs a Matcher object
|
||||
///
|
||||
/// \param pattern_node is a pattern sub graph that will be matched against input graphs
|
||||
/// \param name is a string which is used for logging and disabling a matcher
|
||||
/// \param strict_mode forces a matcher to consider shapes and ET of nodes
|
||||
Matcher(const Output<Node>& pattern_node, const std::string& name, bool strict_mode)
|
||||
: m_pattern_node(pattern_node),
|
||||
m_name(name),
|
||||
m_strict_mode(strict_mode) {}
|
||||
|
||||
// Some matches should start on a node rather than an output. These three constructors
|
||||
// are transition until we work out the right way to do that.
|
||||
Matcher(std::shared_ptr<Node> pattern_node);
|
||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name);
|
||||
Matcher(std::shared_ptr<Node> pattern_node, const std::string& name, bool strict_mode);
|
||||
|
||||
virtual ~Matcher() {}
|
||||
/// \brief Matches a pattern to \p graph_node
|
||||
///
|
||||
/// \param graph_value is an input graph to be matched against
|
||||
bool match(const Output<Node>& graph_value);
|
||||
|
||||
bool match(std::shared_ptr<Node> graph_node);
|
||||
|
||||
/// \brief Matches a pattern to \p graph_node
|
||||
///
|
||||
/// \param graph_value is an input graph to be matched against
|
||||
/// \param previous_matches contains previous mappings from labels to nodes to use
|
||||
bool match(const Output<Node>& graph_value, const PatternMap& previous_matches);
|
||||
bool match(const Output<Node>& graph_value, const PatternValueMap& previous_matches);
|
||||
|
||||
template <typename T>
|
||||
static std::shared_ptr<T> unique_match(std::shared_ptr<Node> node) {
|
||||
std::shared_ptr<T> matched;
|
||||
for (auto arg : node->input_values()) {
|
||||
if (auto t_casted = ov::as_type_ptr<T>(arg.get_node_shared_ptr())) {
|
||||
if (matched) {
|
||||
throw ngraph_error("There's more than two arguments of the same type");
|
||||
} else {
|
||||
matched = t_casted;
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched;
|
||||
}
|
||||
|
||||
bool is_contained_match(const NodeVector& exclusions = {}, bool ignore_unused = true);
|
||||
const NodeVector get_matched_nodes() {
|
||||
return as_node_vector(m_matched_list);
|
||||
}
|
||||
const OutputVector& get_matched_values() const {
|
||||
return m_matched_list;
|
||||
}
|
||||
OutputVector& get_matched_values() {
|
||||
return m_matched_list;
|
||||
}
|
||||
void reset() {}
|
||||
const std::string& get_name() {
|
||||
return m_name;
|
||||
}
|
||||
std::shared_ptr<Node> get_pattern() {
|
||||
return m_pattern_node.get_node_shared_ptr();
|
||||
}
|
||||
Output<Node> get_pattern_value() {
|
||||
return m_pattern_node;
|
||||
}
|
||||
std::shared_ptr<Node> get_match_root();
|
||||
Output<Node> get_match_value();
|
||||
PatternMap get_pattern_map() const;
|
||||
PatternValueMap& get_pattern_value_map() {
|
||||
return m_pattern_map;
|
||||
}
|
||||
PatternValueMaps& get_pattern_value_maps() {
|
||||
return m_pattern_value_maps;
|
||||
}
|
||||
/// \brief Low-level helper to match recurring patterns
|
||||
///
|
||||
/// \param graph is a graph to be matched against
|
||||
/// \param pattern is a recurring pattern
|
||||
/// \param rpattern specifies a node to recur from next
|
||||
/// \param patterns a map from labels to matches
|
||||
|
||||
size_t add_node(Output<Node> node);
|
||||
|
||||
bool virtual match_value(const ngraph::Output<Node>& pattern_value, const ngraph::Output<Node>& graph_value);
|
||||
|
||||
bool is_strict_mode() {
|
||||
return m_strict_mode;
|
||||
}
|
||||
virtual bool match_arguments(Node* pattern_node, const std::shared_ptr<Node>& graph_node);
|
||||
|
||||
void capture(const std::set<Node*>& static_nodes);
|
||||
|
||||
void clear_state();
|
||||
|
||||
size_t get_number_of_recurrent_matches() const {
|
||||
return m_pattern_value_maps.size();
|
||||
}
|
||||
NodeVector get_bound_nodes_for_pattern(const Output<Node>& pattern) const;
|
||||
size_t get_number_of_bound_labels() const;
|
||||
/// \brief Try a match
|
||||
MatcherState start_match();
|
||||
|
||||
Output<Node> m_match_root;
|
||||
Output<Node> m_pattern_node;
|
||||
PatternValueMap m_pattern_map;
|
||||
PatternValueMaps m_pattern_value_maps;
|
||||
OutputVector m_matched_list;
|
||||
|
||||
protected:
|
||||
bool match_permutation(const OutputVector& pattern_args, const OutputVector& args);
|
||||
|
||||
std::string m_name{"unnamed"};
|
||||
bool m_strict_mode{false};
|
||||
};
|
||||
|
||||
class NGRAPH_API RecurrentMatcher {
|
||||
public:
|
||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||
///
|
||||
/// \param initial_pattern is a pattern sub graph describing the initial cell
|
||||
/// \param pattern is a pattern sub graph describing an individual cell
|
||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||
/// start at
|
||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||
/// across all cells
|
||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||
const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||
: m_initial_pattern(initial_pattern),
|
||||
m_pattern(pattern),
|
||||
m_recurrent_pattern(rpattern),
|
||||
m_correlated_patterns(correlated_patterns) {}
|
||||
|
||||
/// \brief Constructs a RecurrentMatcher object. Reccurent Matchers are used to match
|
||||
/// repeating patterns (e.g. RNN, LSTM, GRU cells)
|
||||
///
|
||||
/// \param pattern is a pattern sub graph describing an individual cell
|
||||
/// \param rpattern is a (recurring) label to denote which node the next match should
|
||||
/// start at
|
||||
/// \param correlated_patterns is a set of labels whose bound nodes must remain the same
|
||||
/// across all cells
|
||||
RecurrentMatcher(const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<Node>>& correlated_patterns)
|
||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||
|
||||
RecurrentMatcher(const Output<Node>& initial_pattern,
|
||||
const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns);
|
||||
|
||||
RecurrentMatcher(const Output<Node>& pattern,
|
||||
const std::shared_ptr<Node>& rpattern,
|
||||
const std::set<std::shared_ptr<op::Label>>& correlated_patterns)
|
||||
: RecurrentMatcher(pattern, pattern, rpattern, correlated_patterns) {}
|
||||
|
||||
/// \brief Returns a vector of bound nodes for a given label (used in a pattern
|
||||
/// describing an individual cell
|
||||
NodeVector get_bound_nodes_for_pattern(const std::shared_ptr<Node>& pattern) const {
|
||||
if (m_matches.count(pattern) == 0) {
|
||||
throw ngraph_error("No bound nodes for a given label");
|
||||
}
|
||||
|
||||
return as_node_vector(m_matches.at(pattern));
|
||||
}
|
||||
|
||||
size_t get_number_of_recurrent_matches() const {
|
||||
if (m_matches.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return (*m_matches.begin()).second.size();
|
||||
}
|
||||
|
||||
size_t get_number_of_bound_labels() const {
|
||||
return m_matches.size();
|
||||
}
|
||||
/// \brief Tries to match a pattern for an individual cell to a given \p graph
|
||||
bool match(Output<Node> graph);
|
||||
|
||||
std::shared_ptr<Node> get_match_root() {
|
||||
return m_match_root.get_node_shared_ptr();
|
||||
}
|
||||
Output<Node> get_match_value() {
|
||||
return m_match_root;
|
||||
}
|
||||
|
||||
private:
|
||||
Output<Node> m_initial_pattern;
|
||||
Output<Node> m_pattern;
|
||||
std::shared_ptr<Node> m_recurrent_pattern;
|
||||
const std::set<std::shared_ptr<Node>> m_correlated_patterns;
|
||||
RPatternValueMap m_matches;
|
||||
Output<Node> m_match_root;
|
||||
};
|
||||
using ov::pass::pattern::Matcher;
|
||||
using ov::pass::pattern::MatcherState;
|
||||
using ov::pass::pattern::RecurrentMatcher;
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,38 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/any.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is to the matched value list. If the predicate is true for the node
|
||||
/// and the arguments match, the match succeeds.
|
||||
class NGRAPH_API Any : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAny", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a Any node containing a sub-pattern described by \sa type and \sa
|
||||
/// shape.
|
||||
Any(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Pattern(wrapped_values, pred) {
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
Any(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Any(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
/// \brief creates a Any node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
Any(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Any(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||
Any(const Output<Node>& node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Any(node.get_element_type(),
|
||||
node.get_partial_shape(),
|
||||
as_value_predicate(pred),
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::Any;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,47 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/any_of.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// The graph value is added to the matched values list. If the predicate is true for
|
||||
/// the
|
||||
/// graph node, a submatch is performed on the input of AnyOf and each input of the
|
||||
/// graph node. The first match that succeeds results in a successful match. Otherwise
|
||||
/// the match fails.
|
||||
///
|
||||
/// AnyOf may be given a type and shape for use in strict mode.
|
||||
class NGRAPH_API AnyOf : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
|
||||
/// \sa shape.
|
||||
AnyOf(const element::Type& type, const PartialShape& s, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Pattern(wrapped_values, pred) {
|
||||
if (wrapped_values.size() != 1) {
|
||||
throw ngraph_error("AnyOf expects exactly one argument");
|
||||
}
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
AnyOf(const element::Type& type, const PartialShape& s, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: AnyOf(
|
||||
type,
|
||||
s,
|
||||
[pred](const Output<Node>& value) {
|
||||
return pred(value.get_node_shared_ptr());
|
||||
},
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
/// \brief creates a AnyOf node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
AnyOf(const Output<Node>& node, ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values) {}
|
||||
AnyOf(std::shared_ptr<Node> node, NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::AnyOf;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,23 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/any_output.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Matches any output of a node
|
||||
class NGRAPH_API AnyOutput : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternAnyOutput", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates an AnyOutput node matching any output of a node
|
||||
/// \param node The node to match
|
||||
AnyOutput(const std::shared_ptr<Node>& pattern) : Pattern({pattern->output(0)}) {}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
};
|
||||
using ov::pass::pattern::op::AnyOutput;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,48 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/branch.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// A branch adds a loop to the pattern. The branch match is successful if the
|
||||
/// destination node pattern matches the graph value. The destination node is a node in
|
||||
/// the pattern graph that will not have been created some time after the Branch node is
|
||||
/// created; use set_destination to add it.
|
||||
///
|
||||
/// The branch destination is not stored as a shared pointer to prevent reference
|
||||
/// cycles. Thus the destination node must be referenced in some other way to prevent it
|
||||
/// from being deleted.
|
||||
class NGRAPH_API Branch : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternBranch", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief Creates a Branch pattern
|
||||
/// \param pattern the destinationing pattern
|
||||
/// \param labels Labels where the destination may occur
|
||||
Branch() : Pattern(OutputVector{}) {
|
||||
set_output_type(0, element::f32, Shape{});
|
||||
}
|
||||
|
||||
void set_destination(const Output<Node>& destination) {
|
||||
m_destination_node = destination.get_node();
|
||||
m_destination_index = destination.get_index();
|
||||
}
|
||||
|
||||
Output<Node> get_destination() const {
|
||||
return m_destination_node == nullptr
|
||||
? Output<Node>()
|
||||
: Output<Node>{m_destination_node->shared_from_this(), m_destination_index};
|
||||
}
|
||||
|
||||
bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
Node* m_destination_node{nullptr};
|
||||
size_t m_destination_index{0};
|
||||
};
|
||||
using ov::pass::pattern::op::Branch;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,37 +6,12 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/capture.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Experimental for support of recurrent matches.
|
||||
///
|
||||
/// Capture adds the pattern value map to a list of pattern value maps and resets
|
||||
/// matches for pattern nodes not in the static node list. The match always succeeds.
|
||||
class NGRAPH_API Capture : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternCapture", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
Capture(const Output<Node>& arg) : Pattern({arg}) {
|
||||
set_output_type(0, arg.get_element_type(), arg.get_partial_shape());
|
||||
}
|
||||
|
||||
/// \brief static nodes are retained after a capture. All other nodes are dropped
|
||||
std::set<Node*> get_static_nodes() {
|
||||
return m_static_nodes;
|
||||
}
|
||||
void set_static_nodes(const std::set<Node*>& static_nodes) {
|
||||
m_static_nodes = static_nodes;
|
||||
}
|
||||
|
||||
virtual bool match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
std::set<Node*> m_static_nodes;
|
||||
};
|
||||
using ov::pass::pattern::op::Capture;
|
||||
} // namespace op
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
@ -6,106 +6,14 @@
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pattern {
|
||||
namespace op {
|
||||
/// Fails if the predicate returns false on the graph value.
|
||||
///
|
||||
/// The graph value is added to the matched values list. If the Label is already
|
||||
/// associated with a value, the match succeeds if the value is the same as the graph
|
||||
/// value. Otherwise, the label is associated with the graph value and the match
|
||||
/// succeeds if the pattern input matches the graph value.
|
||||
///
|
||||
/// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
|
||||
/// more than one inputs are given, an Or pattern of the inputs serves as the input.
|
||||
class NGRAPH_API Label : public Pattern {
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"patternLabel", 0};
|
||||
const NodeTypeInfo& get_type_info() const override;
|
||||
/// \brief creates a Label node containing a sub-pattern described by \sa type and
|
||||
/// \sa shape.
|
||||
///
|
||||
/// this Label node can be bound only to the nodes in the input graph
|
||||
/// that match the pattern specified by \sa wrapped_nodes
|
||||
/// Example:
|
||||
/// \code{.cpp}
|
||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||
/// auto label = std::make_shared<pattern::op::Label>(element::f32,
|
||||
/// Shape{2,2},
|
||||
/// nullptr,
|
||||
/// OutputVector{add});
|
||||
/// \endcode
|
||||
Label(const element::Type& type,
|
||||
const PartialShape& s,
|
||||
const ValuePredicate pred,
|
||||
const OutputVector& wrapped_values)
|
||||
: Pattern(OutputVector{wrap_values(wrapped_values)}, pred) {
|
||||
set_output_type(0, type, s);
|
||||
}
|
||||
|
||||
explicit Label(const element::Type& type = element::dynamic, const PartialShape& s = PartialShape::dynamic())
|
||||
: Label(
|
||||
type,
|
||||
s,
|
||||
[](const Output<Node>&) {
|
||||
return true;
|
||||
},
|
||||
OutputVector()) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
|
||||
: Label(type, s, pred, OutputVector{}) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
|
||||
: Label(type, s, as_value_predicate(pred), OutputVector{}) {}
|
||||
|
||||
Label(const element::Type& type, const PartialShape& s, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values)) {}
|
||||
|
||||
/// \brief creates a Label node containing a sub-pattern described by the type and
|
||||
/// shape of \sa node.
|
||||
///
|
||||
/// this Label node can be bound only to the nodes in the input graph
|
||||
/// that match the pattern specified by \sa wrapped_values
|
||||
/// Example:
|
||||
/// \code{.cpp}
|
||||
/// auto add = a + b; // a and b are op::Parameter in this example
|
||||
/// auto label = std::make_shared<pattern::op::Label>(add,
|
||||
/// nullptr,
|
||||
/// OutputVector{add});
|
||||
/// \endcode
|
||||
Label(const Output<Node>& value, const ValuePredicate pred, const OutputVector& wrapped_values)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, wrapped_values) {}
|
||||
Label(const Output<Node>& value, const ValuePredicate pred)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), pred, OutputVector{}) {}
|
||||
|
||||
Label(const Output<Node>& value, const NodePredicate pred)
|
||||
: Label(value.get_element_type(), value.get_partial_shape(), as_value_predicate(pred), OutputVector{}) {}
|
||||
Label(const Output<Node>& value)
|
||||
: Label(
|
||||
value.get_element_type(),
|
||||
value.get_partial_shape(),
|
||||
[](const Output<Node>&) {
|
||||
return true;
|
||||
},
|
||||
OutputVector{}) {}
|
||||
Label(const Output<Node>& node, const NodePredicate pred, const NodeVector& wrapped_values)
|
||||
: Label(node.get_element_type(),
|
||||
node.get_partial_shape(),
|
||||
as_value_predicate(pred),
|
||||
as_output_vector(wrapped_values)) {}
|
||||
|
||||
bool match_value(Matcher* matcher, const Output<Node>& pattern_value, const Output<Node>& graph_value) override;
|
||||
|
||||
protected:
|
||||
static Output<Node> wrap_values(const OutputVector& wrapped_values);
|
||||
};
|
||||
using ov::pass::pattern::op::Label;
|
||||
} // namespace op
|
||||
|
||||
NGRAPH_API
|
||||
std::shared_ptr<Node> any_input();
|
||||
|
||||
NGRAPH_API
|
||||
std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
|
||||
using ov::pass::pattern::any_input;
|
||||
} // namespace pattern
|
||||
} // namespace ngraph
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user