Improve node validation failure message for shape infer (#18520)
* Add NODE_SHAPE_INFER_CHECK macro throws NodeValidationFailure for shape inference * Use NODE_SHAPE_INFER_CHECK in topk shape inference * Move description to header file * export NodeValidationFailure::create function
This commit is contained in:
@@ -533,13 +533,34 @@ public:
|
||||
const Node* node,
|
||||
const std::string& explanation);
|
||||
|
||||
template <class TShape>
|
||||
[[noreturn]] static void create(const CheckLocInfo& check_loc_info,
|
||||
std::pair<const Node*, const std::vector<TShape>*>&& ctx,
|
||||
const std::string& explanation);
|
||||
|
||||
protected:
|
||||
explicit NodeValidationFailure(const std::string& what_arg) : ov::AssertFailure(what_arg) {}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Specialization to throw the `NodeValidationFailure` for shape inference using `PartialShape`
|
||||
*
|
||||
* @param check_loc_info Exception location details to print.
|
||||
* @param ctx NodeValidationFailure context which got pointer to node and input shapes used for shape
|
||||
* inference.
|
||||
* @param explanation Exception explanation string.
|
||||
*/
|
||||
template <>
|
||||
OPENVINO_API void NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
|
||||
std::pair<const Node*, const std::vector<PartialShape>*>&& ctx,
|
||||
const std::string& explanation);
|
||||
} // namespace ov
|
||||
#define NODE_VALIDATION_CHECK(node, ...) OPENVINO_ASSERT_HELPER(::ov::NodeValidationFailure, (node), __VA_ARGS__)
|
||||
|
||||
/** \brief Throw NodeValidationFailure with additional information about input shapes used during shape inference. */
|
||||
#define NODE_SHAPE_INFER_CHECK(node, input_shapes, ...) \
|
||||
NODE_VALIDATION_CHECK(std::make_pair(static_cast<const ::ov::Node*>((node)), &(input_shapes)), __VA_ARGS__)
|
||||
|
||||
namespace ov {
|
||||
template <typename T>
|
||||
void check_new_args_count(const Node* node, T new_args) {
|
||||
|
||||
36
src/core/shape_inference/include/shape_validation.hpp
Normal file
36
src/core/shape_inference/include/shape_validation.hpp
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iterator>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace validate {
|
||||
/**
|
||||
* @brief Provides `NodeValidationFailure` exception explanation string.
|
||||
*
|
||||
* @param shapes Vector of shapes used for inference to be printed before explanation.
|
||||
* @param explanation String with exception explanation.
|
||||
* @return Explanation string.
|
||||
*/
|
||||
template <class TShape>
|
||||
std::string shape_infer_explanation_str(const std::vector<TShape>& shapes, const std::string& explanation) {
|
||||
std::stringstream o;
|
||||
o << "Shape inference input shapes {";
|
||||
std::copy(shapes.cbegin(), std::prev(shapes.cend()), std::ostream_iterator<TShape>(o, ","));
|
||||
if (!shapes.empty()) {
|
||||
o << shapes.back();
|
||||
}
|
||||
o << "}\n" << explanation;
|
||||
return o.str();
|
||||
}
|
||||
} // namespace validate
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
@@ -58,12 +58,13 @@ std::vector<TShape> shape_infer(const util::TopKBase* op,
|
||||
const auto& input_shape = input_shapes[0];
|
||||
const auto input_rank = input_shape.rank();
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_rank.is_dynamic() || input_rank.get_length() > 0,
|
||||
"Input rank must be greater than 0.");
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
input_rank.is_dynamic() || input_rank.get_length() > 0,
|
||||
"Input rank must be greater than 0.");
|
||||
|
||||
const auto& k_shape = input_shapes[1];
|
||||
NODE_VALIDATION_CHECK(op, k_shape.rank().compatible(0), "The 'K' input must be a scalar.");
|
||||
NODE_SHAPE_INFER_CHECK(op, input_shapes, k_shape.rank().compatible(0), "The 'K' input must be a scalar.");
|
||||
|
||||
auto output_shape = input_shape;
|
||||
if (input_shape.rank().is_static()) {
|
||||
@@ -73,12 +74,13 @@ std::vector<TShape> shape_infer(const util::TopKBase* op,
|
||||
auto& dim_axis = output_shape[normalized_axis];
|
||||
|
||||
if (auto k_as_shape = get_input_const_data_as_shape<TShape>(op, 1, constant_data, GetK<TDimValue>(op))) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
k_as_shape->size() == 1,
|
||||
"Only one value (scalar) should be provided as the 'K' input to TopK",
|
||||
" (got ",
|
||||
k_as_shape->size(),
|
||||
" elements).");
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
k_as_shape->size() == 1,
|
||||
"Only one value (scalar) should be provided as the 'K' input to TopK",
|
||||
" (got ",
|
||||
k_as_shape->size(),
|
||||
" elements).");
|
||||
|
||||
const auto& k = (*k_as_shape)[0];
|
||||
if (k.is_static()) {
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "openvino/core/descriptor/input.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
#include "shape_util.hpp"
|
||||
#include "shape_validation.hpp"
|
||||
#include "shared_node_info.hpp"
|
||||
#include "tensor_conversion_util.hpp"
|
||||
|
||||
@@ -40,6 +41,15 @@ void ov::NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
|
||||
throw ov::NodeValidationFailure(make_what(check_loc_info, node_validation_failure_loc_string(node), explanation));
|
||||
}
|
||||
|
||||
template <>
|
||||
void ov::NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
|
||||
std::pair<const Node*, const std::vector<PartialShape>*>&& ctx,
|
||||
const std::string& explanation) {
|
||||
throw ov::NodeValidationFailure(make_what(check_loc_info,
|
||||
node_validation_failure_loc_string(ctx.first),
|
||||
op::validate::shape_infer_explanation_str(*ctx.second, explanation)));
|
||||
}
|
||||
|
||||
atomic<size_t> ov::Node::m_next_instance_id(0);
|
||||
|
||||
ov::Node::Node() = default;
|
||||
|
||||
49
src/core/tests/node_test.cpp
Normal file
49
src/core/tests/node_test.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/core/node.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
class TestNode : public ov::Node {
|
||||
public:
|
||||
TestNode() : Node() {}
|
||||
|
||||
static const type_info_t& get_type_info_static() {
|
||||
static const type_info_t info{"TestNode", ""};
|
||||
info.hash();
|
||||
return info;
|
||||
}
|
||||
|
||||
const type_info_t& get_type_info() const override {
|
||||
return get_type_info_static();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector&) const override {
|
||||
return std::make_shared<TestNode>();
|
||||
}
|
||||
};
|
||||
|
||||
class NodeValidationFailureTest : public Test {
|
||||
protected:
|
||||
TestNode test_node;
|
||||
};
|
||||
|
||||
TEST_F(NodeValidationFailureTest, node_failure_message) {
|
||||
OV_EXPECT_THROW(NODE_VALIDATION_CHECK(&test_node, false, "Test message"),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("':\nTest message"));
|
||||
}
|
||||
|
||||
TEST_F(NodeValidationFailureTest, node_shape_infer_failure_message) {
|
||||
const auto input_shapes = std::vector<ov::PartialShape>{{1, 2, 3}, {1}};
|
||||
|
||||
OV_EXPECT_THROW(NODE_SHAPE_INFER_CHECK(&test_node, input_shapes, false, "Test message"),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("':\nShape inference input shapes {[1,2,3],[1]}\nTest message"));
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shape_validation.hpp"
|
||||
#include "static_shape.hpp"
|
||||
|
||||
namespace ov {
|
||||
@@ -265,4 +266,13 @@ ov::Shape StaticShapeRef::get_shape() const {
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
||||
template <>
|
||||
void NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
|
||||
std::pair<const Node*, const std::vector<intel_cpu::StaticShape>*>&& ctx,
|
||||
const std::string& explanation) {
|
||||
throw ov::NodeValidationFailure(make_what(check_loc_info,
|
||||
node_validation_failure_loc_string(ctx.first),
|
||||
ov::op::validate::shape_infer_explanation_str(*ctx.second, explanation)));
|
||||
}
|
||||
} // namespace ov
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "cpu_types.h"
|
||||
#include "openvino/core/attribute_adapter.hpp"
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
#include "openvino/core/rank.hpp"
|
||||
#include "openvino/core/shape.hpp"
|
||||
@@ -368,4 +369,17 @@ constexpr typename std::enable_if<is_static_shape_adapter<T>() && is_static_shap
|
||||
return (lhs.size() == rhs.size()) && (lhs.empty() || std::equal(lhs.cbegin(), lhs.cend(), rhs.cbegin()));
|
||||
}
|
||||
} // namespace intel_cpu
|
||||
|
||||
/**
|
||||
* @brief Specialization to throw the `NodeValidationFailure` for shape inference using `StaticShape`
|
||||
*
|
||||
* @param check_loc_info Exception location details to print.
|
||||
* @param ctx NodeValidationFailure context which got pointer to node and input shapes used for shape
|
||||
* inference.
|
||||
* @param explanation Exception explanation string.
|
||||
*/
|
||||
template <>
|
||||
void NodeValidationFailure::create(const CheckLocInfo& check_loc_info,
|
||||
std::pair<const ov::Node*, const std::vector<intel_cpu::StaticShape>*>&& ctx,
|
||||
const std::string& explanation);
|
||||
} // namespace ov
|
||||
|
||||
Reference in New Issue
Block a user