[GPU] Allow to use infinity value as a Pad's fill value (#19201)
This commit is contained in:
parent
043cd86449
commit
d13ae31a61
@ -22,7 +22,7 @@ namespace op {
|
||||
namespace util {
|
||||
|
||||
template <class T>
|
||||
bool normalize_single_value(std::vector<T> vec, float& value) {
|
||||
bool normalize_single_value(std::vector<T> vec, float& value, bool check_value_range = true) {
|
||||
for (const auto& val : vec) {
|
||||
if (val != *vec.begin())
|
||||
return false;
|
||||
@ -30,7 +30,8 @@ bool normalize_single_value(std::vector<T> vec, float& value) {
|
||||
|
||||
float ref_val = static_cast<float>(*vec.begin());
|
||||
|
||||
if (ref_val < std::numeric_limits<float>::lowest() || ref_val > std::numeric_limits<float>::max()) {
|
||||
if (check_value_range &&
|
||||
(ref_val < std::numeric_limits<float>::lowest() || ref_val > std::numeric_limits<float>::max())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -159,7 +160,9 @@ bool has_constant_value(const std::shared_ptr<Node>& node,
|
||||
return const_values == values;
|
||||
}
|
||||
|
||||
TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<opset4::Constant>& const_node, float& value);
|
||||
TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<opset4::Constant>& const_node,
|
||||
float& value,
|
||||
bool check_value_range = true);
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<Node> normalize_constant(const std::shared_ptr<opset4::Constant>& constant,
|
||||
const PartialShape& shape);
|
||||
|
@ -21,32 +21,32 @@ namespace ov {
|
||||
namespace op {
|
||||
namespace util {
|
||||
|
||||
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value) {
|
||||
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
|
||||
switch (const_node->get_element_type()) {
|
||||
case element::Type_t::f16:
|
||||
return util::normalize_single_value(const_node->get_vector<float16>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<float16>(), value, check_value_range);
|
||||
case element::Type_t::f32:
|
||||
return util::normalize_single_value(const_node->get_vector<float>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<float>(), value, check_value_range);
|
||||
case element::Type_t::bf16:
|
||||
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value, check_value_range);
|
||||
case element::Type_t::f64:
|
||||
return util::normalize_single_value(const_node->get_vector<double>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<double>(), value, check_value_range);
|
||||
case element::Type_t::i8:
|
||||
return util::normalize_single_value(const_node->get_vector<int8_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<int8_t>(), value, check_value_range);
|
||||
case element::Type_t::i16:
|
||||
return util::normalize_single_value(const_node->get_vector<int16_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<int16_t>(), value, check_value_range);
|
||||
case element::Type_t::i32:
|
||||
return util::normalize_single_value(const_node->get_vector<int32_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<int32_t>(), value, check_value_range);
|
||||
case element::Type_t::i64:
|
||||
return util::normalize_single_value(const_node->get_vector<int64_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<int64_t>(), value, check_value_range);
|
||||
case element::Type_t::u8:
|
||||
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value, check_value_range);
|
||||
case element::Type_t::u16:
|
||||
return util::normalize_single_value(const_node->get_vector<uint16_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<uint16_t>(), value, check_value_range);
|
||||
case element::Type_t::u32:
|
||||
return util::normalize_single_value(const_node->get_vector<uint32_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<uint32_t>(), value, check_value_range);
|
||||
case element::Type_t::u64:
|
||||
return util::normalize_single_value(const_node->get_vector<uint64_t>(), value);
|
||||
return util::normalize_single_value(const_node->get_vector<uint64_t>(), value, check_value_range);
|
||||
default:
|
||||
OPENVINO_THROW("Unsupported precision for const operation: ", const_node->get_friendly_name());
|
||||
}
|
||||
|
@ -44,7 +44,8 @@ static void CreatePadOp(Program& p, const std::shared_ptr<ngraph::op::v1::Pad>&
|
||||
if (op->get_pad_mode() == ov::op::PadMode::CONSTANT && op->get_input_size() == 4) {
|
||||
auto const_node = std::dynamic_pointer_cast<ngraph::op::v0::Constant>(op->get_input_node_shared_ptr(3));
|
||||
if (const_node) {
|
||||
OPENVINO_ASSERT(ov::op::util::get_single_value(const_node, pad_value),
|
||||
const bool check_value_range = false; // Allows the usage of infinity value as pad_value
|
||||
OPENVINO_ASSERT(ov::op::util::get_single_value(const_node, pad_value, check_value_range),
|
||||
"Invalid parameter size in ", op->get_friendly_name(), " (", op->get_type_name(), ")");
|
||||
is_value_const = true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user