Added support for psycopg3 along with psycopg2. #5011

This commit is contained in:
Khushboo Vashi
2023-02-15 11:31:29 +05:30
committed by Akshay Joshi
parent 7a4951f211
commit 5e0daccf76
635 changed files with 6500 additions and 1534 deletions

View File

@@ -31,7 +31,10 @@ class SQLTemplateTestBase(BaseTestGenerator):
# To be implemented by child classes
pass
def generate_sql(self, version):
def get_server_version(self, connection):
return connection.info.server_version
def generate_sql(self, connection):
# To be implemented by child classes
pass
@@ -50,7 +53,7 @@ class SQLTemplateTestBase(BaseTestGenerator):
cursor = connection.cursor()
self.test_setup(connection, cursor)
sql = self.generate_sql(connection.server_version)
sql = self.generate_sql(connection)
cursor = connection.cursor()
cursor.execute(sql)

View File

@@ -12,7 +12,6 @@ import traceback
import os
import sys
import uuid
import psycopg2
import sqlite3
import shutil
from functools import partial
@@ -36,11 +35,17 @@ import regression
from regression import test_setup
from pgadmin.utils.preferences import Preferences
from pgadmin.utils.constants import BINARY_PATHS
from pgadmin.utils.constants import BINARY_PATHS, PSYCOPG3
from pgadmin.utils import set_binary_path
from functools import wraps
# Remove this condition, once psycopg2 will be removed completely
if config.PG_DEFAULT_DRIVER == PSYCOPG3:
import psycopg
else:
import psycopg2 as psycopg
CURRENT_PATH = os.path.abspath(os.path.join(os.path.dirname(
os.path.realpath(__file__)), "../"))
@@ -51,8 +56,8 @@ file_name = os.path.realpath(__file__)
def get_db_connection(db, username, password, host, port, sslmode="prefer"):
"""This function returns the connection object of psycopg"""
connection = psycopg2.connect(
database=db,
connection = psycopg.connect(
dbname=db,
user=username,
password=password,
host=host,
@@ -62,6 +67,19 @@ def get_db_connection(db, username, password, host, port, sslmode="prefer"):
return connection
def get_server_version(connection):
return connection.info.server_version
def set_isolation_level(connection, level):
if level == 0:
connection.rollback()
connection.autocommit = True
else:
connection.autocommit = False
connection.isolation_level = level
def login_tester_account(tester):
"""
This function login the test client using env variables email and password
@@ -141,7 +159,8 @@ def create_database(server, db_name, encoding=None):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
connection.autocommit = True
pg_cursor = connection.cursor()
if encoding is None:
pg_cursor.execute(
@@ -151,7 +170,8 @@ def create_database(server, db_name, encoding=None):
'''CREATE DATABASE "%s" TEMPLATE template0
ENCODING='%s' LC_COLLATE='%s' LC_CTYPE='%s' ''' %
(db_name, encoding[0], encoding[1], encoding[1]))
connection.set_isolation_level(old_isolation_level)
connection.autocommit = False
set_isolation_level(connection, old_isolation_level)
connection.commit()
# Get 'oid' from newly created database
@@ -208,7 +228,7 @@ def create_table(server, db_name, table_name, extra_columns=[]):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
extra_columns_sql = ", " + ", ".join(extra_columns) \
if len(extra_columns) > 0 else ''
@@ -230,7 +250,7 @@ def create_table(server, db_name, table_name, extra_columns=[]):
VALUES ('Yet-Another-Name', 14,
'cool info')''' % table_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -287,10 +307,10 @@ def create_table_with_query(server, db_name, query):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute(query)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -312,14 +332,14 @@ def create_constraint(server,
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute('''
ALTER TABLE "%s"
ADD CONSTRAINT "%s" %s (some_column)
''' % (table_name, constraint_name, constraint_type.upper()))
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -349,7 +369,7 @@ def create_type(server, db_name, type_name, type_fields=[]):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
type_fields_sql = ", ".join(type_fields)
@@ -357,7 +377,7 @@ def create_type(server, db_name, type_name, type_fields=[]):
pg_cursor.execute(
'''CREATE TYPE %s AS (%s)''' % (type_name, type_fields_sql))
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -375,7 +395,7 @@ def create_debug_function(server, db_name, function_name="test_func"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
try:
pg_cursor.execute('''CREATE EXTENSION pldbgapi;''')
@@ -396,7 +416,7 @@ def create_debug_function(server, db_name, function_name="test_func"):
END;
$function$;
''' % function_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -414,12 +434,12 @@ def drop_debug_function(server, db_name, function_name="test_func"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute('''
DROP FUNCTION public."%s"();
''' % function_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -458,14 +478,14 @@ def grant_role(server, db_name, role_name="test_role",
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
sql_query = '''GRANT "%s" TO %s;''' % (grant_role, role_name)
pg_cursor.execute(
sql_query
)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -483,7 +503,7 @@ def create_role(server, db_name, role_name="test_role"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
sql_query = '''
CREATE USER "%s" WITH
@@ -493,13 +513,13 @@ def create_role(server, db_name, role_name="test_role"):
CREATEDB
NOCREATEROLE
''' % (role_name)
if connection.server_version > 90100:
if get_server_version(connection) > 90100:
sql_query += '\nNOREPLICATION'
pg_cursor.execute(
sql_query
)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -517,12 +537,12 @@ def drop_role(server, db_name, role_name="test_role"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute('''
DROP USER "%s"
''' % role_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -533,7 +553,7 @@ def drop_database(connection, database_name):
"""This function used to drop the database"""
if database_name not in ["postgres", "template1", "template0"]:
pg_cursor = connection.cursor()
if connection.server_version >= 90100:
if connection.info.server_version >= 90100:
pg_cursor.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
"FROM pg_stat_activity "
@@ -551,9 +571,9 @@ def drop_database(connection, database_name):
" db.datname='%s'" % database_name)
if pg_cursor.fetchall():
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor.execute('''DROP DATABASE "%s"''' % database_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
connection.close()
@@ -563,7 +583,7 @@ def drop_database_multiple(connection, database_names):
for database_name in database_names:
if database_name not in ["postgres", "template1", "template0"]:
pg_cursor = connection.cursor()
if connection.server_version >= 90100:
if get_server_version(connection) >= 90100:
pg_cursor.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
"FROM pg_stat_activity "
@@ -581,9 +601,9 @@ def drop_database_multiple(connection, database_names):
" db.datname='%s'" % database_name)
if pg_cursor.fetchall():
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor.execute('''DROP DATABASE "%s"''' % database_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
connection.close()
@@ -597,9 +617,9 @@ def drop_tablespace(connection):
for table_space in table_spaces:
if table_space[0] not in ["pg_default", "pg_global"]:
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor.execute("DROP TABLESPACE %s" % table_space[0])
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
connection.close()
@@ -1140,11 +1160,11 @@ def create_schema(server, db_name, schema_name):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute(
'''CREATE SCHEMA "%s"''' % schema_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -1205,7 +1225,7 @@ def check_binary_path_or_skip_test(cls, utility_name):
def get_driver_version():
version = getattr(psycopg2, '__version__', None)
version = getattr(psycopg, '__version__', None)
return version

View File

@@ -13,6 +13,7 @@ import traceback
from urllib.parse import urlencode
from flask import url_for
import regression
import config
from regression import parent_node_dict
from pgadmin.utils.route import BaseTestGenerator
from regression.python_test_utils import test_utils as utils
@@ -20,6 +21,7 @@ from pgadmin.browser.server_groups.servers.databases.tests import \
utils as database_utils
from pgadmin.utils.versioned_template_loader import \
get_version_mapping_directories
from config import PG_DEFAULT_DRIVER
def create_resql_module_list(all_modules, exclude_pkgs, for_modules):
@@ -99,13 +101,15 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
# Schema ID placeholder in JSON file which needs to be replaced
# while running the test cases
self.JSON_PLACEHOLDERS = {'schema_id': '<SCHEMA_ID>',
'owner': '<OWNER>',
'timestamptz_1': '<TIMESTAMPTZ_1>',
'password': '<PASSWORD>',
'pga_job_id': '<PGA_JOB_ID>',
'timestamptz_2': '<TIMESTAMPTZ_2>',
'db_name': '<TEST_DB_NAME>'}
'db_name': '<TEST_DB_NAME>',
'db_driver': '<DB_DRIVER>'}
resql_module_list = create_resql_module_list(
BaseTestGenerator.re_sql_module_list,
@@ -232,6 +236,9 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
elif self.check_precondition(
scenario['precondition_sql'], False):
skip_test_case = False
elif 'pg_driver' in scenario and\
scenario['pg_driver'] != PG_DEFAULT_DRIVER:
skip_test_case = True
else:
skip_test_case = False
@@ -272,6 +279,10 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
try:
self.assertEqual(response.status_code, 200)
except Exception as e:
response = self.tester.post(create_url,
data=json.dumps(
scenario['data']),
content_type='html/json')
self.final_test_status = False
print(scenario['name'] + "... FAIL")
traceback.print_exc()
@@ -307,6 +318,11 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
self.assertEqual(response.status_code, 200)
except Exception as e:
self.final_test_status = False
alter_url = self.get_url(scenario['endpoint'], object_id)
response = self.tester.put(alter_url,
data=json.dumps(
scenario['data']),
follow_redirects=True)
print(scenario['name'] + "... FAIL")
traceback.print_exc()
continue
@@ -493,7 +509,6 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
try:
self.assertEqual(response.status_code, 200)
except Exception as e:
self.final_test_status = False
traceback.print_exc()
return False
@@ -522,6 +537,8 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
try:
self.assertEqual(sql, resp_sql)
except Exception as e:
print(sql)
print(resp_sql)
self.final_test_status = False
traceback.print_exc()
return False
@@ -636,6 +653,8 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
scenario['data'][key_attr]['added'][0]):
self.get_db_connection()
pg_cursor = self.connection.cursor()
pg_cursor.execute("SET DateStyle=ISO;")
try:
if is_tz_columns_list:
query = "SELECT timestamp with time zone '" \

View File

@@ -22,6 +22,10 @@ import secrets
import threading
import time
import unittest
import asyncio
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
if sys.version_info < (3, 4):
raise RuntimeError('The test suite must be run under Python 3.4 or later.')
@@ -495,7 +499,7 @@ def execute_test(test_module_list_passed, server_passed, driver_passed,
)
# Add the server version in server information
server_information['server_version'] = connection.server_version
server_information['server_version'] = connection.info.server_version
server_information['type'] = server_passed['type']
# Drop the database if already exists.