Extend nGraph for operation Softmax-8 (#8157)

This commit is contained in:
Yegor Kruglov 2021-12-09 16:21:55 +03:00 committed by GitHub
parent 12d322807f
commit 7cb1bd6a1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 480 additions and 23 deletions

View File

@ -60,8 +60,9 @@ void shape_inference(ov::Node* op,
ov::is_type<ov::opset1::LogicalNot>(op) || ov::is_type<ov::opset4::Mish>(op) ||
ov::is_type<ov::opset2::MVN>(op) || ov::is_type<ov::opset6::MVN>(op) ||
ov::is_type<ov::opset1::PRelu>(op) || ov::is_type<ov::opset1::Relu>(op) ||
ov::is_type<ov::opset4::Swish>(op) || ov::is_type<ov::opset1::Softmax>(op) ||
ov::is_type<ov::opset1::Elu>(op) || ov::is_type<ov::opset5::Round>(op)) {
ov::is_type<ov::opset4::Swish>(op) || ov::is_type<ov::opset1::Elu>(op) ||
ov::is_type<ov::opset1::Softmax>(op) || ov::is_type<ov::opset8::Softmax>(op) ||
ov::is_type<ov::opset5::Round>(op)) {
copy_shape_infer(node, input_shapes, output_shapes);
} else if (ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(op) ||
ov::is_type<ov::op::util::BinaryElementwiseComparison>(op) || ov::is_type<ov::op::util::BinaryElementwiseLogical>(op)) {

View File

@ -0,0 +1,95 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_softmax_downgrade.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST_F(TransformationTestsF, ConvertSoftMax8ToSoftMax1) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
int64_t axis = 1;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertSoftMax8ToSoftMax1>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
size_t axis = 1;
auto softmax_1 = std::make_shared<ngraph::opset1::Softmax>(data, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data});
}
}
TEST_F(TransformationTestsF, ConvertSoftMax8ToSoftMax1_negative_axis) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
int64_t axis = -1;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertSoftMax8ToSoftMax1>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
size_t axis = 1;
auto softmax_1 = std::make_shared<ngraph::opset1::Softmax>(data, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data});
}
}
TEST_F(TransformationTestsF, ConvertSoftMax8ToSoftMax1_input_rank_5) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5, 5, 5});
int64_t axis = -2;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertSoftMax8ToSoftMax1>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 5, 5, 5});
size_t axis = 3;
auto softmax_1 = std::make_shared<ngraph::opset1::Softmax>(data, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data});
}
}
TEST_F(TransformationTestsF, negative_ConvertSoftMax8ToSoftMax1_dynamic_rank) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
int64_t axis = -3;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertSoftMax8ToSoftMax1>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
int64_t axis = -3;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
}
}

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/op_conversions/convert_softmax_upgrade.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST_F(TransformationTestsF, ConvertSoftMax1ToSoftMax8) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
size_t axis = 1;
auto softmax_1 = std::make_shared<ngraph::opset1::Softmax>(data, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertSoftMax1ToSoftMax8>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
int64_t axis = 1;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
}
}
TEST_F(TransformationTestsF, ConvertSoftMax1ToSoftMax8_dynamic_rank) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
size_t axis = 1;
auto softmax_1 = std::make_shared<ngraph::opset1::Softmax>(data, axis);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data});
manager.register_pass<ngraph::pass::ConvertSoftMax1ToSoftMax8>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
int64_t axis = 1;
auto softmax_8 = std::make_shared<ngraph::opset8::Softmax>(data, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data});
}
}

View File

@ -0,0 +1,27 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertSoftMax8ToSoftMax1;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertSoftMax8ToSoftMax1 converts v8::SoftMax into v1::SoftMax.
*/
class ngraph::pass::ConvertSoftMax8ToSoftMax1 : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertSoftMax8ToSoftMax1();
};

View File

@ -0,0 +1,28 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertSoftMax1ToSoftMax8;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief ConvertSoftMax1ToSoftMax8 converts v1::SoftMax into v8::SoftMax.
*/
class ngraph::pass::ConvertSoftMax1ToSoftMax8 : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertSoftMax1ToSoftMax8();
};

View File

