diff --git a/inference-engine/src/mkldnn_plugin/utils/shape_inference/shape_inference.cpp b/inference-engine/src/mkldnn_plugin/utils/shape_inference/shape_inference.cpp index 4dcd650e6bd..f110a061bc7 100644 --- a/inference-engine/src/mkldnn_plugin/utils/shape_inference/shape_inference.cpp +++ b/inference-engine/src/mkldnn_plugin/utils/shape_inference/shape_inference.cpp @@ -60,8 +60,9 @@ void shape_inference(ov::Node* op, ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || ov::is_type(op) || - ov::is_type(op) || ov::is_type(op) || - ov::is_type(op) || ov::is_type(op)) { + ov::is_type(op) || ov::is_type(op) || + ov::is_type(op) || ov::is_type(op) || + ov::is_type(op)) { copy_shape_infer(node, input_shapes, output_shapes); } else if (ov::is_type(op) || ov::is_type(op) || ov::is_type(op)) { diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_softmax_downgrade_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_softmax_downgrade_test.cpp new file mode 100644 index 00000000000..84c2504bd48 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_softmax_downgrade_test.cpp @@ -0,0 +1,95 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST_F(TransformationTestsF, ConvertSoftMax8ToSoftMax1) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + int64_t axis = 1; + auto softmax_8 = std::make_shared(data, axis); + + function = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + size_t axis = 1; + auto softmax_1 = std::make_shared(data, axis); + + function_ref = std::make_shared(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, ConvertSoftMax8ToSoftMax1_negative_axis) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + int64_t axis = -1; + auto softmax_8 = std::make_shared(data, axis); + + function = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + size_t axis = 1; + auto softmax_1 = std::make_shared(data, axis); + + function_ref = std::make_shared(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, ConvertSoftMax8ToSoftMax1_input_rank_5) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5, 5, 5}); + int64_t axis = -2; + auto softmax_8 = std::make_shared(data, axis); + + function = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{1, 3, 5, 5, 5}); + size_t axis = 3; + auto softmax_1 = std::make_shared(data, axis); + + function_ref = std::make_shared(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, negative_ConvertSoftMax8ToSoftMax1_dynamic_rank) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + int64_t axis = -3; + auto softmax_8 = std::make_shared(data, axis); + + function = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + int64_t axis = -3; + auto softmax_8 = std::make_shared(data, axis); + + function_ref = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + } +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_softmax_upgrade_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_softmax_upgrade_test.cpp new file mode 100644 index 00000000000..eabe7b76064 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_softmax_upgrade_test.cpp @@ -0,0 +1,57 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST_F(TransformationTestsF, ConvertSoftMax1ToSoftMax8) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + size_t axis = 1; + auto softmax_1 = std::make_shared(data, axis); + + function = std::make_shared(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data}); + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + int64_t axis = 1; + auto softmax_8 = std::make_shared(data, axis); + + function_ref = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + } +} + +TEST_F(TransformationTestsF, ConvertSoftMax1ToSoftMax8_dynamic_rank) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + size_t axis = 1; + auto softmax_1 = std::make_shared(data, axis); + + function = std::make_shared(ngraph::NodeVector{softmax_1}, ngraph::ParameterVector{data}); + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::PartialShape::dynamic()); + int64_t axis = 1; + auto softmax_8 = std::make_shared(data, axis); + + function_ref = std::make_shared(ngraph::NodeVector{softmax_8}, ngraph::ParameterVector{data}); + } +} diff --git a/src/common/transformations/include/transformations/op_conversions/convert_softmax_downgrade.hpp b/src/common/transformations/include/transformations/op_conversions/convert_softmax_downgrade.hpp new file mode 100644 index 00000000000..4a9ce23c6c0 --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/convert_softmax_downgrade.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API ConvertSoftMax8ToSoftMax1; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief ConvertSoftMax8ToSoftMax1 converts v8::SoftMax into v1::SoftMax. + */ +class ngraph::pass::ConvertSoftMax8ToSoftMax1 : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ConvertSoftMax8ToSoftMax1(); +}; diff --git a/src/common/transformations/include/transformations/op_conversions/convert_softmax_upgrade.hpp b/src/common/transformations/include/transformations/op_conversions/convert_softmax_upgrade.hpp new file mode 100644 index 00000000000..240482593a5 --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/convert_softmax_upgrade.hpp @@ -0,0 +1,28 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API ConvertSoftMax1ToSoftMax8; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief ConvertSoftMax1ToSoftMax8 converts v1::SoftMax into v8::SoftMax. + */ + +class ngraph::pass::ConvertSoftMax1ToSoftMax8 : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ConvertSoftMax1ToSoftMax8(); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 24e75f039d9..6e261dd6ef5 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -63,6 +63,8 @@ #include "transformations/op_conversions/convert_scatter_elements_to_scatter.hpp" #include "transformations/op_conversions/convert_reduce_to_pooling.hpp" #include "transformations/op_conversions/convert_subtract.hpp" +#include "transformations/op_conversions/convert_softmax_downgrade.hpp" +#include "transformations/op_conversions/convert_softmax_upgrade.hpp" #include "transformations/op_conversions/convert_depth_to_space.hpp" #include "transformations/op_conversions/convert_space_to_depth.hpp" #include "transformations/op_conversions/convert_broadcast_to_tiles.hpp" @@ -177,6 +179,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); // not plugins implemented priorbox8 manager.register_pass(); diff --git a/src/common/transformations/src/transformations/op_conversions/convert_softmax_downgrade.cpp b/src/common/transformations/src/transformations/op_conversions/convert_softmax_downgrade.cpp new file mode 100644 index 00000000000..56ad25cece9 --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/convert_softmax_downgrade.cpp @@ -0,0 +1,40 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_softmax_downgrade.hpp" +#include +#include +#include +#include +#include +#include "itt.hpp" + +NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSoftMax8ToSoftMax1, "ConvertSoftMax8ToSoftMax1", 0); + +ngraph::pass::ConvertSoftMax8ToSoftMax1::ConvertSoftMax8ToSoftMax1() { + MATCHER_SCOPE(ConvertSoftMax8ToSoftMax1); + + auto input = pattern::any_input(pattern::has_static_rank()); + auto softmax_v8_pattern = pattern::wrap_type({input}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto softmax_v8_node = std::dynamic_pointer_cast(m.get_match_root()); + if (!softmax_v8_node) + return false; + + auto v8_axis = softmax_v8_node->get_axis(); + auto rank = softmax_v8_node->get_input_partial_shape(0).rank().get_length(); + auto v1_axis = static_cast(ov::normalize_axis(softmax_v8_node->description(), v8_axis, rank)); + + auto softmax_v1_node = std::make_shared(softmax_v8_node->input_value(0), v1_axis); + softmax_v1_node->set_friendly_name(softmax_v8_node->get_friendly_name()); + copy_runtime_info(softmax_v8_node, softmax_v1_node); + replace_node(softmax_v8_node, softmax_v1_node); + + return true; + }; + + auto m = std::make_shared(softmax_v8_pattern, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/op_conversions/convert_softmax_upgrade.cpp b/src/common/transformations/src/transformations/op_conversions/convert_softmax_upgrade.cpp new file mode 100644 index 00000000000..fb06a5438bb --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/convert_softmax_upgrade.cpp @@ -0,0 +1,35 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_softmax_upgrade.hpp" +#include +#include +#include +#include +#include "itt.hpp" + +NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSoftMax1ToSoftMax8, "ConvertSoftMax1ToSoftMax8", 0); + +ngraph::pass::ConvertSoftMax1ToSoftMax8::ConvertSoftMax1ToSoftMax8() { + MATCHER_SCOPE(ConvertSoftMax1ToSoftMax8); + + auto softmax_v1_pattern = pattern::wrap_type(); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto softmax_v1_node = std::dynamic_pointer_cast(m.get_match_root()); + if (!softmax_v1_node) + return false; + + auto axis = static_cast(softmax_v1_node->get_axis()); + auto softmax_v8_node = std::make_shared(softmax_v1_node->input_value(0), axis); + softmax_v8_node->set_friendly_name(softmax_v1_node->get_friendly_name()); + copy_runtime_info(softmax_v1_node, softmax_v8_node); + replace_node(softmax_v1_node, softmax_v8_node); + + return true; + }; + + auto m = std::make_shared(softmax_v1_pattern, matcher_name); + register_matcher(m, callback); +} diff --git a/src/common/transformations/src/transformations/op_conversions/softmax_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/softmax_decomposition.cpp index c133feaca9a..a291379c166 100644 --- a/src/common/transformations/src/transformations/op_conversions/softmax_decomposition.cpp +++ b/src/common/transformations/src/transformations/op_conversions/softmax_decomposition.cpp @@ -10,31 +10,44 @@ #include #include +#include #include +#include NGRAPH_RTTI_DEFINITION(ngraph::pass::SoftmaxDecomposition, "SoftmaxDecomposition", 0); ngraph::pass::SoftmaxDecomposition::SoftmaxDecomposition() { MATCHER_SCOPE(SoftmaxDecomposition); - auto softmax = pattern::wrap_type(); - + auto softmax = pattern::wrap_type(); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { - auto node = std::dynamic_pointer_cast(m.get_match_root()); - if (!node || transformation_callback(node)) { + auto m_softmax = m.get_match_root(); + Output input; + int64_t softmax_axis; + + if (transformation_callback(m_softmax)) { return false; } - auto input = node->input_value(0); - auto axis = opset8::Constant::create(element::i64, Shape{1}, {node->get_axis()}); - auto reduce_max = std::make_shared(input, axis, true); - auto sub = std::make_shared(input, reduce_max); - auto exp = std::make_shared(sub); - auto reduce_sum = std::make_shared(exp, axis, true); - auto div = std::make_shared(exp, reduce_sum); + if (auto m_softmax_v1 = std::dynamic_pointer_cast(m_softmax)) { + input = m_softmax_v1->input_value(0); + softmax_axis = static_cast(m_softmax_v1->get_axis()); + } else if (auto m_softmax_v8 = std::dynamic_pointer_cast(m_softmax)) { + input = m_softmax_v8->input_value(0); + softmax_axis = m_softmax_v8->get_axis(); + } else { + return false; + } - replace_node(node, div); - copy_runtime_info(node, {reduce_max, reduce_sum, sub, exp, div}); - div->set_friendly_name(node->get_friendly_name()); + auto axis = ngraph::opset8::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {softmax_axis}); + auto reduce_max = std::make_shared(input, axis, true); + auto sub = std::make_shared(input, reduce_max); + auto exp = std::make_shared(sub); + auto reduce_sum = std::make_shared(exp, axis, true); + auto div = std::make_shared(exp, reduce_sum); + + replace_node(m_softmax, div); + copy_runtime_info(m_softmax, {reduce_max, reduce_sum, sub, exp, div}); + div->set_friendly_name(m_softmax->get_friendly_name()); return true; }; diff --git a/src/core/include/ngraph/op/softmax.hpp b/src/core/include/ngraph/op/softmax.hpp index f0fadb841a4..cde1b372f9b 100644 --- a/src/core/include/ngraph/op/softmax.hpp +++ b/src/core/include/ngraph/op/softmax.hpp @@ -12,5 +12,9 @@ namespace op { namespace v1 { using ov::op::v1::Softmax; } // namespace v1 + +namespace v8 { +using ov::op::v8::Softmax; +} // namespace v8 } // namespace op } // namespace ngraph diff --git a/src/core/include/openvino/op/softmax.hpp b/src/core/include/openvino/op/softmax.hpp index 89123199ab4..7decaa14a46 100644 --- a/src/core/include/openvino/op/softmax.hpp +++ b/src/core/include/openvino/op/softmax.hpp @@ -45,5 +45,43 @@ private: size_t m_axis{0}; }; } // namespace v1 + +namespace v8 { +/// \brief Softmax operation with with negative axis values +class OPENVINO_API Softmax : public Op { +public: + OPENVINO_OP("Softmax", "opset8"); + + Softmax() = default; + /// \brief Constructs a softmax operation. + /// + /// \param arg Node that produces the first input tensor.
+ /// `[d0, ...]` + /// \param axis The axis position (0-based) in range [-rank(arg), rank(arg) - 1] on which to calculate the softmax. + /// + /// Output `[d0, ...]` + /// + Softmax(const Output& arg, const int64_t axis = 1); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + const int64_t& get_axis() const { + return m_axis; + } + void set_axis(const int64_t& axis) { + m_axis = axis; + } + OPENVINO_SUPPRESS_DEPRECATED_START + bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override; + OPENVINO_SUPPRESS_DEPRECATED_END + bool has_evaluate() const override; + +private: + int64_t m_axis{1}; +}; +} // namespace v8 } // namespace op } // namespace ov diff --git a/src/core/include/openvino/opsets/opset8_tbl.hpp b/src/core/include/openvino/opsets/opset8_tbl.hpp index 4f626f4fb91..540a6784286 100644 --- a/src/core/include/openvino/opsets/opset8_tbl.hpp +++ b/src/core/include/openvino/opsets/opset8_tbl.hpp @@ -94,7 +94,6 @@ _OPENVINO_OP_REG(Sign, ov::op::v0) _OPENVINO_OP_REG(Sigmoid, ov::op::v0) _OPENVINO_OP_REG(Sin, ov::op::v0) _OPENVINO_OP_REG(Sinh, ov::op::v0) -_OPENVINO_OP_REG(Softmax, ov::op::v1) _OPENVINO_OP_REG(Sqrt, ov::op::v0) _OPENVINO_OP_REG(SpaceToDepth, ov::op::v0) _OPENVINO_OP_REG(Split, ov::op::v1) @@ -186,5 +185,6 @@ _OPENVINO_OP_REG(NV12toBGR, ov::op::v8) _OPENVINO_OP_REG(NV12toRGB, ov::op::v8) _OPENVINO_OP_REG(RandomUniform, ov::op::v8) _OPENVINO_OP_REG(Slice, ov::op::v8) +_OPENVINO_OP_REG(Softmax, ov::op::v8) _OPENVINO_OP_REG(If, ov::op::v8) _OPENVINO_OP_REG(PriorBox, ov::op::v8) diff --git a/src/core/src/op/softmax.cpp b/src/core/src/op/softmax.cpp index 602f476b118..22ea4133a1a 100644 --- a/src/core/src/op/softmax.cpp +++ b/src/core/src/op/softmax.cpp @@ -99,3 +99,67 @@ bool op::v1::Softmax::has_evaluate() const { } return false; } + +// *** SOFTMAX OP SET V8 *** + +op::v8::Softmax::Softmax(const Output& arg, const int64_t axis) : Op({arg}), m_axis(axis) { + constructor_validate_and_infer_types(); +} + +bool op::v8::Softmax::visit_attributes(AttributeVisitor& visitor) { + NGRAPH_OP_SCOPE(v8_Softmax_visit_attributes); + visitor.on_attribute("axis", m_axis); + return true; +} + +void op::v8::Softmax::validate_and_infer_types() { + NGRAPH_OP_SCOPE(v8_Softmax_validate_and_infer_types); + const auto& input_shape = get_input_partial_shape(0); + if (input_shape.rank().is_static()) { + auto rank = static_cast(input_shape.size()); + NODE_VALIDATION_CHECK(this, + -rank <= m_axis && m_axis < rank, + "Reduction axis (", + m_axis, + ") is out of bounds (argument shape: ", + input_shape, + ")."); + } + + set_output_type(0, get_input_element_type(0), input_shape); +} + +shared_ptr op::v8::Softmax::clone_with_new_inputs(const OutputVector& new_args) const { + NGRAPH_OP_SCOPE(v8_Softmax_clone_with_new_inputs); + check_new_args_count(this, new_args); + return make_shared(new_args.at(0), m_axis); +} + +bool op::v8::Softmax::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const { + NGRAPH_OP_SCOPE(v8_Softmax_evaluate); + NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 1)); + outputs[0]->set_unary(inputs[0]); + auto rank = static_cast(inputs[0]->get_shape().size()); + NGRAPH_CHECK(-rank <= m_axis && m_axis < rank, + "Reduction axis (", + m_axis, + ") is out of bounds (argument shape: ", + inputs[0]->get_shape(), + ")."); + size_t axis = static_cast(ov::normalize_axis(this->description(), m_axis, rank)); + return evaluate_softmax(inputs[0], outputs[0], AxisSet{axis}); +} + +bool op::v8::Softmax::has_evaluate() const { + NGRAPH_OP_SCOPE(v8_Softmax_has_evaluate); + switch (get_input_element_type(0)) { + case ngraph::element::bf16: + case ngraph::element::f16: + case ngraph::element::f32: + case ngraph::element::f64: + return true; + default: + break; + } + return false; +} diff --git a/src/core/tests/eval.cpp b/src/core/tests/eval.cpp index feaa3aa6c72..622b1bf7a10 100644 --- a/src/core/tests/eval.cpp +++ b/src/core/tests/eval.cpp @@ -48,6 +48,7 @@ #include "ngraph/op/sign.hpp" #include "ngraph/op/sin.hpp" #include "ngraph/op/sinh.hpp" +#include "ngraph/op/softmax.hpp" #include "ngraph/op/sqrt.hpp" #include "ngraph/op/squeeze.hpp" #include "ngraph/op/tan.hpp" @@ -1787,3 +1788,18 @@ TEST(eval, evaluate_dynamic_scatter_update_one_elem_i32) { vector out{0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}; ASSERT_EQ(cval, out); } + +TEST(eval, evaluate_softmax_8) { + const Shape data_shape{1, 2}; + auto arg = std::make_shared(element::f32, PartialShape::dynamic()); + auto softmax = std::make_shared(arg, -1); + auto fun = std::make_shared(OutputVector{softmax}, ParameterVector{arg}); + auto result_tensor = std::make_shared(); + + ASSERT_TRUE(fun->evaluate({result_tensor}, {make_host_tensor(data_shape, {1, 1})})); + EXPECT_EQ(result_tensor->get_element_type(), element::f32); + EXPECT_EQ(result_tensor->get_partial_shape(), (PartialShape{1, 2})); + auto val = read_vector(result_tensor); + vector out{0.5, 0.5}; + ASSERT_EQ(val, out); +} diff --git a/src/core/tests/type_prop/softmax.cpp b/src/core/tests/type_prop/softmax.cpp index 4d3a795a5e2..158fae1b421 100644 --- a/src/core/tests/type_prop/softmax.cpp +++ b/src/core/tests/type_prop/softmax.cpp @@ -21,3 +21,40 @@ TEST(type_prop, softmax_out_of_bound_axis) { // axis cannot be a negative number ASSERT_THROW(make_shared(arg, -1), ngraph::NodeValidationFailure); } + +TEST(type_prop, softmax_8_default_axis) { + const Shape arg_shape{2, 3}; + auto arg = make_shared(element::f32, arg_shape); + auto sm = make_shared(arg); + ASSERT_EQ(sm->get_axis(), 1); +} + +TEST(type_prop, softmax_8_out_of_bound_negative_axis) { + const Shape arg_shape{2, 3}; + auto arg = make_shared(element::f32, arg_shape); + // axis should be in range [-rank, rank - 1] + ASSERT_THROW(make_shared(arg, -10), ngraph::NodeValidationFailure); +} + +TEST(type_prop, softmax_8_out_of_bound_positive_axis) { + const Shape arg_shape{2, 3}; + auto arg = make_shared(element::f32, arg_shape); + // axis should be in range [-rank, rank - 1] + ASSERT_THROW(make_shared(arg, 10), ngraph::NodeValidationFailure); +} + +TEST(type_prop, softmax_8_positive_axis) { + const Shape arg_shape{1, 10}; + auto arg = make_shared(element::f32, arg_shape); + auto softmax = make_shared(arg, 1); + ASSERT_EQ(softmax->get_element_type(), element::f32); + ASSERT_EQ(softmax->get_shape(), (Shape{1, 10})); +} + +TEST(type_prop, softmax_8_negative_axis) { + const Shape arg_shape{1, 10}; + auto arg = make_shared(element::f32, arg_shape); + auto softmax = make_shared(arg, -1); + ASSERT_EQ(softmax->get_element_type(), element::f32); + ASSERT_EQ(softmax->get_shape(), (Shape{1, 10})); +} diff --git a/tools/mo/openvino/tools/mo/ops/softmax.py b/tools/mo/openvino/tools/mo/ops/softmax.py index 5f6087af6cd..2133b89cbd2 100644 --- a/tools/mo/openvino/tools/mo/ops/softmax.py +++ b/tools/mo/openvino/tools/mo/ops/softmax.py @@ -12,10 +12,10 @@ class Softmax(Op): def __init__(self, graph: Graph, attrs: dict): super().__init__(graph, { - 'type': __class__.op, - 'op': __class__.op, - 'version': 'opset1', - 'infer': Softmax.infer, + 'type': self.op, + 'op': self.op, + 'version': 'opset8', + 'infer': self.infer, 'axis': 1, 'in_ports_count': 1, 'out_ports_count': 1, @@ -26,8 +26,6 @@ class Softmax(Op): @staticmethod def infer(node: Node): - if node.axis < 0: - node.axis = len(node.in_node().shape) + node.axis copy_shape_infer(node) PermuteAttrs.create_permute_attrs(node, attrs=[('axis', 'input:0')])