[GPU] Allow to use infinity value as a Pad's fill value (#19201)
This commit is contained in:
parent
043cd86449
commit
d13ae31a61
@ -22,7 +22,7 @@ namespace op {
|
|||||||
namespace util {
|
namespace util {
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
bool normalize_single_value(std::vector<T> vec, float& value) {
|
bool normalize_single_value(std::vector<T> vec, float& value, bool check_value_range = true) {
|
||||||
for (const auto& val : vec) {
|
for (const auto& val : vec) {
|
||||||
if (val != *vec.begin())
|
if (val != *vec.begin())
|
||||||
return false;
|
return false;
|
||||||
@ -30,7 +30,8 @@ bool normalize_single_value(std::vector<T> vec, float& value) {
|
|||||||
|
|
||||||
float ref_val = static_cast<float>(*vec.begin());
|
float ref_val = static_cast<float>(*vec.begin());
|
||||||
|
|
||||||
if (ref_val < std::numeric_limits<float>::lowest() || ref_val > std::numeric_limits<float>::max()) {
|
if (check_value_range &&
|
||||||
|
(ref_val < std::numeric_limits<float>::lowest() || ref_val > std::numeric_limits<float>::max())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,7 +160,9 @@ bool has_constant_value(const std::shared_ptr<Node>& node,
|
|||||||
return const_values == values;
|
return const_values == values;
|
||||||
}
|
}
|
||||||
|
|
||||||
TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<opset4::Constant>& const_node, float& value);
|
TRANSFORMATIONS_API bool get_single_value(const std::shared_ptr<opset4::Constant>& const_node,
|
||||||
|
float& value,
|
||||||
|
bool check_value_range = true);
|
||||||
|
|
||||||
TRANSFORMATIONS_API std::shared_ptr<Node> normalize_constant(const std::shared_ptr<opset4::Constant>& constant,
|
TRANSFORMATIONS_API std::shared_ptr<Node> normalize_constant(const std::shared_ptr<opset4::Constant>& constant,
|
||||||
const PartialShape& shape);
|
const PartialShape& shape);
|
||||||
|
@ -21,32 +21,32 @@ namespace ov {
|
|||||||
namespace op {
|
namespace op {
|
||||||
namespace util {
|
namespace util {
|
||||||
|
|
||||||
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value) {
|
bool get_single_value(const std::shared_ptr<op::v0::Constant>& const_node, float& value, bool check_value_range) {
|
||||||
switch (const_node->get_element_type()) {
|
switch (const_node->get_element_type()) {
|
||||||
case element::Type_t::f16:
|
case element::Type_t::f16:
|
||||||
return util::normalize_single_value(const_node->get_vector<float16>(), value);
|
return util::normalize_single_value(const_node->get_vector<float16>(), value, check_value_range);
|
||||||
case element::Type_t::f32:
|
case element::Type_t::f32:
|
||||||
return util::normalize_single_value(const_node->get_vector<float>(), value);
|
return util::normalize_single_value(const_node->get_vector<float>(), value, check_value_range);
|
||||||
case element::Type_t::bf16:
|
case element::Type_t::bf16:
|
||||||
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value);
|
return util::normalize_single_value(const_node->get_vector<bfloat16>(), value, check_value_range);
|
||||||
case element::Type_t::f64:
|
case element::Type_t::f64:
|
||||||
return util::normalize_single_value(const_node->get_vector<double>(), value);
|
return util::normalize_single_value(const_node->get_vector<double>(), value, check_value_range);
|
||||||
case element::Type_t::i8:
|
case element::Type_t::i8:
|
||||||
return util::normalize_single_value(const_node->get_vector<int8_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<int8_t>(), value, check_value_range);
|
||||||
case element::Type_t::i16:
|
case element::Type_t::i16:
|
||||||
return util::normalize_single_value(const_node->get_vector<int16_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<int16_t>(), value, check_value_range);
|
||||||
case element::Type_t::i32:
|
case element::Type_t::i32:
|
||||||
return util::normalize_single_value(const_node->get_vector<int32_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<int32_t>(), value, check_value_range);
|
||||||
case element::Type_t::i64:
|
case element::Type_t::i64:
|
||||||
return util::normalize_single_value(const_node->get_vector<int64_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<int64_t>(), value, check_value_range);
|
||||||
case element::Type_t::u8:
|
case element::Type_t::u8:
|
||||||
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<uint8_t>(), value, check_value_range);
|
||||||
case element::Type_t::u16:
|
case element::Type_t::u16:
|
||||||
return util::normalize_single_value(const_node->get_vector<uint16_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<uint16_t>(), value, check_value_range);
|
||||||
case element::Type_t::u32:
|
case element::Type_t::u32:
|
||||||
return util::normalize_single_value(const_node->get_vector<uint32_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<uint32_t>(), value, check_value_range);
|
||||||
case element::Type_t::u64:
|
case element::Type_t::u64:
|
||||||
return util::normalize_single_value(const_node->get_vector<uint64_t>(), value);
|
return util::normalize_single_value(const_node->get_vector<uint64_t>(), value, check_value_range);
|
||||||
default:
|
default:
|
||||||
OPENVINO_THROW("Unsupported precision for const operation: ", const_node->get_friendly_name());
|
OPENVINO_THROW("Unsupported precision for const operation: ", const_node->get_friendly_name());
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,8 @@ static void CreatePadOp(Program& p, const std::shared_ptr<ngraph::op::v1::Pad>&
|
|||||||
if (op->get_pad_mode() == ov::op::PadMode::CONSTANT && op->get_input_size() == 4) {
|
if (op->get_pad_mode() == ov::op::PadMode::CONSTANT && op->get_input_size() == 4) {
|
||||||
auto const_node = std::dynamic_pointer_cast<ngraph::op::v0::Constant>(op->get_input_node_shared_ptr(3));
|
auto const_node = std::dynamic_pointer_cast<ngraph::op::v0::Constant>(op->get_input_node_shared_ptr(3));
|
||||||
if (const_node) {
|
if (const_node) {
|
||||||
OPENVINO_ASSERT(ov::op::util::get_single_value(const_node, pad_value),
|
const bool check_value_range = false; // Allows the usage of infinity value as pad_value
|
||||||
|
OPENVINO_ASSERT(ov::op::util::get_single_value(const_node, pad_value, check_value_range),
|
||||||
"Invalid parameter size in ", op->get_friendly_name(), " (", op->get_type_name(), ")");
|
"Invalid parameter size in ", op->get_friendly_name(), " (", op->get_type_name(), ")");
|
||||||
is_value_const = true;
|
is_value_const = true;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user