Improve node validation failure message for shape infer (#18520)

* Add NODE_SHAPE_INFER_CHECK macro
throws NodeValidationFailure for shape inference

* Use NODE_SHAPE_INFER_CHECK in topk shape inference

* Move description to header file

* export NodeValidationFailure::create function
This commit is contained in:
Pawel Raasz
2023-07-18 09:41:19 +02:00
committed by GitHub
parent 787796d88f
commit ff5b56ee07
7 changed files with 152 additions and 10 deletions

View File

@@ -533,13 +533,34 @@ public:
const Node* node,
const std::string& explanation);
template <class TShape>
[[noreturn]] static void create(const CheckLocInfo& check_loc_info,
std::pair<const Node*, const std::vector<TShape>*>&& ctx,
const std::string& explanation);
protected:
explicit NodeValidationFailure(const std::string& what_arg) : ov::AssertFailure(what_arg) {}
};
/**
* @brief Specialization to throw the `NodeValidationFailure` for shape inference using `PartialShape`
*
* @param check_loc_info Exception location details to print.
* @param ctx NodeValidationFailure context which got pointer to node and input shapes used for shape
* inference.
* @param explanation Exception explanation string.
*/
template <>
OPENVINO_API void NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
std::pair<const Node*, const std::vector<PartialShape>*>&& ctx,
const std::string& explanation);
} // namespace ov
#define NODE_VALIDATION_CHECK(node, ...) OPENVINO_ASSERT_HELPER(::ov::NodeValidationFailure, (node), __VA_ARGS__)
/** \brief Throw NodeValidationFailure with additional information about input shapes used during shape inference. */
#define NODE_SHAPE_INFER_CHECK(node, input_shapes, ...) \
NODE_VALIDATION_CHECK(std::make_pair(static_cast<const ::ov::Node*>((node)), &(input_shapes)), __VA_ARGS__)
namespace ov {
template <typename T>
void check_new_args_count(const Node* node, T new_args) {

View File

@@ -0,0 +1,36 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <iterator>
#include <sstream>
#include <string>
#include "openvino/core/node.hpp"
namespace ov {
namespace op {
namespace validate {
/**
* @brief Provides `NodeValidationFailure` exception explanation string.
*
* @param shapes Vector of shapes used for inference to be printed before explanation.
* @param explanation String with exception explanation.
* @return Explanation string.
*/
template <class TShape>
std::string shape_infer_explanation_str(const std::vector<TShape>& shapes, const std::string& explanation) {
std::stringstream o;
o << "Shape inference input shapes {";
std::copy(shapes.cbegin(), std::prev(shapes.cend()), std::ostream_iterator<TShape>(o, ","));
if (!shapes.empty()) {
o << shapes.back();
}
o << "}\n" << explanation;
return o.str();
}
} // namespace validate
} // namespace op
} // namespace ov

View File

@@ -58,12 +58,13 @@ std::vector<TShape> shape_infer(const util::TopKBase* op,
const auto& input_shape = input_shapes[0];
const auto input_rank = input_shape.rank();
NODE_VALIDATION_CHECK(op,
input_rank.is_dynamic() || input_rank.get_length() > 0,
"Input rank must be greater than 0.");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
input_rank.is_dynamic() || input_rank.get_length() > 0,
"Input rank must be greater than 0.");
const auto& k_shape = input_shapes[1];
NODE_VALIDATION_CHECK(op, k_shape.rank().compatible(0), "The 'K' input must be a scalar.");
NODE_SHAPE_INFER_CHECK(op, input_shapes, k_shape.rank().compatible(0), "The 'K' input must be a scalar.");
auto output_shape = input_shape;
if (input_shape.rank().is_static()) {
@@ -73,12 +74,13 @@ std::vector<TShape> shape_infer(const util::TopKBase* op,
auto& dim_axis = output_shape[normalized_axis];
if (auto k_as_shape = get_input_const_data_as_shape<TShape>(op, 1, constant_data, GetK<TDimValue>(op))) {
NODE_VALIDATION_CHECK(op,
k_as_shape->size() == 1,
"Only one value (scalar) should be provided as the 'K' input to TopK",
" (got ",
k_as_shape->size(),
" elements).");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
k_as_shape->size() == 1,
"Only one value (scalar) should be provided as the 'K' input to TopK",
" (got ",
k_as_shape->size(),
" elements).");
const auto& k = (*k_as_shape)[0];
if (k.is_static()) {

View File

@@ -21,6 +21,7 @@
#include "openvino/core/descriptor/input.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "shape_util.hpp"
#include "shape_validation.hpp"
#include "shared_node_info.hpp"
#include "tensor_conversion_util.hpp"
@@ -40,6 +41,15 @@ void ov::NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
throw ov::NodeValidationFailure(make_what(check_loc_info, node_validation_failure_loc_string(node), explanation));
}
template <>
void ov::NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
std::pair<const Node*, const std::vector<PartialShape>*>&& ctx,
const std::string& explanation) {
throw ov::NodeValidationFailure(make_what(check_loc_info,
node_validation_failure_loc_string(ctx.first),
op::validate::shape_infer_explanation_str(*ctx.second, explanation)));
}
atomic<size_t> ov::Node::m_next_instance_id(0);
ov::Node::Node() = default;

View File

@@ -0,0 +1,49 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/core/node.hpp"
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
using namespace testing;
class TestNode : public ov::Node {
public:
TestNode() : Node() {}
static const type_info_t& get_type_info_static() {
static const type_info_t info{"TestNode", ""};
info.hash();
return info;
}
const type_info_t& get_type_info() const override {
return get_type_info_static();
}
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector&) const override {
return std::make_shared<TestNode>();
}
};
class NodeValidationFailureTest : public Test {
protected:
TestNode test_node;
};
TEST_F(NodeValidationFailureTest, node_failure_message) {
OV_EXPECT_THROW(NODE_VALIDATION_CHECK(&test_node, false, "Test message"),
ov::NodeValidationFailure,
HasSubstr("':\nTest message"));
}
TEST_F(NodeValidationFailureTest, node_shape_infer_failure_message) {
const auto input_shapes = std::vector<ov::PartialShape>{{1, 2, 3}, {1}};
OV_EXPECT_THROW(NODE_SHAPE_INFER_CHECK(&test_node, input_shapes, false, "Test message"),
ov::NodeValidationFailure,
HasSubstr("':\nShape inference input shapes {[1,2,3],[1]}\nTest message"));
}

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "shape_validation.hpp"
#include "static_shape.hpp"
namespace ov {
@@ -265,4 +266,13 @@ ov::Shape StaticShapeRef::get_shape() const {
}
} // namespace intel_cpu
template <>
void NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
std::pair<const Node*, const std::vector<intel_cpu::StaticShape>*>&& ctx,
const std::string& explanation) {
throw ov::NodeValidationFailure(make_what(check_loc_info,
node_validation_failure_loc_string(ctx.first),
ov::op::validate::shape_infer_explanation_str(*ctx.second, explanation)));
}
} // namespace ov

View File

@@ -9,6 +9,7 @@
#include "cpu_types.h"
#include "openvino/core/attribute_adapter.hpp"
#include "openvino/core/except.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/rank.hpp"
#include "openvino/core/shape.hpp"
@@ -368,4 +369,17 @@ constexpr typename std::enable_if<is_static_shape_adapter<T>() && is_static_shap
return (lhs.size() == rhs.size()) && (lhs.empty() || std::equal(lhs.cbegin(), lhs.cend(), rhs.cbegin()));
}
} // namespace intel_cpu
/**
* @brief Specialization to throw the `NodeValidationFailure` for shape inference using `StaticShape`
*
* @param check_loc_info Exception location details to print.
* @param ctx NodeValidationFailure context which got pointer to node and input shapes used for shape
* inference.
* @param explanation Exception explanation string.
*/
template <>
void NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
std::pair<const ov::Node*, const std::vector<intel_cpu::StaticShape>*>&& ctx,
const std::string& explanation);
} // namespace ov