mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2024-12-23 07:34:35 -06:00
835 lines
33 KiB
Python
835 lines
33 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2024, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
import json
|
|
import os
|
|
import re
|
|
import traceback
|
|
from urllib.parse import urlencode
|
|
from flask import url_for
|
|
import regression
|
|
from regression import parent_node_dict
|
|
from pgadmin.utils.route import BaseTestGenerator
|
|
from regression.python_test_utils import test_utils as utils
|
|
from pgadmin.browser.server_groups.servers.databases.tests import \
|
|
utils as database_utils
|
|
from pgadmin.utils.versioned_template_loader import \
|
|
get_version_mapping_directories
|
|
from pgadmin.utils.constants import DBMS_JOB_SCHEDULER_ID
|
|
|
|
|
|
def create_resql_module_list(all_modules, exclude_pkgs, for_modules):
|
|
"""
|
|
This function is used to create the module list for reverse engineered
|
|
SQL by iterating all the modules.
|
|
|
|
:param all_modules: List of all the modules
|
|
:param exclude_pkgs: List of exclude packages
|
|
:param for_modules: Module list
|
|
:return:
|
|
"""
|
|
resql_module_list = dict()
|
|
|
|
for module in all_modules:
|
|
if "tests." in str(module) and not any(str(module).startswith(
|
|
'pgadmin.' + str(exclude_pkg)) for exclude_pkg in exclude_pkgs
|
|
):
|
|
complete_module_name = module.split(".test")
|
|
module_name_list = complete_module_name[0].split(".")
|
|
module_name = module_name_list[len(module_name_list) - 1]
|
|
|
|
if len(for_modules) > 0:
|
|
if module_name in for_modules:
|
|
resql_module_list[module_name] = \
|
|
os.path.join(*module_name_list)
|
|
else:
|
|
resql_module_list[module_name] = \
|
|
os.path.join(*module_name_list)
|
|
|
|
return resql_module_list
|
|
|
|
|
|
class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
|
""" This class will test the reverse engineered SQL"""
|
|
|
|
scenarios = [
|
|
('Reverse Engineered SQL Test Cases', dict())
|
|
]
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.maxDiff = None
|
|
|
|
def setUp(self):
|
|
# Get the database connection
|
|
self.db_con = database_utils.connect_database(
|
|
self, utils.SERVER_GROUP, self.server_information['server_id'],
|
|
self.server_information['db_id'])
|
|
|
|
self.get_db_connection()
|
|
|
|
if not self.db_con['info'] == "Database connected.":
|
|
raise Exception("Could not connect to database.")
|
|
|
|
self.test_config_db_conn = utils.get_db_connection(
|
|
self.server['db'],
|
|
self.server['username'],
|
|
self.server['db_password'],
|
|
self.server['host'],
|
|
self.server['port']
|
|
)
|
|
|
|
# Get the application path
|
|
self.apppath = os.getcwd()
|
|
# Status of the test case
|
|
self.final_test_status = True
|
|
self.parent_ids = dict()
|
|
self.all_object_ids = dict()
|
|
|
|
# Added line break after scenario name
|
|
print("")
|
|
|
|
def runTest(self):
|
|
""" Create the module list on which reverse engineeredsql test
|
|
cases will be executed."""
|
|
|
|
# Schema ID placeholder in JSON file which needs to be replaced
|
|
# while running the test cases
|
|
|
|
self.JSON_PLACEHOLDERS = {'schema_id': '<SCHEMA_ID>',
|
|
'owner': '<OWNER>',
|
|
'timestamptz_1': '<TIMESTAMPTZ_1>',
|
|
'password': '<PASSWORD>',
|
|
'pga_job_id': '<PGA_JOB_ID>',
|
|
'timestamptz_2': '<TIMESTAMPTZ_2>',
|
|
'db_name': '<TEST_DB_NAME>',
|
|
'db_driver': '<DB_DRIVER>',
|
|
'LC_COLLATE': '<LC_COLLATE>',
|
|
'LC_CTYPE': '<LC_CTYPE>'}
|
|
|
|
resql_module_list = create_resql_module_list(
|
|
BaseTestGenerator.re_sql_module_list,
|
|
BaseTestGenerator.exclude_pkgs,
|
|
getattr(BaseTestGenerator, 'for_modules', []))
|
|
|
|
for module in resql_module_list:
|
|
module_path = resql_module_list[module]
|
|
# Get the folder name based on server version number and
|
|
# their existence.
|
|
self.module_path = module_path
|
|
status, self.test_folder = self.get_test_folder(module_path)
|
|
if not status:
|
|
continue
|
|
|
|
# Iterate all the files in the test folder and check for
|
|
# the JSON files.
|
|
for filename in os.listdir(self.test_folder):
|
|
if filename.endswith(".json"):
|
|
complete_file_name = os.path.join(self.test_folder,
|
|
filename)
|
|
with open(complete_file_name) as jsonfp:
|
|
try:
|
|
data = json.load(jsonfp)
|
|
except Exception as e:
|
|
print(
|
|
"Unable to read the json file: {0}".format(
|
|
complete_file_name))
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
for key, scenarios in data.items():
|
|
self.execute_test_case(scenarios)
|
|
|
|
# Clear the parent ids stored for one json file.
|
|
self.parent_ids.clear()
|
|
self.all_object_ids.clear()
|
|
|
|
# Check the final status of the test case
|
|
self.assertEqual(self.final_test_status, True)
|
|
|
|
def tearDown(self):
|
|
database_utils.disconnect_database(
|
|
self, self.server_information['server_id'],
|
|
self.server_information['db_id'])
|
|
self.test_config_db_conn.close()
|
|
|
|
def get_db_connection(self):
|
|
"""Get the database connection."""
|
|
self.database_info = parent_node_dict["database"][-1]
|
|
self.db_name = self.database_info["db_name"]
|
|
|
|
if (not hasattr(self, 'connection')) or \
|
|
(hasattr(self, 'connection') and self.connection.closed == 1):
|
|
self.connection = utils.get_db_connection(
|
|
self.db_name,
|
|
self.server['username'],
|
|
self.server['db_password'],
|
|
self.server['host'],
|
|
self.server['port']
|
|
)
|
|
|
|
def get_url(self, endpoint, object_id=None):
|
|
"""
|
|
This function is used to get the url.
|
|
|
|
:param endpoint:
|
|
:param object_id:
|
|
:return:
|
|
"""
|
|
object_url = None
|
|
for rule in self.app.url_map.iter_rules(endpoint):
|
|
options = {}
|
|
for arg in rule.arguments:
|
|
if arg == 'gid':
|
|
options['gid'] = int(utils.SERVER_GROUP)
|
|
elif arg == 'sid':
|
|
options['sid'] = int(self.server_information['server_id'])
|
|
elif arg == 'did':
|
|
# For database node object_id is the actual database id.
|
|
if endpoint.__contains__('NODE-database') and \
|
|
object_id is not None:
|
|
options['did'] = int(object_id)
|
|
else:
|
|
options['did'] = int(self.server_information['db_id'])
|
|
elif arg == 'scid':
|
|
# For schema node object_id is the actual schema id.
|
|
if endpoint.__contains__('NODE-schema') and \
|
|
object_id is not None:
|
|
options['scid'] = int(object_id)
|
|
else:
|
|
options['scid'] = int(self.schema_id)
|
|
# tid represents table oid
|
|
elif arg == 'tid' and 'tid' in self.parent_ids:
|
|
options['tid'] = int(self.parent_ids['tid'])
|
|
# fid represents FDW oid
|
|
elif arg == 'fid' and 'fid' in self.parent_ids:
|
|
options['fid'] = int(self.parent_ids['fid'])
|
|
# fsid represents Foreign Server oid
|
|
elif arg == 'fsid' and 'fsid' in self.parent_ids:
|
|
options['fsid'] = int(self.parent_ids['fsid'])
|
|
elif arg == 'jsid':
|
|
options['jsid'] = DBMS_JOB_SCHEDULER_ID
|
|
else:
|
|
if object_id is not None:
|
|
try:
|
|
options[arg] = int(object_id)
|
|
except ValueError:
|
|
options[arg] = object_id
|
|
|
|
with self.app.test_request_context():
|
|
object_url = url_for(rule.endpoint, **options)
|
|
|
|
return object_url
|
|
|
|
def execute_test_case(self, scenarios):
|
|
"""
|
|
This function will run the test cases for specific module.
|
|
|
|
:param scenarios: List of scenarios
|
|
:return:
|
|
"""
|
|
object_id = None
|
|
|
|
for scenario in scenarios:
|
|
skip_test_case = True
|
|
if 'precondition_sql' in scenario:
|
|
if 'pgagent_test' in scenario and self.check_precondition(
|
|
scenario['precondition_sql'], True):
|
|
skip_test_case = False
|
|
elif self.check_precondition(
|
|
scenario['precondition_sql'], False):
|
|
skip_test_case = False
|
|
else:
|
|
skip_test_case = False
|
|
|
|
if skip_test_case:
|
|
print(scenario['name'] +
|
|
"... skipped (pre-condition SQL not satisfied)")
|
|
continue
|
|
|
|
# Check precondition for schema
|
|
self.check_schema_precondition(scenario)
|
|
|
|
# Preprocessed data to replace any place holder if available
|
|
if 'preprocess_data' in scenario and \
|
|
scenario['preprocess_data'] and 'data' in scenario:
|
|
scenario['data'] = self.preprocess_data(scenario['data'])
|
|
|
|
# If msql_endpoint exists then validate the modified sql
|
|
if 'msql_endpoint' in scenario\
|
|
and scenario['msql_endpoint']:
|
|
if not self.check_msql(scenario, object_id):
|
|
print_msg = scenario['name']
|
|
if 'expected_msql_file' in scenario:
|
|
print_msg += " Expected MSQL File:" + scenario[
|
|
'expected_msql_file']
|
|
print_msg = print_msg + "... FAIL"
|
|
print(print_msg)
|
|
continue
|
|
else:
|
|
print(scenario['name'] + " (MSQL) ... ok")
|
|
|
|
try:
|
|
if 'type' in scenario and scenario['type'] == 'create':
|
|
# Get the url and create the specific node.
|
|
|
|
create_url = self.get_url(scenario['endpoint'])
|
|
response = self.tester.post(
|
|
create_url, data=json.dumps(scenario['data']),
|
|
content_type='html/json')
|
|
try:
|
|
self.assertEqual(response.status_code, 200)
|
|
except Exception as e:
|
|
response = self.tester.post(create_url,
|
|
data=json.dumps(
|
|
scenario['data']),
|
|
content_type='html/json')
|
|
self.final_test_status = False
|
|
print(scenario['name'] + "... FAIL")
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
resp_data = json.loads(response.data.decode('utf8'))
|
|
print('object_id set', object_id)
|
|
object_id = resp_data['node']['_id']
|
|
|
|
# Store the object id based on endpoints
|
|
if 'store_object_id' in scenario:
|
|
self.store_object_ids(object_id,
|
|
scenario['data'],
|
|
scenario['endpoint'])
|
|
|
|
# Compare the reverse engineering SQL
|
|
if not self.check_re_sql(scenario, object_id):
|
|
print(scenario['name'] + "... FAIL")
|
|
|
|
if 'expected_sql_file' in scenario:
|
|
print_msg = " - Expected SQL File: " + \
|
|
os.path.join(
|
|
self.test_folder,
|
|
scenario['expected_sql_file'])
|
|
print(print_msg)
|
|
continue
|
|
elif 'type' in scenario and scenario['type'] == 'alter':
|
|
# Get the url and create the specific node.
|
|
|
|
alter_url = self.get_url(scenario['endpoint'], object_id)
|
|
response = self.tester.put(
|
|
alter_url, data=json.dumps(scenario['data']),
|
|
follow_redirects=True)
|
|
try:
|
|
self.assertEqual(response.status_code, 200)
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
alter_url = self.get_url(
|
|
scenario['endpoint'], object_id)
|
|
response = self.tester.put(alter_url,
|
|
data=json.dumps(
|
|
scenario['data']),
|
|
follow_redirects=True)
|
|
print(scenario['name'] + "... FAIL")
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
resp_data = json.loads(response.data.decode('utf8'))
|
|
object_id = resp_data['node']['_id']
|
|
|
|
# Compare the reverse engineering SQL
|
|
if not self.check_re_sql(scenario, object_id):
|
|
print_msg = scenario['name']
|
|
if 'expected_sql_file' in scenario:
|
|
print_msg = \
|
|
print_msg + " Expected SQL File:" + \
|
|
scenario['expected_sql_file']
|
|
print_msg = print_msg + "... FAIL"
|
|
print(print_msg)
|
|
continue
|
|
elif 'type' in scenario and scenario['type'] == 'delete':
|
|
# Get the delete url and delete the object created above.
|
|
delete_url = self.get_url(scenario['endpoint'], object_id)
|
|
delete_response = self.tester.delete(
|
|
delete_url, data=json.dumps(scenario.get('data', {})),
|
|
follow_redirects=True)
|
|
try:
|
|
self.assertEqual(delete_response.status_code, 200)
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
print(scenario['name'] + "... FAIL")
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
print(scenario['name'] + "... ok")
|
|
except Exception as _:
|
|
print(scenario['name'] + "... FAIL")
|
|
raise
|
|
|
|
def get_test_folder(self, module_path):
|
|
"""
|
|
This function will get the appropriate test folder based on
|
|
server version and their existence.
|
|
|
|
:param module_path: Path of the module to be tested.
|
|
:return:
|
|
"""
|
|
# Join the application path, module path and tests folder
|
|
tests_folder_path = os.path.join(self.apppath, module_path, 'tests')
|
|
|
|
# A folder name matching the Server Type (pg, ppas) takes priority so
|
|
# check whether that exists or not. If so, than check the version
|
|
# folder in it, else look directly in the 'tests' folder.
|
|
absolute_path = os.path.join(tests_folder_path, self.server['type'])
|
|
if not os.path.exists(absolute_path):
|
|
absolute_path = tests_folder_path
|
|
|
|
# Iterate the version mapping directories.
|
|
for version_mapping in get_version_mapping_directories():
|
|
if version_mapping['number'] > \
|
|
self.server_information['server_version']:
|
|
continue
|
|
|
|
complete_path = os.path.join(absolute_path,
|
|
version_mapping['name'])
|
|
|
|
if os.path.exists(complete_path):
|
|
return True, complete_path
|
|
|
|
return False, None
|
|
|
|
def get_test_file(self, file_name):
|
|
"""
|
|
This function will get the appropriate test file based on
|
|
server version and their existence.
|
|
|
|
:param file_name: File containing expected output .
|
|
:return:
|
|
"""
|
|
# Join the application path, module path and tests folder
|
|
tests_folder_path = \
|
|
os.path.join(self.apppath, self.module_path, 'tests')
|
|
|
|
# A folder name matching the Server Type (pg, ppas) takes priority so
|
|
# check whether that exists or not. If so, than check the version
|
|
# folder in it, else look directly in the 'tests' folder.
|
|
absolute_path = os.path.join(tests_folder_path, self.server['type'])
|
|
if not os.path.exists(absolute_path):
|
|
absolute_path = tests_folder_path
|
|
|
|
# Iterate the version mapping directories.
|
|
for version_mapping in get_version_mapping_directories():
|
|
if version_mapping['number'] > \
|
|
self.server_information['server_version']:
|
|
continue
|
|
complete_path = os.path.join(absolute_path,
|
|
version_mapping['name'], file_name)
|
|
|
|
if os.path.exists(complete_path):
|
|
return True, complete_path
|
|
|
|
return False, None
|
|
|
|
def check_msql(self, scenario, object_id):
|
|
"""
|
|
This function is used to check the modified SQL.
|
|
:param scenario:
|
|
:param object_id:
|
|
:return:
|
|
"""
|
|
|
|
msql_url = self.get_url(scenario['msql_endpoint'],
|
|
object_id)
|
|
|
|
# As msql data is passed as URL params, dict, list types data has to
|
|
# be converted to string using json.dumps before passing it to
|
|
# urlencode
|
|
msql_data = {
|
|
key: json.dumps(val)
|
|
if isinstance(val, dict) or isinstance(val, list) else
|
|
(val if val is not None else 'null')
|
|
for key, val in scenario['data'].items()}
|
|
|
|
params = urlencode(msql_data)
|
|
params = params.replace('False', 'false').replace('True', 'true')
|
|
url = msql_url + "?%s" % params
|
|
response = self.tester.get(url,
|
|
follow_redirects=True)
|
|
try:
|
|
self.assertEqual(response.status_code, 200)
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
print(scenario['name'] + "... FAIL")
|
|
traceback.print_exc()
|
|
return False
|
|
try:
|
|
if isinstance(response.data, bytes):
|
|
response_data = response.data.decode('utf8')
|
|
resp = json.loads(response_data)
|
|
else:
|
|
resp = json.loads(response.data)
|
|
resp_sql = resp['data']
|
|
except Exception:
|
|
print("Unable to decode the response data from url: ", url)
|
|
return False
|
|
|
|
# Remove first and last double quotes
|
|
if resp_sql.startswith('"') and resp_sql.endswith('"'):
|
|
resp_sql = resp_sql[1:-1]
|
|
|
|
# Remove triling \n
|
|
resp_sql = resp_sql.rstrip()
|
|
|
|
# Check if expected sql is given in JSON file or path of the output
|
|
# file is given
|
|
if 'expected_msql_file' in scenario:
|
|
file_found, output_file = \
|
|
self.get_test_file(scenario['expected_msql_file'])
|
|
|
|
if file_found:
|
|
fp = open(output_file, "r")
|
|
# Used rstrip to remove trailing \n
|
|
sql = fp.read().rstrip()
|
|
sql = self.preprocess_expected_sql(scenario, sql, resp_sql,
|
|
object_id)
|
|
try:
|
|
self.assertEqual(sql, resp_sql)
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
traceback.print_exc()
|
|
return False
|
|
else:
|
|
try:
|
|
self.assertFalse("Expected Modified SQL File not found")
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
traceback.print_exc()
|
|
return False
|
|
return True
|
|
|
|
def check_re_sql(self, scenario, object_id):
|
|
"""
|
|
This function is used to get the reverse engineered SQL.
|
|
:param scenario:
|
|
:param object_id:
|
|
:return:
|
|
"""
|
|
|
|
sql_url = self.get_url(scenario['sql_endpoint'], object_id)
|
|
response = self.tester.get(sql_url)
|
|
|
|
try:
|
|
self.assertEqual(response.status_code, 200)
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
resp_sql = response.data.decode('unicode_escape')
|
|
|
|
# Remove first and last double quotes
|
|
if resp_sql.startswith('"') and resp_sql.endswith('"'):
|
|
resp_sql = resp_sql[1:-1]
|
|
|
|
# Remove triling \n
|
|
resp_sql = resp_sql.rstrip()
|
|
|
|
# Check if expected sql is given in JSON file or path of the output
|
|
# file is given
|
|
if 'expected_sql_file' in scenario:
|
|
file_found, output_file = \
|
|
self.get_test_file(scenario['expected_sql_file'])
|
|
|
|
if os.path.exists(output_file):
|
|
fp = open(output_file, "r")
|
|
# Used rstrip to remove trailing \n
|
|
sql = fp.read().rstrip()
|
|
sql = self.preprocess_expected_sql(scenario, sql, resp_sql,
|
|
object_id)
|
|
try:
|
|
self.assertEqual(sql, resp_sql)
|
|
except Exception as e:
|
|
print(sql)
|
|
print(resp_sql)
|
|
self.final_test_status = False
|
|
traceback.print_exc()
|
|
return False
|
|
else:
|
|
try:
|
|
self.assertFalse("Expected SQL File not found")
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
traceback.print_exc()
|
|
return False
|
|
elif 'expected_sql' in scenario:
|
|
exp_sql = scenario['expected_sql']
|
|
exp_sql = self.preprocess_expected_sql(scenario, exp_sql, resp_sql,
|
|
object_id)
|
|
try:
|
|
self.assertEqual(exp_sql, resp_sql)
|
|
except Exception as e:
|
|
self.final_test_status = False
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
return True
|
|
|
|
def check_precondition(self, precondition_sql, use_test_config_db_conn):
|
|
"""
|
|
This method executes precondition_sql and returns appropriate result
|
|
:param precondition_sql: SQL query in format select count(*) from ...
|
|
:return: True/False depending on precondition_sql result
|
|
"""
|
|
precondition_flag = False
|
|
if not use_test_config_db_conn:
|
|
self.get_db_connection()
|
|
pg_cursor = self.connection.cursor()
|
|
else:
|
|
pg_cursor = self.test_config_db_conn.cursor()
|
|
|
|
try:
|
|
pg_cursor.execute(precondition_sql)
|
|
precondition_result = pg_cursor.fetchone()
|
|
if len(precondition_result) >= 1 and precondition_result[0] == '1':
|
|
precondition_flag = True
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
pg_cursor.close()
|
|
return precondition_flag
|
|
|
|
def check_schema_precondition(self, scenario):
|
|
"""
|
|
This function will check the given schema is exist or not. If exist
|
|
then fetch the oid and if not then create it.
|
|
|
|
:param scenario:
|
|
:return:
|
|
"""
|
|
if 'type' in scenario and scenario['type'] == 'create':
|
|
# Get the url and create the specific node.
|
|
|
|
if 'data' in scenario and 'schema' in scenario['data']:
|
|
# If schema is already exist then fetch the oid
|
|
self.get_db_connection()
|
|
schema = regression.schema_utils.verify_schemas(
|
|
self.server, self.db_name,
|
|
scenario['data']['schema']
|
|
)
|
|
|
|
if schema:
|
|
self.schema_id = schema[0]
|
|
else:
|
|
# If schema doesn't exist then create it
|
|
schema = regression.schema_utils.create_schema(
|
|
self.connection,
|
|
scenario['data']['schema'])
|
|
self.schema_id = schema[0]
|
|
else:
|
|
self.schema_id = self.server_information['schema_id']
|
|
|
|
if 'data' in scenario and 'schema_id' in scenario['data'] and \
|
|
scenario['data']['schema_id'] == \
|
|
self.JSON_PLACEHOLDERS['schema_id']:
|
|
scenario['data']['schema'] = self.schema_id
|
|
|
|
def convert_timestamptz(self, scenario, sql):
|
|
"""
|
|
This function will convert the given timestamptz with database
|
|
servers timestamptz and replace that in given sql.
|
|
:param scenario:
|
|
:param sql:
|
|
:return:
|
|
"""
|
|
if 'convert_timestamp_columns' in scenario:
|
|
col_list = list()
|
|
key_attr = ''
|
|
is_tz_columns_list = False
|
|
tz_index = 0
|
|
if isinstance(scenario['convert_timestamp_columns'], dict):
|
|
for key, value in scenario[
|
|
'convert_timestamp_columns'].items():
|
|
col_list = scenario['convert_timestamp_columns'][key]
|
|
key_attr = key
|
|
break
|
|
else:
|
|
col_list = scenario['convert_timestamp_columns']
|
|
is_tz_columns_list = True
|
|
|
|
for col in col_list:
|
|
if ('data' in scenario and col in scenario['data']) or \
|
|
(key_attr and 'data' in scenario and 'type' in
|
|
scenario and scenario['type'] == 'create' and col in
|
|
scenario['data'][key_attr][0]) or \
|
|
(key_attr and 'data' in scenario and 'type' in
|
|
scenario and scenario['type'] == 'alter' and col in
|
|
scenario['data'][key_attr]['added'][0]):
|
|
self.get_db_connection()
|
|
pg_cursor = self.connection.cursor()
|
|
pg_cursor.execute("SET DateStyle=ISO;")
|
|
|
|
try:
|
|
if is_tz_columns_list:
|
|
query = "SELECT timestamp with time zone '" \
|
|
+ scenario['data'][col] + "'"
|
|
elif scenario['type'] == 'create':
|
|
query = "SELECT timestamp with time zone '" \
|
|
+ scenario['data'][key_attr][0][col] + "'"
|
|
else:
|
|
query = "SELECT timestamp with time zone '" \
|
|
+ scenario['data'][key_attr][
|
|
'added'][0][col] + "'"
|
|
|
|
pg_cursor.execute(query)
|
|
converted_tz = pg_cursor.fetchone()
|
|
if len(converted_tz) >= 1:
|
|
tz_index = tz_index + 1
|
|
tz_str = "timestamptz_{0}".format(tz_index)
|
|
sql = sql.replace(
|
|
self.JSON_PLACEHOLDERS[tz_str],
|
|
converted_tz[0])
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
pg_cursor.close()
|
|
|
|
return sql
|
|
|
|
def store_object_ids(self, object_id, object_data, endpoint):
|
|
"""
|
|
This functions will store the object id based on endpoints
|
|
:param object_id: Object id of the created node
|
|
:param object_name: Object name
|
|
:param endpoint:
|
|
:return:
|
|
"""
|
|
object_name = object_data.get('name', '')
|
|
if endpoint.__contains__("NODE-table"):
|
|
self.parent_ids['tid'] = object_id
|
|
elif endpoint.__contains__("NODE-foreign_data_wrapper"):
|
|
self.parent_ids['fid'] = object_id
|
|
elif endpoint.__contains__("NODE-foreign_server"):
|
|
self.parent_ids['fsid'] = object_id
|
|
elif endpoint.__contains__("NODE-role.obj"):
|
|
object_name = object_data['rolname']
|
|
elif endpoint.__contains__("NODE-foreign_table"):
|
|
self.parent_ids['tid'] = object_id
|
|
|
|
# Store object id with object name
|
|
self.all_object_ids[object_name] = object_id
|
|
|
|
def preprocess_data(self, data):
|
|
"""
|
|
This function iterate through data and check for any place holder
|
|
starts with '<' and ends with '>' and replace with respective object
|
|
ids.
|
|
:param data: Data
|
|
:return:
|
|
"""
|
|
|
|
if isinstance(data, dict):
|
|
for key, val in data.items():
|
|
if isinstance(val, dict) or isinstance(val, list):
|
|
data[key] = self.preprocess_data(val)
|
|
else:
|
|
data[key] = self.replace_placeholder_with_id(val)
|
|
elif isinstance(data, list):
|
|
ret_list = []
|
|
for item in data:
|
|
if isinstance(item, dict) or isinstance(item, list):
|
|
ret_list.append(self.preprocess_data(item))
|
|
else:
|
|
ret_list.append(self.replace_placeholder_with_id(item))
|
|
return ret_list
|
|
|
|
return data
|
|
|
|
def preprocess_expected_sql(self, scenario, sql, resp_sql, object_id):
|
|
"""
|
|
This function preprocesses expected sql before comparing
|
|
it with response sql.
|
|
:param data: sql
|
|
:param data: resp_sql
|
|
:return:
|
|
"""
|
|
# Replace place holder <owner> with the current username
|
|
# used to connect to the database
|
|
if 'username' in self.server:
|
|
sql = sql.replace(self.JSON_PLACEHOLDERS['owner'],
|
|
self.server['username'])
|
|
# Convert timestamp with timezone from json file to the
|
|
# database server's correct timestamp
|
|
sql = self.convert_timestamptz(scenario, sql)
|
|
|
|
# extract password fields from response and replace in expected
|
|
# to match the response
|
|
if 'replace_password' in scenario:
|
|
password = ''
|
|
for line in resp_sql.split('\n'):
|
|
if 'PASSWORD' in line:
|
|
found = re.search(r"'([\w\W]*)'", line)
|
|
if found:
|
|
password = found.groups(0)[0]
|
|
break
|
|
|
|
sql = sql.replace(self.JSON_PLACEHOLDERS['password'], password)
|
|
|
|
if 'replace_regex_pattern' in scenario:
|
|
for a_patten in scenario['replace_regex_pattern']:
|
|
found = re.findall(a_patten, resp_sql)
|
|
if len(found) > 0:
|
|
sql = re.sub(a_patten, found[0], sql)
|
|
|
|
# Replace place holder <owner> with the current username
|
|
# used to connect to the database
|
|
if 'pga_job_id' in scenario:
|
|
sql = sql.replace(self.JSON_PLACEHOLDERS['pga_job_id'],
|
|
str(object_id))
|
|
|
|
if 'TEST_DB_NAME' in scenario:
|
|
sql = sql.replace(self.JSON_PLACEHOLDERS['db_name'],
|
|
self.server_information['test_db_name'])
|
|
|
|
# get the database connection
|
|
if 'REPLACE_LOCALE' in scenario:
|
|
self.get_db_connection()
|
|
pg_cursor = self.connection.cursor()
|
|
|
|
db_name = self.server_information['test_db_name']
|
|
# Database name if specify in scenario
|
|
if 'data' in scenario and 'name' in scenario['data'] and \
|
|
'db' in self.server:
|
|
db_name = self.server['db']
|
|
|
|
# Fetch the lc_collate and lc_ctype
|
|
pg_cursor.execute(
|
|
"SELECT datcollate as cname FROM pg_database WHERE datname = "
|
|
"'{0}'".format(db_name))
|
|
lc_collate = ''.join(pg_cursor.fetchone())
|
|
pg_cursor.execute(
|
|
"SELECT datctype as cname FROM pg_database WHERE datname = "
|
|
"'{0}'".format(db_name))
|
|
lc_ctype = ''.join(pg_cursor.fetchone())
|
|
pg_cursor.close()
|
|
|
|
sql = sql.replace(self.JSON_PLACEHOLDERS['LC_COLLATE'], lc_collate)
|
|
sql = sql.replace(self.JSON_PLACEHOLDERS['LC_CTYPE'], lc_ctype)
|
|
|
|
return sql
|
|
|
|
def replace_placeholder_with_id(self, value):
|
|
"""
|
|
This function is used to replace the place holder with id.
|
|
:param value:
|
|
:return:
|
|
"""
|
|
|
|
if isinstance(value, str) and \
|
|
value.startswith('<') and value.endswith('>'):
|
|
# Remove < and > from the string
|
|
temp_value = value[1:-1]
|
|
# Find the place holder OID in dictionary
|
|
if temp_value in self.all_object_ids:
|
|
return self.all_object_ids[temp_value]
|
|
return value
|