LogicalNot: convert precision (#17061)
* CVS-108362 LogicalNot: convert precision * Test
This commit is contained in:
committed by
GitHub
parent
2e3deb8d8f
commit
cd4c012f08
@@ -84,14 +84,12 @@ bool fuse_type_to_logical(const std::shared_ptr<ngraph::Node>& node, const preci
|
|||||||
const auto& to = it->second;
|
const auto& to = it->second;
|
||||||
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
if (auto type_relaxed = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node)) {
|
||||||
type_relaxed->set_overridden_output_type(to);
|
type_relaxed->set_overridden_output_type(to);
|
||||||
type_relaxed->set_origin_input_type(ov::element::boolean, 0);
|
for (size_t i = 0; i < node->get_input_size(); ++i)
|
||||||
type_relaxed->set_origin_input_type(ov::element::boolean, 1);
|
type_relaxed->set_origin_input_type(ov::element::boolean, i);
|
||||||
return true;
|
return true;
|
||||||
} else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
|
} else if (auto casted = std::dynamic_pointer_cast<T>(node)) {
|
||||||
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<T>>(
|
ov::element::TypeVector input_types(node->get_input_size(), ov::element::boolean);
|
||||||
*casted,
|
auto relaxed_op = std::make_shared<ov::op::TypeRelaxed<T>>(*casted, input_types, ov::element::TypeVector{to});
|
||||||
ov::element::TypeVector{ov::element::boolean, ov::element::boolean},
|
|
||||||
ov::element::TypeVector{to});
|
|
||||||
replace_node(node, relaxed_op);
|
replace_node(node, relaxed_op);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -671,6 +671,14 @@ TEST(TransformationTests, ConvertPrecision_LogicalNot) {
|
|||||||
|
|
||||||
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
ASSERT_FALSE(has_type<element::Type_t::boolean>(f));
|
||||||
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
ASSERT_TRUE(has_type<element::Type_t::u8>(f));
|
||||||
|
|
||||||
|
std::shared_ptr<ov::op::TypeRelaxedBase> tr;
|
||||||
|
for (const auto& node : f->get_ordered_ops())
|
||||||
|
if (auto op = std::dynamic_pointer_cast<ov::op::TypeRelaxedBase>(node))
|
||||||
|
tr = op;
|
||||||
|
ASSERT_TRUE(tr != nullptr);
|
||||||
|
ASSERT_EQ(tr->get_origin_input_type(0), element::boolean);
|
||||||
|
ASSERT_EQ(tr->get_origin_input_type(1), element::undefined);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, ConvertPrecision_Select) {
|
TEST(TransformationTests, ConvertPrecision_Select) {
|
||||||
|
|||||||
Reference in New Issue
Block a user