@ -63,6 +63,8 @@
#include "transformations/op_conversions/convert_scatter_elements_to_scatter.hpp"
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
#include "transformations/op_conversions/convert_subtract.hpp"
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include "transformations/op_conversions/convert_depth_to_space.hpp"
#include "transformations/op_conversions/convert_space_to_depth.hpp"
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
@ -177,6 +179,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
manager.register_pass<ngraph::pass::ConvertGather1ToGather7, false>();
manager.register_pass<ngraph::pass::ConvertGather7ToGather8, false>();
manager.register_pass<ngraph::pass::ConvertDeformableConv8To1>();
manager.register_pass<ngraph::pass::ConvertSoftMax8ToSoftMax1>();
manager.register_pass<ngraph::pass::ConvertSoftMax1ToSoftMax8, false>();
manager.register_pass<ngraph::pass::ConvertMaxPool8ToMaxPool1>();
manager.register_pass<ngraph::pass::ConvertPriorBox8To0>(); // not plugins implemented priorbox8
manager.register_pass<ngraph::pass::ConvertDetectionOutput8ToDetectionOutput1>();

View File

@ -0,0 +1,40 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_softmax_downgrade.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/validation_util.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSoftMax8ToSoftMax1, "ConvertSoftMax8ToSoftMax1", 0);
ngraph::pass::ConvertSoftMax8ToSoftMax1::ConvertSoftMax8ToSoftMax1() {
MATCHER_SCOPE(ConvertSoftMax8ToSoftMax1);
auto input = pattern::any_input(pattern::has_static_rank());
auto softmax_v8_pattern = pattern::wrap_type<opset8::Softmax>({input});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto softmax_v8_node = std::dynamic_pointer_cast<opset8::Softmax>(m.get_match_root());
if (!softmax_v8_node)
return false;
auto v8_axis = softmax_v8_node->get_axis();
auto rank = softmax_v8_node->get_input_partial_shape(0).rank().get_length();
auto v1_axis = static_cast<size_t>(ov::normalize_axis(softmax_v8_node->description(), v8_axis, rank));
auto softmax_v1_node = std::make_shared<opset1::Softmax>(softmax_v8_node->input_value(0), v1_axis);
softmax_v1_node->set_friendly_name(softmax_v8_node->get_friendly_name());
copy_runtime_info(softmax_v8_node, softmax_v1_node);
replace_node(softmax_v8_node, softmax_v1_node);
return true;
};
auto m = std::make_shared<pattern::Matcher>(softmax_v8_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -0,0 +1,35 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/op_conversions/convert_softmax_upgrade.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSoftMax1ToSoftMax8, "ConvertSoftMax1ToSoftMax8", 0);
ngraph::pass::ConvertSoftMax1ToSoftMax8::ConvertSoftMax1ToSoftMax8() {
MATCHER_SCOPE(ConvertSoftMax1ToSoftMax8);
auto softmax_v1_pattern = pattern::wrap_type<opset1::Softmax>();
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto softmax_v1_node = std::dynamic_pointer_cast<opset1::Softmax>(m.get_match_root());
if (!softmax_v1_node)
return false;
auto axis = static_cast<int64_t>(softmax_v1_node->get_axis());
auto softmax_v8_node = std::make_shared<opset8::Softmax>(softmax_v1_node->input_value(0), axis);
softmax_v8_node->set_friendly_name(softmax_v1_node->get_friendly_name());
copy_runtime_info(softmax_v1_node, softmax_v8_node);
replace_node(softmax_v1_node, softmax_v8_node);
return true;
};
auto m = std::make_shared<pattern::Matcher>(softmax_v1_pattern, matcher_name);
register_matcher(m, callback);
}

View File

@ -10,31 +10,44 @@
#include <ngraph/rt_info.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::SoftmaxDecomposition, "SoftmaxDecomposition", 0);
ngraph::pass::SoftmaxDecomposition::SoftmaxDecomposition() {
MATCHER_SCOPE(SoftmaxDecomposition);
auto softmax = pattern::wrap_type<ngraph::opset8::Softmax>();
auto softmax = pattern::wrap_type<ngraph::opset1::Softmax, ngraph::opset8::Softmax>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto node = std::dynamic_pointer_cast<opset8::Softmax>(m.get_match_root());
if (!node || transformation_callback(node)) {
auto m_softmax = m.get_match_root();
Output<Node> input;
int64_t softmax_axis;
if (transformation_callback(m_softmax)) {
return false;
}
auto input = node->input_value(0);
auto axis = opset8::Constant::create(element::i64, Shape{1}, {node->get_axis()});
auto reduce_max = std::make_shared<opset8::ReduceMax>(input, axis, true);
auto sub = std::make_shared<opset8::Subtract>(input, reduce_max);
auto exp = std::make_shared<opset8::Exp>(sub);
auto reduce_sum = std::make_shared<opset8::ReduceSum>(exp, axis, true);
auto div = std::make_shared<opset8::Divide>(exp, reduce_sum);
if (auto m_softmax_v1 = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(m_softmax)) {
input = m_softmax_v1->input_value(0);
softmax_axis = static_cast<int64_t>(m_softmax_v1->get_axis());
} else if (auto m_softmax_v8 = std::dynamic_pointer_cast<ngraph::opset8::Softmax>(m_softmax)) {
input = m_softmax_v8->input_value(0);
softmax_axis = m_softmax_v8->get_axis();
} else {
return false;
}
replace_node(node, div);
copy_runtime_info(node, {reduce_max, reduce_sum, sub, exp, div});
div->set_friendly_name(node->get_friendly_name());
auto axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {softmax_axis});
auto reduce_max = std::make_shared<ngraph::opset8::ReduceMax>(input, axis, true);
auto sub = std::make_shared<ngraph::opset8::Subtract>(input, reduce_max);
auto exp = std::make_shared<ngraph::opset8::Exp>(sub);
auto reduce_sum = std::make_shared<ngraph::opset8::ReduceSum>(exp, axis, true);
auto div = std::make_shared<ngraph::opset8::Divide>(exp, reduce_sum);
replace_node(m_softmax, div);
copy_runtime_info(m_softmax, {reduce_max, reduce_sum, sub, exp, div});
div->set_friendly_name(m_softmax->get_friendly_name());
return true;
};

