Support SSL in the regression tests. Fixes #2170

This commit is contained in:
Murtuza Zabuawala
2017-07-18 15:23:11 +01:00
committed by Dave Page
parent 6396b8ce18
commit bab3da24e6
55 changed files with 266 additions and 119 deletions

View File

@@ -25,13 +25,16 @@ SERVER_GROUP = test_setup.config_data['server_group']
file_name = os.path.realpath(__file__)
def get_db_connection(db, username, password, host, port):
def get_db_connection(db, username, password, host, port, sslmode="prefer"):
"""This function returns the connection object of psycopg"""
connection = psycopg2.connect(database=db,
user=username,
password=password,
host=host,
port=port)
connection = psycopg2.connect(
database=db,
user=username,
password=password,
host=host,
port=port,
sslmode=sslmode
)
return connection
@@ -116,7 +119,8 @@ def create_database(server, db_name):
server['username'],
server['db_password'],
server['host'],
server['port'])
server['port'],
server['sslmode'])
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
pg_cursor = connection.cursor()
@@ -154,7 +158,8 @@ def create_table(server, db_name, table_name):
server['username'],
server['db_password'],
server['host'],
server['port'])
server['port'],
server['sslmode'])
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
pg_cursor = connection.cursor()
@@ -191,7 +196,8 @@ def create_table_with_query(server, db_name, query):
server['username'],
server['db_password'],
server['host'],
server['port'])
server['port'],
server['sslmode'])
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
pg_cursor = connection.cursor()
@@ -211,7 +217,8 @@ def create_constraint(
server['username'],
server['db_password'],
server['host'],
server['port'])
server['port'],
server['sslmode'])
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
pg_cursor = connection.cursor()
@@ -234,7 +241,8 @@ def create_debug_function(server, db_name, function_name="test_func"):
server['username'],
server['db_password'],
server['host'],
server['port'])
server['port'],
server['sslmode'])
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
pg_cursor = connection.cursor()
@@ -267,7 +275,8 @@ def drop_debug_function(server, db_name, function_name="test_func"):
server['username'],
server['db_password'],
server['host'],
server['port'])
server['port'],
server['sslmode'])
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
pg_cursor = connection.cursor()
@@ -390,7 +399,8 @@ def create_parent_server_node(server_info):
server_info['username'],
server_info['db_password'],
server_info['host'],
server_info['port'])
server_info['port'],
server_info['sslmode'])
schema = regression.schema_utils.create_schema(connection, schema_name)
add_schema_to_parent_node_dict(srv_id, db_id, schema[0],
@@ -413,7 +423,8 @@ def delete_test_server(tester):
servers_dict['username'],
servers_dict['db_password'],
servers_dict['host'],
servers_dict['port'])
servers_dict['port'],
servers_dict['sslmode'])
database_name = database["db_name"]
# Drop database
drop_database(connection, database_name)
@@ -422,7 +433,8 @@ def delete_test_server(tester):
servers_dict['username'],
servers_dict['db_password'],
servers_dict['host'],
servers_dict['port'])
servers_dict['port'],
servers_dict['sslmode'])
# Delete role
regression.roles_utils.delete_role(connection,
role["role_name"])
@@ -431,7 +443,8 @@ def delete_test_server(tester):
servers_dict['username'],
servers_dict['db_password'],
servers_dict['host'],
servers_dict['port'])
servers_dict['port'],
servers_dict['sslmode'])
# Delete tablespace
regression.tablespace_utils.delete_tablespace(
connection, tablespace["tablespace_name"])
@@ -462,7 +475,7 @@ def get_db_server(sid):
conn = sqlite3.connect(config.TEST_SQLITE_PATH)
cur = conn.cursor()
server = cur.execute('SELECT name, host, port, maintenance_db,'
' username FROM server where id=%s' % sid)
' username, ssl_mode FROM server where id=%s' % sid)
server = server.fetchone()
if server:
name = server[0]
@@ -470,6 +483,7 @@ def get_db_server(sid):
db_port = server[2]
db_name = server[3]
username = server[4]
ssl_mode = server[5]
config_servers = test_setup.config_data['server_credentials']
# Get the db password from config file for appropriate server
db_password = get_db_password(config_servers, name, host, db_port)
@@ -479,7 +493,8 @@ def get_db_server(sid):
username,
db_password,
host,
db_port)
db_port,
ssl_mode)
conn.close()
return connection
@@ -625,19 +640,23 @@ class Database:
def __enter__(self):
self.name = "test_db_{0}".format(str(uuid.uuid4())[0:7])
self.maintenance_connection = get_db_connection(self.server['db'],
self.server[
'username'],
self.server[
'db_password'],
self.server['host'],
self.server['port'])
self.maintenance_connection = get_db_connection(
self.server['db'],
self.server['username'],
self.server['db_password'],
self.server['host'],
self.server['port'],
self.server['sslmode']
)
create_database(self.server, self.name)
self.connection = get_db_connection(self.name,
self.server['username'],
self.server['db_password'],
self.server['host'],
self.server['port'])
self.connection = get_db_connection(
self.name,
self.server['username'],
self.server['db_password'],
self.server['host'],
self.server['port'],
self.server['sslmode']
)
return self.connection, self.name
def __exit__(self, type, value, traceback):