Add LogSoftmax-5 to ngraph (#2645)
This commit is contained in:
@@ -244,6 +244,7 @@ packages = [
|
||||
"ngraph.opset2",
|
||||
"ngraph.opset3",
|
||||
"ngraph.opset4",
|
||||
"ngraph.opset5",
|
||||
"ngraph.utils",
|
||||
"ngraph.impl",
|
||||
"ngraph.impl.op",
|
||||
|
||||
@@ -28,146 +28,147 @@ from ngraph.impl import Function
|
||||
from ngraph.helpers import function_from_cnn
|
||||
from ngraph.helpers import function_to_cnn
|
||||
|
||||
from ngraph.opset4 import absolute
|
||||
from ngraph.opset4 import absolute as abs
|
||||
from ngraph.opset4 import acos
|
||||
from ngraph.opset4 import acosh
|
||||
from ngraph.opset4 import add
|
||||
from ngraph.opset4 import asin
|
||||
from ngraph.opset4 import asinh
|
||||
from ngraph.opset4 import assign
|
||||
from ngraph.opset4 import atan
|
||||
from ngraph.opset4 import atanh
|
||||
from ngraph.opset4 import avg_pool
|
||||
from ngraph.opset4 import batch_norm_inference
|
||||
from ngraph.opset4 import batch_to_space
|
||||
from ngraph.opset4 import binary_convolution
|
||||
from ngraph.opset4 import broadcast
|
||||
from ngraph.opset4 import bucketize
|
||||
from ngraph.opset4 import ceiling
|
||||
from ngraph.opset4 import ceiling as ceil
|
||||
from ngraph.opset4 import clamp
|
||||
from ngraph.opset4 import concat
|
||||
from ngraph.opset4 import constant
|
||||
from ngraph.opset4 import convert
|
||||
from ngraph.opset4 import convert_like
|
||||
from ngraph.opset4 import convolution
|
||||
from ngraph.opset4 import convolution_backprop_data
|
||||
from ngraph.opset4 import cos
|
||||
from ngraph.opset4 import cosh
|
||||
from ngraph.opset4 import ctc_greedy_decoder
|
||||
from ngraph.opset4 import ctc_loss
|
||||
from ngraph.opset4 import cum_sum
|
||||
from ngraph.opset4 import cum_sum as cumsum
|
||||
from ngraph.opset4 import deformable_convolution
|
||||
from ngraph.opset4 import deformable_psroi_pooling
|
||||
from ngraph.opset4 import depth_to_space
|
||||
from ngraph.opset4 import detection_output
|
||||
from ngraph.opset4 import divide
|
||||
from ngraph.opset4 import elu
|
||||
from ngraph.opset4 import embedding_bag_offsets_sum
|
||||
from ngraph.opset4 import embedding_bag_packed_sum
|
||||
from ngraph.opset4 import embedding_segments_sum
|
||||
from ngraph.opset4 import extract_image_patches
|
||||
from ngraph.opset4 import equal
|
||||
from ngraph.opset4 import erf
|
||||
from ngraph.opset4 import exp
|
||||
from ngraph.opset4 import fake_quantize
|
||||
from ngraph.opset4 import floor
|
||||
from ngraph.opset4 import floor_mod
|
||||
from ngraph.opset4 import gather
|
||||
from ngraph.opset4 import gather_tree
|
||||
from ngraph.opset4 import gelu
|
||||
from ngraph.opset4 import greater
|
||||
from ngraph.opset4 import greater_equal
|
||||
from ngraph.opset4 import grn
|
||||
from ngraph.opset4 import group_convolution
|
||||
from ngraph.opset4 import group_convolution_backprop_data
|
||||
from ngraph.opset4 import gru_cell
|
||||
from ngraph.opset4 import hard_sigmoid
|
||||
from ngraph.opset4 import hswish
|
||||
from ngraph.opset4 import interpolate
|
||||
from ngraph.opset4 import less
|
||||
from ngraph.opset4 import less_equal
|
||||
from ngraph.opset4 import log
|
||||
from ngraph.opset4 import logical_and
|
||||
from ngraph.opset4 import logical_not
|
||||
from ngraph.opset4 import logical_or
|
||||
from ngraph.opset4 import logical_xor
|
||||
from ngraph.opset4 import lrn
|
||||
from ngraph.opset4 import lstm_cell
|
||||
from ngraph.opset4 import lstm_sequence
|
||||
from ngraph.opset4 import matmul
|
||||
from ngraph.opset4 import max_pool
|
||||
from ngraph.opset4 import maximum
|
||||
from ngraph.opset4 import minimum
|
||||
from ngraph.opset4 import mish
|
||||
from ngraph.opset4 import mod
|
||||
from ngraph.opset4 import multiply
|
||||
from ngraph.opset4 import mvn
|
||||
from ngraph.opset4 import negative
|
||||
from ngraph.opset4 import non_max_suppression
|
||||
from ngraph.opset4 import non_zero
|
||||
from ngraph.opset4 import normalize_l2
|
||||
from ngraph.opset4 import not_equal
|
||||
from ngraph.opset4 import one_hot
|
||||
from ngraph.opset4 import pad
|
||||
from ngraph.opset4 import parameter
|
||||
from ngraph.opset4 import power
|
||||
from ngraph.opset4 import prelu
|
||||
from ngraph.opset4 import prior_box
|
||||
from ngraph.opset4 import prior_box_clustered
|
||||
from ngraph.opset4 import psroi_pooling
|
||||
from ngraph.opset4 import proposal
|
||||
from ngraph.opset4 import range
|
||||
from ngraph.opset4 import read_value
|
||||
from ngraph.opset4 import reduce_l1
|
||||
from ngraph.opset4 import reduce_l2
|
||||
from ngraph.opset4 import reduce_logical_and
|
||||
from ngraph.opset4 import reduce_logical_or
|
||||
from ngraph.opset4 import reduce_max
|
||||
from ngraph.opset4 import reduce_mean
|
||||
from ngraph.opset4 import reduce_min
|
||||
from ngraph.opset4 import reduce_prod
|
||||
from ngraph.opset4 import reduce_sum
|
||||
from ngraph.opset4 import region_yolo
|
||||
from ngraph.opset4 import reorg_yolo
|
||||
from ngraph.opset4 import relu
|
||||
from ngraph.opset4 import reshape
|
||||
from ngraph.opset4 import result
|
||||
from ngraph.opset4 import reverse_sequence
|
||||
from ngraph.opset4 import rnn_cell
|
||||
from ngraph.opset4 import roi_align
|
||||
from ngraph.opset4 import roi_pooling
|
||||
from ngraph.opset4 import scatter_elements_update
|
||||
from ngraph.opset4 import scatter_update
|
||||
from ngraph.opset4 import select
|
||||
from ngraph.opset4 import selu
|
||||
from ngraph.opset4 import shape_of
|
||||
from ngraph.opset4 import shuffle_channels
|
||||
from ngraph.opset4 import sigmoid
|
||||
from ngraph.opset4 import sign
|
||||
from ngraph.opset4 import sin
|
||||
from ngraph.opset4 import sinh
|
||||
from ngraph.opset4 import softmax
|
||||
from ngraph.opset4 import softplus
|
||||
from ngraph.opset4 import space_to_batch
|
||||
from ngraph.opset4 import space_to_depth
|
||||
from ngraph.opset4 import split
|
||||
from ngraph.opset4 import sqrt
|
||||
from ngraph.opset4 import squared_difference
|
||||
from ngraph.opset4 import squeeze
|
||||
from ngraph.opset4 import strided_slice
|
||||
from ngraph.opset4 import subtract
|
||||
from ngraph.opset4 import swish
|
||||
from ngraph.opset4 import tan
|
||||
from ngraph.opset4 import tanh
|
||||
from ngraph.opset4 import tensor_iterator
|
||||
from ngraph.opset4 import tile
|
||||
from ngraph.opset4 import topk
|
||||
from ngraph.opset4 import transpose
|
||||
from ngraph.opset4 import unsqueeze
|
||||
from ngraph.opset4 import variadic_split
|
||||
from ngraph.opset5 import absolute
|
||||
from ngraph.opset5 import absolute as abs
|
||||
from ngraph.opset5 import acos
|
||||
from ngraph.opset5 import acosh
|
||||
from ngraph.opset5 import add
|
||||
from ngraph.opset5 import asin
|
||||
from ngraph.opset5 import asinh
|
||||
from ngraph.opset5 import assign
|
||||
from ngraph.opset5 import atan
|
||||
from ngraph.opset5 import atanh
|
||||
from ngraph.opset5 import avg_pool
|
||||
from ngraph.opset5 import batch_norm_inference
|
||||
from ngraph.opset5 import batch_to_space
|
||||
from ngraph.opset5 import binary_convolution
|
||||
from ngraph.opset5 import broadcast
|
||||
from ngraph.opset5 import bucketize
|
||||
from ngraph.opset5 import ceiling
|
||||
from ngraph.opset5 import ceiling as ceil
|
||||
from ngraph.opset5 import clamp
|
||||
from ngraph.opset5 import concat
|
||||
from ngraph.opset5 import constant
|
||||
from ngraph.opset5 import convert
|
||||
from ngraph.opset5 import convert_like
|
||||
from ngraph.opset5 import convolution
|
||||
from ngraph.opset5 import convolution_backprop_data
|
||||
from ngraph.opset5 import cos
|
||||
from ngraph.opset5 import cosh
|
||||
from ngraph.opset5 import ctc_greedy_decoder
|
||||
from ngraph.opset5 import ctc_loss
|
||||
from ngraph.opset5 import cum_sum
|
||||
from ngraph.opset5 import cum_sum as cumsum
|
||||
from ngraph.opset5 import deformable_convolution
|
||||
from ngraph.opset5 import deformable_psroi_pooling
|
||||
from ngraph.opset5 import depth_to_space
|
||||
from ngraph.opset5 import detection_output
|
||||
from ngraph.opset5 import divide
|
||||
from ngraph.opset5 import elu
|
||||
from ngraph.opset5 import embedding_bag_offsets_sum
|
||||
from ngraph.opset5 import embedding_bag_packed_sum
|
||||
from ngraph.opset5 import embedding_segments_sum
|
||||
from ngraph.opset5 import extract_image_patches
|
||||
from ngraph.opset5 import equal
|
||||
from ngraph.opset5 import erf
|
||||
from ngraph.opset5 import exp
|
||||
from ngraph.opset5 import fake_quantize
|
||||
from ngraph.opset5 import floor
|
||||
from ngraph.opset5 import floor_mod
|
||||
from ngraph.opset5 import gather
|
||||
from ngraph.opset5 import gather_tree
|
||||
from ngraph.opset5 import gelu
|
||||
from ngraph.opset5 import greater
|
||||
from ngraph.opset5 import greater_equal
|
||||
from ngraph.opset5 import grn
|
||||
from ngraph.opset5 import group_convolution
|
||||
from ngraph.opset5 import group_convolution_backprop_data
|
||||
from ngraph.opset5 import gru_cell
|
||||
from ngraph.opset5 import hard_sigmoid
|
||||
from ngraph.opset5 import hswish
|
||||
from ngraph.opset5 import interpolate
|
||||
from ngraph.opset5 import less
|
||||
from ngraph.opset5 import less_equal
|
||||
from ngraph.opset5 import log
|
||||
from ngraph.opset5 import logical_and
|
||||
from ngraph.opset5 import logical_not
|
||||
from ngraph.opset5 import logical_or
|
||||
from ngraph.opset5 import logical_xor
|
||||
from ngraph.opset5 import log_softmax
|
||||
from ngraph.opset5 import lrn
|
||||
from ngraph.opset5 import lstm_cell
|
||||
from ngraph.opset5 import lstm_sequence
|
||||
from ngraph.opset5 import matmul
|
||||
from ngraph.opset5 import max_pool
|
||||
from ngraph.opset5 import maximum
|
||||
from ngraph.opset5 import minimum
|
||||
from ngraph.opset5 import mish
|
||||
from ngraph.opset5 import mod
|
||||
from ngraph.opset5 import multiply
|
||||
from ngraph.opset5 import mvn
|
||||
from ngraph.opset5 import negative
|
||||
from ngraph.opset5 import non_max_suppression
|
||||
from ngraph.opset5 import non_zero
|
||||
from ngraph.opset5 import normalize_l2
|
||||
from ngraph.opset5 import not_equal
|
||||
from ngraph.opset5 import one_hot
|
||||
from ngraph.opset5 import pad
|
||||
from ngraph.opset5 import parameter
|
||||
from ngraph.opset5 import power
|
||||
from ngraph.opset5 import prelu
|
||||
from ngraph.opset5 import prior_box
|
||||
from ngraph.opset5 import prior_box_clustered
|
||||
from ngraph.opset5 import psroi_pooling
|
||||
from ngraph.opset5 import proposal
|
||||
from ngraph.opset5 import range
|
||||
from ngraph.opset5 import read_value
|
||||
from ngraph.opset5 import reduce_l1
|
||||
from ngraph.opset5 import reduce_l2
|
||||
from ngraph.opset5 import reduce_logical_and
|
||||
from ngraph.opset5 import reduce_logical_or
|
||||
from ngraph.opset5 import reduce_max
|
||||
from ngraph.opset5 import reduce_mean
|
||||
from ngraph.opset5 import reduce_min
|
||||
from ngraph.opset5 import reduce_prod
|
||||
from ngraph.opset5 import reduce_sum
|
||||
from ngraph.opset5 import region_yolo
|
||||
from ngraph.opset5 import reorg_yolo
|
||||
from ngraph.opset5 import relu
|
||||
from ngraph.opset5 import reshape
|
||||
from ngraph.opset5 import result
|
||||
from ngraph.opset5 import reverse_sequence
|
||||
from ngraph.opset5 import rnn_cell
|
||||
from ngraph.opset5 import roi_align
|
||||
from ngraph.opset5 import roi_pooling
|
||||
from ngraph.opset5 import scatter_elements_update
|
||||
from ngraph.opset5 import scatter_update
|
||||
from ngraph.opset5 import select
|
||||
from ngraph.opset5 import selu
|
||||
from ngraph.opset5 import shape_of
|
||||
from ngraph.opset5 import shuffle_channels
|
||||
from ngraph.opset5 import sigmoid
|
||||
from ngraph.opset5 import sign
|
||||
from ngraph.opset5 import sin
|
||||
from ngraph.opset5 import sinh
|
||||
from ngraph.opset5 import softmax
|
||||
from ngraph.opset5 import softplus
|
||||
from ngraph.opset5 import space_to_batch
|
||||
from ngraph.opset5 import space_to_depth
|
||||
from ngraph.opset5 import split
|
||||
from ngraph.opset5 import sqrt
|
||||
from ngraph.opset5 import squared_difference
|
||||
from ngraph.opset5 import squeeze
|
||||
from ngraph.opset5 import strided_slice
|
||||
from ngraph.opset5 import subtract
|
||||
from ngraph.opset5 import swish
|
||||
from ngraph.opset5 import tan
|
||||
from ngraph.opset5 import tanh
|
||||
from ngraph.opset5 import tensor_iterator
|
||||
from ngraph.opset5 import tile
|
||||
from ngraph.opset5 import topk
|
||||
from ngraph.opset5 import transpose
|
||||
from ngraph.opset5 import unsqueeze
|
||||
from ngraph.opset5 import variadic_split
|
||||
|
||||
|
||||
# Extend Node class to support binary operators
|
||||
|
||||
158
ngraph/python/src/ngraph/opset5/__init__.py
Normal file
158
ngraph/python/src/ngraph/opset5/__init__.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2017-2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
|
||||
|
||||
from ngraph.opset1.ops import absolute
|
||||
from ngraph.opset1.ops import absolute as abs
|
||||
from ngraph.opset1.ops import acos
|
||||
from ngraph.opset4.ops import acosh
|
||||
from ngraph.opset1.ops import add
|
||||
from ngraph.opset1.ops import asin
|
||||
from ngraph.opset4.ops import asinh
|
||||
from ngraph.opset3.ops import assign
|
||||
from ngraph.opset1.ops import atan
|
||||
from ngraph.opset4.ops import atanh
|
||||
from ngraph.opset1.ops import avg_pool
|
||||
from ngraph.opset1.ops import batch_norm_inference
|
||||
from ngraph.opset2.ops import batch_to_space
|
||||
from ngraph.opset1.ops import binary_convolution
|
||||
from ngraph.opset3.ops import broadcast
|
||||
from ngraph.opset3.ops import bucketize
|
||||
from ngraph.opset1.ops import ceiling
|
||||
from ngraph.opset1.ops import ceiling as ceil
|
||||
from ngraph.opset1.ops import clamp
|
||||
from ngraph.opset1.ops import concat
|
||||
from ngraph.opset1.ops import constant
|
||||
from ngraph.opset1.ops import convert
|
||||
from ngraph.opset1.ops import convert_like
|
||||
from ngraph.opset1.ops import convolution
|
||||
from ngraph.opset1.ops import convolution_backprop_data
|
||||
from ngraph.opset1.ops import cos
|
||||
from ngraph.opset1.ops import cosh
|
||||
from ngraph.opset1.ops import ctc_greedy_decoder
|
||||
from ngraph.opset4.ops import ctc_loss
|
||||
from ngraph.opset3.ops import cum_sum
|
||||
from ngraph.opset3.ops import cum_sum as cumsum
|
||||
from ngraph.opset1.ops import deformable_convolution
|
||||
from ngraph.opset1.ops import deformable_psroi_pooling
|
||||
from ngraph.opset1.ops import depth_to_space
|
||||
from ngraph.opset1.ops import detection_output
|
||||
from ngraph.opset1.ops import divide
|
||||
from ngraph.opset1.ops import elu
|
||||
from ngraph.opset3.ops import embedding_bag_offsets_sum
|
||||
from ngraph.opset3.ops import embedding_bag_packed_sum
|
||||
from ngraph.opset3.ops import embedding_segments_sum
|
||||
from ngraph.opset3.ops import extract_image_patches
|
||||
from ngraph.opset1.ops import equal
|
||||
from ngraph.opset1.ops import erf
|
||||
from ngraph.opset1.ops import exp
|
||||
from ngraph.opset1.ops import fake_quantize
|
||||
from ngraph.opset1.ops import floor
|
||||
from ngraph.opset1.ops import floor_mod
|
||||
from ngraph.opset1.ops import gather
|
||||
from ngraph.opset1.ops import gather_tree
|
||||
from ngraph.opset2.ops import gelu
|
||||
from ngraph.opset1.ops import greater
|
||||
from ngraph.opset1.ops import greater_equal
|
||||
from ngraph.opset1.ops import grn
|
||||
from ngraph.opset1.ops import group_convolution
|
||||
from ngraph.opset1.ops import group_convolution_backprop_data
|
||||
from ngraph.opset3.ops import gru_cell
|
||||
from ngraph.opset1.ops import hard_sigmoid
|
||||
from ngraph.opset4.ops import hswish
|
||||
from ngraph.opset1.ops import interpolate
|
||||
from ngraph.opset1.ops import less
|
||||
from ngraph.opset1.ops import less_equal
|
||||
from ngraph.opset1.ops import log
|
||||
from ngraph.opset1.ops import logical_and
|
||||
from ngraph.opset1.ops import logical_not
|
||||
from ngraph.opset1.ops import logical_or
|
||||
from ngraph.opset1.ops import logical_xor
|
||||
from ngraph.opset5.ops import log_softmax
|
||||
from ngraph.opset1.ops import lrn
|
||||
from ngraph.opset4.ops import lstm_cell
|
||||
from ngraph.opset1.ops import lstm_sequence
|
||||
from ngraph.opset1.ops import matmul
|
||||
from ngraph.opset1.ops import max_pool
|
||||
from ngraph.opset1.ops import maximum
|
||||
from ngraph.opset1.ops import minimum
|
||||
from ngraph.opset4.ops import mish
|
||||
from ngraph.opset1.ops import mod
|
||||
from ngraph.opset1.ops import multiply
|
||||
from ngraph.opset2.ops import mvn
|
||||
from ngraph.opset1.ops import negative
|
||||
from ngraph.opset4.ops import non_max_suppression
|
||||
from ngraph.opset3.ops import non_zero
|
||||
from ngraph.opset1.ops import normalize_l2
|
||||
from ngraph.opset1.ops import not_equal
|
||||
from ngraph.opset1.ops import one_hot
|
||||
from ngraph.opset1.ops import pad
|
||||
from ngraph.opset1.ops import parameter
|
||||
from ngraph.opset1.ops import power
|
||||
from ngraph.opset1.ops import prelu
|
||||
from ngraph.opset1.ops import prior_box
|
||||
from ngraph.opset1.ops import prior_box_clustered
|
||||
from ngraph.opset1.ops import psroi_pooling
|
||||
from ngraph.opset4.ops import proposal
|
||||
from ngraph.opset1.ops import range
|
||||
from ngraph.opset3.ops import read_value
|
||||
from ngraph.opset4.ops import reduce_l1
|
||||
from ngraph.opset4.ops import reduce_l2
|
||||
from ngraph.opset1.ops import reduce_logical_and
|
||||
from ngraph.opset1.ops import reduce_logical_or
|
||||
from ngraph.opset1.ops import reduce_max
|
||||
from ngraph.opset1.ops import reduce_mean
|
||||
from ngraph.opset1.ops import reduce_min
|
||||
from ngraph.opset1.ops import reduce_prod
|
||||
from ngraph.opset1.ops import reduce_sum
|
||||
from ngraph.opset1.ops import region_yolo
|
||||
from ngraph.opset2.ops import reorg_yolo
|
||||
from ngraph.opset1.ops import relu
|
||||
from ngraph.opset1.ops import reshape
|
||||
from ngraph.opset1.ops import result
|
||||
from ngraph.opset1.ops import reverse_sequence
|
||||
from ngraph.opset3.ops import rnn_cell
|
||||
from ngraph.opset3.ops import roi_align
|
||||
from ngraph.opset2.ops import roi_pooling
|
||||
from ngraph.opset3.ops import scatter_elements_update
|
||||
from ngraph.opset3.ops import scatter_update
|
||||
from ngraph.opset1.ops import select
|
||||
from ngraph.opset1.ops import selu
|
||||
from ngraph.opset3.ops import shape_of
|
||||
from ngraph.opset3.ops import shuffle_channels
|
||||
from ngraph.opset1.ops import sigmoid
|
||||
from ngraph.opset1.ops import sign
|
||||
from ngraph.opset1.ops import sin
|
||||
from ngraph.opset1.ops import sinh
|
||||
from ngraph.opset1.ops import softmax
|
||||
from ngraph.opset4.ops import softplus
|
||||
from ngraph.opset2.ops import space_to_batch
|
||||
from ngraph.opset1.ops import space_to_depth
|
||||
from ngraph.opset1.ops import split
|
||||
from ngraph.opset1.ops import sqrt
|
||||
from ngraph.opset1.ops import squared_difference
|
||||
from ngraph.opset1.ops import squeeze
|
||||
from ngraph.opset1.ops import strided_slice
|
||||
from ngraph.opset1.ops import subtract
|
||||
from ngraph.opset4.ops import swish
|
||||
from ngraph.opset1.ops import tan
|
||||
from ngraph.opset1.ops import tanh
|
||||
from ngraph.opset1.ops import tensor_iterator
|
||||
from ngraph.opset1.ops import tile
|
||||
from ngraph.opset3.ops import topk
|
||||
from ngraph.opset1.ops import transpose
|
||||
from ngraph.opset1.ops import unsqueeze
|
||||
from ngraph.opset1.ops import variadic_split
|
||||
69
ngraph/python/src/ngraph/opset5/ops.py
Normal file
69
ngraph/python/src/ngraph/opset5/ops.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2017-2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
|
||||
"""Factory functions for all ngraph ops."""
|
||||
from typing import Callable, Iterable, List, Optional, Set, Union
|
||||
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ngraph.impl import Node, Shape
|
||||
from ngraph.impl.op import Constant, Parameter
|
||||
from ngraph.opset_utils import _get_node_factory
|
||||
from ngraph.utils.decorators import binary_op, nameable_op, unary_op
|
||||
from ngraph.utils.input_validation import (
|
||||
assert_list_of_ints,
|
||||
check_valid_attributes,
|
||||
is_non_negative_value,
|
||||
is_positive_value,
|
||||
)
|
||||
from ngraph.utils.node_factory import NodeFactory
|
||||
from ngraph.utils.tensor_iterator_types import (
|
||||
GraphBody,
|
||||
TensorIteratorSliceInputDesc,
|
||||
TensorIteratorMergedInputDesc,
|
||||
TensorIteratorInvariantInputDesc,
|
||||
TensorIteratorBodyOutputDesc,
|
||||
TensorIteratorConcatOutputDesc,
|
||||
)
|
||||
from ngraph.utils.types import (
|
||||
NodeInput,
|
||||
NumericData,
|
||||
NumericType,
|
||||
ScalarData,
|
||||
TensorShape,
|
||||
as_node,
|
||||
as_nodes,
|
||||
get_dtype,
|
||||
get_element_type,
|
||||
get_element_type_str,
|
||||
make_constant_node,
|
||||
)
|
||||
|
||||
_get_node_factory_opset5 = partial(_get_node_factory, "opset5")
|
||||
|
||||
# -------------------------------------------- ops ------------------------------------------------
|
||||
|
||||
|
||||
@nameable_op
|
||||
def log_softmax(data: NodeInput, axis: int, name: Optional[str] = None) -> Node:
|
||||
"""Apply LogSoftmax operation on each element of input tensor.
|
||||
|
||||
:param data: The tensor providing input data.
|
||||
:param axis: An axis along which LogSoftmax should be calculated
|
||||
:return: The new node with LogSoftmax operation applied on each element.
|
||||
"""
|
||||
return _get_node_factory_opset5().create("LogSoftmax", [as_node(data)], {"axis": axis})
|
||||
@@ -5,7 +5,7 @@ from _pyngraph import NodeFactory as _NodeFactory
|
||||
|
||||
from ngraph.impl import Node
|
||||
|
||||
DEFAULT_OPSET = "opset4"
|
||||
DEFAULT_OPSET = "opset5"
|
||||
|
||||
|
||||
class NodeFactory(object):
|
||||
|
||||
@@ -91,6 +91,7 @@ namespace
|
||||
{"opset2", OpsetFunction(ngraph::get_opset2)},
|
||||
{"opset3", OpsetFunction(ngraph::get_opset3)},
|
||||
{"opset4", OpsetFunction(ngraph::get_opset4)},
|
||||
{"opset5", OpsetFunction(ngraph::get_opset5)},
|
||||
};
|
||||
|
||||
auto it = s_opsets.find(opset_ver);
|
||||
|
||||
29
ngraph/python/tests/test_ngraph/test_log_softmax.py
Normal file
29
ngraph/python/tests/test_ngraph/test_log_softmax.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# ******************************************************************************
|
||||
# Copyright 2020 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ******************************************************************************
|
||||
import numpy as np
|
||||
import ngraph as ng
|
||||
from ngraph.impl import Shape, Type
|
||||
|
||||
|
||||
def test_log_softmax():
|
||||
float_dtype = np.float32
|
||||
data = ng.parameter(Shape([3, 10]), dtype=float_dtype, name="data")
|
||||
|
||||
node = ng.log_softmax(data, 1)
|
||||
assert node.get_type_name() == "LogSoftmax"
|
||||
assert node.get_output_size() == 1
|
||||
assert list(node.get_output_shape(0)) == [3, 10]
|
||||
assert node.get_output_element_type(0) == Type.f32
|
||||
Reference in New Issue
Block a user