[GPU] Allow to use infinity value as a Pad's fill value (#19201)

This commit is contained in:
Sergey Shlyapnikov 2023-08-15 18:31:49 +04:00 committed by GitHub
parent 043cd86449
commit d13ae31a61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 17 deletions

View File

@ -22,7 +22,7 @@ namespace op {
namespace util {
template <class T>
bool normalize_single_value(std::vector<T> vec, float& value) {
bool normalize_single_value(std::vector<T> vec, float& value, bool check_value_range = true) {
for (const auto& val : vec) {
if (val != *vec.begin())
return false;
@ -30,7 +30,8 @@ bool normalize_single_value(std::vector<T> vec, float& value) {
float ref_val = static_cast<float>(*vec.begin());
if (ref_val < std::numeric_limits<float>::lowest() || ref_val > std::numeric_limits<float>::max()) {
if (check_value_range &&
(ref_val < std::numeric_limits<float>::lowest() || ref_val > std::numeric_limits<float>::max())) {
return false;
}
@ -159,7 +160,9 @@ bool has_constant_value(const std::shared_ptr<Node>& node,
return const_values == values;
}
TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<opset4::Constant>& const_node, float& value);
TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<opset4::Constant>& const_node,
float& value,
bool check_value_range = true);
TRANSFORMATIONS_API std::shared_ptr<Node> normalize_constant(const std::shared_ptr<opset4::Constant>& constant,
const PartialShape& shape);

View File

@ -21,32 +21,32 @@ namespace ov {
namespace op {
namespace util {
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value) {
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
switch (const_node->get_element_type()) {
case element::Type_t::f16:
return util::normalize_single_value(const_node->get_vector<float16>(), value);
return util::normalize_single_value(const_node->get_vector<float16>(), value, check_value_range);
case element::Type_t::f32:
return util::normalize_single_value(const_node->get_vector<float>(), value);
return util::normalize_single_value(const_node->get_vector<float>(), value, check_value_range);
case element::Type_t::bf16:
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value);
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value, check_value_range);
case element::Type_t::f64:
return util::normalize_single_value(const_node->get_vector<double>(), value);
return util::normalize_single_value(const_node->get_vector<double>(), value, check_value_range);
case element::Type_t::i8:
return util::normalize_single_value(const_node->get_vector<int8_t>(), value);
return util::normalize_single_value(const_node->get_vector<int8_t>(), value, check_value_range);
case element::Type_t::i16:
return util::normalize_single_value(const_node->get_vector<int16_t>(), value);
return util::normalize_single_value(const_node->get_vector<int16_t>(), value, check_value_range);
case element::Type_t::i32:
return util::normalize_single_value(const_node->get_vector<int32_t>(), value);
return util::normalize_single_value(const_node->get_vector<int32_t>(), value, check_value_range);
case element::Type_t::i64:
return util::normalize_single_value(const_node->get_vector<int64_t>(), value);
return util::normalize_single_value(const_node->get_vector<int64_t>(), value, check_value_range);
case element::Type_t::u8:
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value);
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value, check_value_range);
case element::Type_t::u16:
return util::normalize_single_value(const_node->get_vector<uint16_t>(), value);
return util::normalize_single_value(const_node->get_vector<uint16_t>(), value, check_value_range);
case element::Type_t::u32:
return util::normalize_single_value(const_node->get_vector<uint32_t>(), value);
return util::normalize_single_value(const_node->get_vector<uint32_t>(), value, check_value_range);
case element::Type_t::u64:
return util::normalize_single_value(const_node->get_vector<uint64_t>(), value);
return util::normalize_single_value(const_node->get_vector<uint64_t>(), value, check_value_range);
default:
OPENVINO_THROW("Unsupported precision for const operation: ", const_node->get_friendly_name());
}

View File

@ -44,7 +44,8 @@ static void CreatePadOp(Program& p, const std::shared_ptr<ngraph::op::v1::Pad>&
if (op->get_pad_mode() == ov::op::PadMode::CONSTANT && op->get_input_size() == 4) {
auto const_node = std::dynamic_pointer_cast<ngraph::op::v0::Constant>(op->get_input_node_shared_ptr(3));
if (const_node) {
OPENVINO_ASSERT(ov::op::util::get_single_value(const_node, pad_value),
const bool check_value_range = false; // Allows the usage of infinity value as pad_value
OPENVINO_ASSERT(ov::op::util::get_single_value(const_node, pad_value, check_value_range),
"Invalid parameter size in ", op->get_friendly_name(), " (", op->get_type_name(), ")");
is_value_const = true;
}