Graph comparator: added consumers number check (#15367)

* GraphComparator: add CONSUMERS_COUNT CmpValue

* Added tests
This commit is contained in:
Vladislav Golubev 2023-02-10 16:36:54 +01:00 committed by GitHub
parent 8b1b4de21d
commit b329b005a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 1 deletions

View File

@ -807,6 +807,14 @@ void Comparator::compare_outputs(ngraph::Node* node1, ngraph::Node* node2, std::
err_log << "Different runtime info detected at output(" << i << ")\n"
<< name(node1) << " and " << name(node2) << " not equal runtime info." << std::endl;
}
if (should_compare(CmpValues::CONSUMERS_COUNT)) {
if (node1->output(i).get_target_inputs().size() != node2->output(i).get_target_inputs().size()) {
err_log << "Different consumers number detected\n"
<< name(node1) << " Output(" << i << ") " << node1->output(i).get_target_inputs().size() << " and "
<< name(node2) << " Output(" << i << ") " << node2->output(i).get_target_inputs().size() << std::endl;
}
}
}
}

View File

@ -28,7 +28,8 @@ public:
ATTRIBUTES = 1 << 5,
TENSOR_NAMES = 1 << 6,
ACCURACY = 1 << 7,
SUBGRAPH_DESCRIPTORS = 1 << 8
SUBGRAPH_DESCRIPTORS = 1 << 8,
CONSUMERS_COUNT = 1 << 9
};
struct Result {

View File

@ -620,3 +620,52 @@ TEST(GraphComparatorTests, CheckAccuracyNotEnabled) {
auto res = comparator.compare(function, function_ref);
ASSERT_TRUE(res.valid) << res.message;
}
TEST(GraphComparatorTests, CheckConsumersCountPositive) {
FunctionsComparator comparator(FunctionsComparator::no_default());
std::shared_ptr<ov::Model> function, function_ref;
{
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
auto constant = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant);
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant);
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
}
{
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
auto constant = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant);
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant);
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
}
comparator.enable(FunctionsComparator::NODES).enable(FunctionsComparator::CONSUMERS_COUNT);
auto res = comparator.compare(function, function_ref);
ASSERT_TRUE(res.valid) << res.message;
}
TEST(GraphComparatorTests, CheckConsumersCountNegative) {
FunctionsComparator comparator(FunctionsComparator::no_default());
std::shared_ptr<ov::Model> function, function_ref;
{
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
auto constant = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant);
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant);
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
}
{
auto input = std::make_shared<ov::opset8::Parameter>(ngraph::element::i64, ov::Shape{1});
auto constant_1 = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
auto constant_2 = ov::opset8::Constant::create(ngraph::element::i64, {1}, {0});
auto add_1 = std::make_shared<ov::opset8::Add>(input, constant_1);
auto add_2 = std::make_shared<ov::opset8::Add>(input, constant_2);
auto mul = std::make_shared<ov::opset8::Multiply>(add_1, add_2);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ mul }, ngraph::ParameterVector{ input });
}
comparator.enable(FunctionsComparator::NODES).enable(FunctionsComparator::CONSUMERS_COUNT);
auto res = comparator.compare(function, function_ref);
ASSERT_FALSE(res.valid) << res.message;
}