Graph comparator: added consumers number check (#15367)
* GraphComparator: add CONSUMERS_COUNT CmpValue * Added tests
This commit is contained in:
parent
8b1b4de21d
commit
b329b005a3
@ -807,6 +807,14 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::
|
||||
err_log << "Different runtime info detected at output(" << i << ")\n"
|
||||
<< name(node1) << " and " << name(node2) << " not equal runtime info." << std::endl;
|
||||
}
|
||||
|
||||
if (should_compare(CmpValues::CONSUMERS_COUNT)) {
|
||||
if (node1->output(i).get_target_inputs().size() != node2->output(i).get_target_inputs().size()) {
|
||||
err_log << "Different consumers number detected\n"
|
||||
<< name(node1) << " Output(" << i << ") " << node1->output(i).get_target_inputs().size() << " and "
|
||||
<< name(node2) << " Output(" << i << ") " << node2->output(i).get_target_inputs().size() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,8 @@ public:
|
||||
ATTRIBUTES = 1 << 5,
|
||||
TENSOR_NAMES = 1 << 6,
|
||||
ACCURACY = 1 << 7,
|
||||
SUBGRAPH_DESCRIPTORS = 1 << 8
|
||||
SUBGRAPH_DESCRIPTORS = 1 << 8,
|
||||
CONSUMERS_COUNT = 1 << 9
|
||||
};
|
||||
|
||||
struct Result {
|
||||
|
@ -620,3 +620,52 @@ TEST(GraphComparatorTests, CheckAccuracyNotEnabled) {
|
||||
auto res = comparator.compare(function, function_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
||||
TEST(GraphComparatorTests, CheckConsumersCountPositive) {
|
||||
FunctionsComparator comparator(FunctionsComparator::no_default());
|
||||
std::shared_ptr<ov::Model> function, function_ref;
|
||||
{
|
||||
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
|
||||
auto constant = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
|
||||
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant);
|
||||
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant);
|
||||
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
|
||||
auto constant = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
|
||||
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant);
|
||||
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant);
|
||||
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
comparator.enable(FunctionsComparator::NODES).enable(FunctionsComparator::CONSUMERS_COUNT);
|
||||
auto res = comparator.compare(function, function_ref);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
}
|
||||
|
||||
TEST(GraphComparatorTests, CheckConsumersCountNegative) {
|
||||
FunctionsComparator comparator(FunctionsComparator::no_default());
|
||||
std::shared_ptr<ov::Model> function, function_ref;
|
||||
{
|
||||
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
|
||||
auto constant = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
|
||||
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant);
|
||||
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant);
|
||||
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
|
||||
auto constant_1 = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
|
||||
auto constant_2 = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
|
||||
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant_1);
|
||||
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant_2);
|
||||
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
comparator.enable(FunctionsComparator::NODES).enable(FunctionsComparator::CONSUMERS_COUNT);
|
||||
auto res = comparator.compare(function, function_ref);
|
||||
ASSERT_FALSE(res.valid) << res.message;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user