View File

@ -12,5 +12,9 @@ namespace op {
namespace v1 {
using ov::op::v1::Softmax;
} // namespace v1
namespace v8 {
using ov::op::v8::Softmax;
} // namespace v8
} // namespace op
} // namespace ngraph

View File

@ -45,5 +45,43 @@ private:
size_t m_axis{0};
};
} // namespace v1
namespace v8 {
/// \brief Softmax operation with with negative axis values
class OPENVINO_API Softmax : public Op {
public:
OPENVINO_OP("Softmax", "opset8");
Softmax() = default;
/// \brief Constructs a softmax operation.
///
/// \param arg Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param axis The axis position (0-based) in range [-rank(arg), rank(arg) - 1] on which to calculate the softmax.
///
/// Output `[d0, ...]`
///
Softmax(const Output<Node>& arg, const int64_t axis = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
const int64_t& get_axis() const {
return m_axis;
}
void set_axis(const int64_t& axis) {
m_axis = axis;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override;
private:
int64_t m_axis{1};
};
} // namespace v8
} // namespace op
} // namespace ov

View File

@ -94,7 +94,6 @@ _OPENVINO_OP_REG(Sign, ov::op::v0)
_OPENVINO_OP_REG(Sigmoid, ov::op::v0)
_OPENVINO_OP_REG(Sin, ov::op::v0)
_OPENVINO_OP_REG(Sinh, ov::op::v0)
_OPENVINO_OP_REG(Softmax, ov::op::v1)
_OPENVINO_OP_REG(Sqrt, ov::op::v0)
_OPENVINO_OP_REG(SpaceToDepth, ov::op::v0)
_OPENVINO_OP_REG(Split, ov::op::v1)
@ -186,5 +185,6 @@ _OPENVINO_OP_REG(NV12toBGR, ov::op::v8)
_OPENVINO_OP_REG(NV12toRGB, ov::op::v8)
_OPENVINO_OP_REG(RandomUniform, ov::op::v8)
_OPENVINO_OP_REG(Slice, ov::op::v8)
_OPENVINO_OP_REG(Softmax, ov::op::v8)
_OPENVINO_OP_REG(If, ov::op::v8)
_OPENVINO_OP_REG(PriorBox, ov::op::v8)

View File

@ -99,3 +99,67 @@ bool op::v1::Softmax::has_evaluate() const {
}
return false;
}
// *** SOFTMAX OP SET V8 ***
op::v8::Softmax::Softmax(const Output<Node>& arg, const int64_t axis) : Op({arg}), m_axis(axis) {
constructor_validate_and_infer_types();
}
bool op::v8::Softmax::visit_attributes(AttributeVisitor& visitor) {
NGRAPH_OP_SCOPE(v8_Softmax_visit_attributes);
visitor.on_attribute("axis", m_axis);
return true;
}
void op::v8::Softmax::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v8_Softmax_validate_and_infer_types);
const auto& input_shape = get_input_partial_shape(0);
if (input_shape.rank().is_static()) {
auto rank = static_cast<int64_t>(input_shape.size());
NODE_VALIDATION_CHECK(this,
-rank <= m_axis && m_axis < rank,
"Reduction axis (",
m_axis,
") is out of bounds (argument shape: ",
input_shape,
").");
}
set_output_type(0, get_input_element_type(0), input_shape);
}
shared_ptr<Node> op::v8::Softmax::clone_with_new_inputs(const OutputVector& new_args) const {
NGRAPH_OP_SCOPE(v8_Softmax_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<op::v8::Softmax>(new_args.at(0), m_axis);
}
bool op::v8::Softmax::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
NGRAPH_OP_SCOPE(v8_Softmax_evaluate);
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 1));
outputs[0]->set_unary(inputs[0]);
auto rank = static_cast<int64_t>(inputs[0]->get_shape().size());
NGRAPH_CHECK(-rank <= m_axis && m_axis < rank,
"Reduction axis (",
m_axis,
") is out of bounds (argument shape: ",
inputs[0]->get_shape(),
").");
size_t axis = static_cast<size_t>(ov::normalize_axis(this->description(), m_axis, rank));
return evaluate_softmax(inputs[0], outputs[0], AxisSet{axis});
}
bool op::v8::Softmax::has_evaluate() const {
NGRAPH_OP_SCOPE(v8_Softmax_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::bf16:
case ngraph::element::f16:
case ngraph::element::f32:
case ngraph::element::f64:
return true;
default:
break;
}
return false;
}

