Files
LBPM/threadpool/test/test_atomic_list.cpp
2018-02-06 10:50:43 -05:00

222 lines
7.5 KiB
C++

#include "threadpool/atomic_list.h"
#include "common/UnitTest.h"
#include "common/Utilities.h"
#include <algorithm>
#include <atomic>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
static void modify_list( AtomicList<int, 1024> &list )
{
const int N_count = 50000;
for ( int i = 0; i < N_count; i++ ) {
auto v1 = list.remove_first();
auto v2 = list.remove( []( int ) { return true; } );
auto v3 = list.remove( []( int v ) { return v >= ( rand() / 8 ); } );
auto v4 = list.remove( []( int v ) { return v >= ( rand() / 4 ); } );
auto v5 = list.remove( []( int v ) { return v >= ( rand() / 2 ); } );
if ( v1 != -1 ) {
list.insert( v1 );
}
if ( v2 != -1 ) {
list.insert( v2 );
}
if ( v3 != -1 ) {
list.insert( v3 );
}
if ( v4 != -1 ) {
list.insert( v4 );
}
if ( v5 != -1 ) {
list.insert( v5 );
}
}
}
static bool check_list( const std::vector<int> &x, AtomicList<int, 1024> &list )
{
bool pass = list.check();
pass = pass && (int) x.size() == list.size();
if ( pass ) {
for ( int i : x )
pass = pass && i == list.remove( []( int ) { return true; } );
}
// Restore the list
for ( int i = 0; i < list.size(); i++ )
list.remove_first();
for ( int i : x )
list.insert( i );
return pass;
}
static inline void clear_list( AtomicList<int, 1024> &list )
{
for ( int i = 0; i < list.size(); i++ )
list.remove_first();
}
/******************************************************************
* The main program *
******************************************************************/
int main( int, char *[] )
{
UnitTest ut;
int N_threads = 8; // Number of threads
// Create the list
AtomicList<int, 1024> list( -1 );
if ( list.size() == 0 && list.check() )
ut.passes( "Initialize" );
else
ut.failure( "Initialize" );
// Initialize the list with some empty values
for ( int i = 0; i < 80; i++ )
list.insert( rand() );
list.insert( 2 );
list.insert( 1 );
list.insert( rand() );
// Try to pull off a couple of values
int v1 = list.remove( []( int a ) { return a == 1; } ); // Find the entry with 1
int v2 = list.remove( []( int ) { return true; } ); // Get the first entry
int v3 = list.remove( []( int ) { return false; } ); // Fail to get an entry
if ( v1 == 1 && v2 == 2 && v3 == -1 && list.size() == 81 && list.check() )
ut.passes( "Basic sanity test" );
else
ut.failure( "Basic sanity test" );
// Clear the list
while ( list.remove( []( int ) { return true; } ) != -1 ) {
}
// Create a list of known values
// std::vector<int> data0(512);
std::vector<int> data0( 5 * N_threads );
for ( int &i : data0 )
i = rand();
auto data = data0;
std::sort( data.begin(), data.end() );
// Test the cost to insert
int N_it = 20;
for ( int i = 0; i < list.size(); i++ )
list.remove( []( int ) { return true; } );
std::chrono::duration<double> time;
std::chrono::time_point<std::chrono::high_resolution_clock> start, stop;
time = time.zero();
for ( int it = 0; it < N_it; it++ ) {
clear_list( list );
start = std::chrono::high_resolution_clock::now();
for ( int i : data0 )
list.insert( i );
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf( "insert time/item = %0.0f ns\n", 1e9 * time.count() / ( N_it * data0.size() ) );
// Test the cost to remove (first)
time = time.zero();
for ( int it = 0; it < N_it; it++ ) {
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
for ( size_t i = 0; i < data0.size(); i++ )
list.remove_first();
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf( "remove (first) time/item = %0.0f ns\n", 1e9 * time.count() / ( N_it * data0.size() ) );
// Test the cost to remove (in order)
time = time.zero();
for ( int it = 0; it < N_it; it++ ) {
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
for ( size_t i = 0; i < data0.size(); i++ )
list.remove( []( int ) { return true; } );
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf(
"remove (ordered) time/item = %0.0f ns\n", 1e9 * time.count() / ( N_it * data0.size() ) );
// Test the cost to remove (out order)
time = time.zero();
for ( int it = 0; it < N_it; it++ ) {
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
for ( int tmp : data0 ) {
list.remove( [tmp]( int v ) { return v == tmp; } );
}
stop = std::chrono::high_resolution_clock::now();
time += ( stop - start );
}
printf(
"remove (unordered) time/item = %0.0f ns\n", 1e9 * time.count() / ( N_it * data0.size() ) );
// Read/write to the list and check the results
int64_t N0 = list.N_remove();
check_list( data, list );
start = std::chrono::high_resolution_clock::now();
modify_list( list );
stop = std::chrono::high_resolution_clock::now();
double time_serial = std::chrono::duration<double>( stop - start ).count();
int64_t N1 = list.N_remove();
bool pass = check_list( data, list );
if ( pass )
ut.passes( "Serial get/insert" );
else
ut.failure( "Serial get/insert" );
printf( "serial time = %0.5f s\n", time_serial );
printf( "serial time/item = %0.0f ns\n", 1e9 * time_serial / ( N1 - N0 ) );
// Have multiple threads reading/writing to the list simultaneously
std::vector<std::thread> threads( N_threads );
start = std::chrono::high_resolution_clock::now();
for ( int i = 0; i < N_threads; i++ )
threads[i] = std::thread( modify_list, std::ref( list ) );
for ( int i = 0; i < N_threads; i++ )
threads[i].join();
stop = std::chrono::high_resolution_clock::now();
double time_parallel = std::chrono::duration<double>( stop - start ).count();
int64_t N2 = list.N_remove();
pass = check_list( data, list );
if ( pass )
ut.passes( "Parallel get/insert" );
else
ut.failure( "Parallel get/insert" );
printf( "parallel time = %0.5f s\n", time_parallel );
printf( "parallel time/item = %0.0f ns\n", 1e9 * time_parallel / ( N2 - N1 ) );
// Try to over-fill the list
while ( !list.empty() )
list.remove_first();
for ( int i = 1; i <= list.capacity(); i++ )
list.insert( i );
try {
list.insert( list.capacity() + 1 );
ut.failure( "List overflow" );
} catch ( const std::exception &e ) {
ut.passes( "List overflow" );
} catch ( ... ) {
ut.failure( "List overflow (unknown exception)" );
}
// Finished
ut.report();
auto N_errors = static_cast<int>( ut.NumFailGlobal() );
return N_errors;
}