View File

@ -48,6 +48,7 @@
#include "ngraph/op/sign.hpp"
#include "ngraph/op/sin.hpp"
#include "ngraph/op/sinh.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/squeeze.hpp"
#include "ngraph/op/tan.hpp"
@ -1787,3 +1788,18 @@ TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32) {
vector<int32_t> out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0};
ASSERT_EQ(cval, out);
}
TEST(eval, evaluate_softmax_8) {
const Shape data_shape{1, 2};
auto arg = std::make_shared<ngraph::op::Parameter>(element::f32, PartialShape::dynamic());
auto softmax = std::make_shared<ngraph::op::v8::Softmax>(arg, -1);
auto fun = std::make_shared<Function>(OutputVector{softmax}, ParameterVector{arg});
auto result_tensor = std::make_shared<HostTensor>();
ASSERT_TRUE(fun->evaluate({result_tensor}, {make_host_tensor<element::Type_t::f32>(data_shape, {1, 1})}));
EXPECT_EQ(result_tensor->get_element_type(), element::f32);
EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{1, 2}));
auto val = read_vector<float>(result_tensor);
vector<float> out{0.5, 0.5};
ASSERT_EQ(val, out);
}

View File

@ -21,3 +21,40 @@ TEST(type_prop, softmax_out_of_bound_axis) {
// axis cannot be a negative number
ASSERT_THROW(make_shared<op::v1::Softmax>(arg, -1), ngraph::NodeValidationFailure);
}
TEST(type_prop, softmax_8_default_axis) {
const Shape arg_shape{2, 3};
auto arg = make_shared<op::Parameter>(element::f32, arg_shape);
auto sm = make_shared<op::v8::Softmax>(arg);
ASSERT_EQ(sm->get_axis(), 1);
}
TEST(type_prop, softmax_8_out_of_bound_negative_axis) {
const Shape arg_shape{2, 3};
auto arg = make_shared<op::Parameter>(element::f32, arg_shape);
// axis should be in range [-rank, rank - 1]
ASSERT_THROW(make_shared<op::v8::Softmax>(arg, -10), ngraph::NodeValidationFailure);
}
TEST(type_prop, softmax_8_out_of_bound_positive_axis) {
const Shape arg_shape{2, 3};
auto arg = make_shared<op::Parameter>(element::f32, arg_shape);
// axis should be in range [-rank, rank - 1]
ASSERT_THROW(make_shared<op::v8::Softmax>(arg, 10), ngraph::NodeValidationFailure);
}
TEST(type_prop, softmax_8_positive_axis) {
const Shape arg_shape{1, 10};
auto arg = make_shared<op::Parameter>(element::f32, arg_shape);
auto softmax = make_shared<op::v8::Softmax>(arg, 1);
ASSERT_EQ(softmax->get_element_type(), element::f32);
ASSERT_EQ(softmax->get_shape(), (Shape{1, 10}));
}
TEST(type_prop, softmax_8_negative_axis) {
const Shape arg_shape{1, 10};
auto arg = make_shared<op::Parameter>(element::f32, arg_shape);
auto softmax = make_shared<op::v8::Softmax>(arg, -1);
ASSERT_EQ(softmax->get_element_type(), element::f32);
ASSERT_EQ(softmax->get_shape(), (Shape{1, 10}));
}

View File

@ -12,10 +12,10 @@ class Softmax(Op):
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': __class__.op,
'op': __class__.op,
'version': 'opset1',
'infer': Softmax.infer,
'type': self.op,
'op': self.op,
'version': 'opset8',
'infer': self.infer,
'axis': 1,
'in_ports_count': 1,
'out_ports_count': 1,
@ -26,8 +26,6 @@ class Softmax(Op):
@staticmethod
def infer(node: Node):
if node.axis < 0:
node.axis = len(node.in_node().shape) + node.axis
copy_shape_infer(node)
PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])