Added support for psycopg3 along with psycopg2. #5011

This commit is contained in:
Khushboo Vashi
2023-02-15 11:31:29 +05:30
committed by Akshay Joshi
parent 7a4951f211
commit 5e0daccf76
635 changed files with 6500 additions and 1534 deletions

View File

@@ -31,7 +31,10 @@ class SQLTemplateTestBase(BaseTestGenerator):
# To be implemented by child classes
pass
def generate_sql(self, version):
def get_server_version(self, connection):
return connection.info.server_version
def generate_sql(self, connection):
# To be implemented by child classes
pass
@@ -50,7 +53,7 @@ class SQLTemplateTestBase(BaseTestGenerator):
cursor = connection.cursor()
self.test_setup(connection, cursor)
sql = self.generate_sql(connection.server_version)
sql = self.generate_sql(connection)
cursor = connection.cursor()
cursor.execute(sql)

View File

@@ -12,7 +12,6 @@ import traceback
import os
import sys
import uuid
import psycopg2
import sqlite3
import shutil
from functools import partial
@@ -36,11 +35,17 @@ import regression
from regression import test_setup
from pgadmin.utils.preferences import Preferences
from pgadmin.utils.constants import BINARY_PATHS
from pgadmin.utils.constants import BINARY_PATHS, PSYCOPG3
from pgadmin.utils import set_binary_path
from functools import wraps
# Remove this condition, once psycopg2 will be removed completely
if config.PG_DEFAULT_DRIVER == PSYCOPG3:
import psycopg
else:
import psycopg2 as psycopg
CURRENT_PATH = os.path.abspath(os.path.join(os.path.dirname(
os.path.realpath(__file__)), "../"))
@@ -51,8 +56,8 @@ file_name = os.path.realpath(__file__)
def get_db_connection(db, username, password, host, port, sslmode="prefer"):
"""This function returns the connection object of psycopg"""
connection = psycopg2.connect(
database=db,
connection = psycopg.connect(
dbname=db,
user=username,
password=password,
host=host,
@@ -62,6 +67,19 @@ def get_db_connection(db, username, password, host, port, sslmode="prefer"):
return connection
def get_server_version(connection):
return connection.info.server_version
def set_isolation_level(connection, level):
if level == 0:
connection.rollback()
connection.autocommit = True
else:
connection.autocommit = False
connection.isolation_level = level
def login_tester_account(tester):
"""
This function login the test client using env variables email and password
@@ -141,7 +159,8 @@ def create_database(server, db_name, encoding=None):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
connection.autocommit = True
pg_cursor = connection.cursor()
if encoding is None:
pg_cursor.execute(
@@ -151,7 +170,8 @@ def create_database(server, db_name, encoding=None):
'''CREATE DATABASE "%s" TEMPLATE template0
ENCODING='%s' LC_COLLATE='%s' LC_CTYPE='%s' ''' %
(db_name, encoding[0], encoding[1], encoding[1]))
connection.set_isolation_level(old_isolation_level)
connection.autocommit = False
set_isolation_level(connection, old_isolation_level)
connection.commit()
# Get 'oid' from newly created database
@@ -208,7 +228,7 @@ def create_table(server, db_name, table_name, extra_columns=[]):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
extra_columns_sql = ", " + ", ".join(extra_columns) \
if len(extra_columns) > 0 else ''
@@ -230,7 +250,7 @@ def create_table(server, db_name, table_name, extra_columns=[]):
VALUES ('Yet-Another-Name', 14,
'cool info')''' % table_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -287,10 +307,10 @@ def create_table_with_query(server, db_name, query):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute(query)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -312,14 +332,14 @@ def create_constraint(server,
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute('''
ALTER TABLE "%s"
ADD CONSTRAINT "%s" %s (some_column)
''' % (table_name, constraint_name, constraint_type.upper()))
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -349,7 +369,7 @@ def create_type(server, db_name, type_name, type_fields=[]):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
type_fields_sql = ", ".join(type_fields)
@@ -357,7 +377,7 @@ def create_type(server, db_name, type_name, type_fields=[]):
pg_cursor.execute(
'''CREATE TYPE %s AS (%s)''' % (type_name, type_fields_sql))
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -375,7 +395,7 @@ def create_debug_function(server, db_name, function_name="test_func"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
try:
pg_cursor.execute('''CREATE EXTENSION pldbgapi;''')
@@ -396,7 +416,7 @@ def create_debug_function(server, db_name, function_name="test_func"):
END;
$function$;
''' % function_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -414,12 +434,12 @@ def drop_debug_function(server, db_name, function_name="test_func"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute('''
DROP FUNCTION public."%s"();
''' % function_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -458,14 +478,14 @@ def grant_role(server, db_name, role_name="test_role",
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
sql_query = '''GRANT "%s" TO %s;''' % (grant_role, role_name)
pg_cursor.execute(
sql_query
)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -483,7 +503,7 @@ def create_role(server, db_name, role_name="test_role"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
sql_query = '''
CREATE USER "%s" WITH
@@ -493,13 +513,13 @@ def create_role(server, db_name, role_name="test_role"):
CREATEDB
NOCREATEROLE
''' % (role_name)
if connection.server_version > 90100:
if get_server_version(connection) > 90100:
sql_query += '\nNOREPLICATION'
pg_cursor.execute(
sql_query
)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -517,12 +537,12 @@ def drop_role(server, db_name, role_name="test_role"):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute('''
DROP USER "%s"
''' % role_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -533,7 +553,7 @@ def drop_database(connection, database_name):
"""This function used to drop the database"""
if database_name not in ["postgres", "template1", "template0"]:
pg_cursor = connection.cursor()
if connection.server_version >= 90100:
if connection.info.server_version >= 90100:
pg_cursor.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
"FROM pg_stat_activity "
@@ -551,9 +571,9 @@ def drop_database(connection, database_name):
" db.datname='%s'" % database_name)
if pg_cursor.fetchall():
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor.execute('''DROP DATABASE "%s"''' % database_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
connection.close()
@@ -563,7 +583,7 @@ def drop_database_multiple(connection, database_names):
for database_name in database_names:
if database_name not in ["postgres", "template1", "template0"]:
pg_cursor = connection.cursor()
if connection.server_version >= 90100:
if get_server_version(connection) >= 90100:
pg_cursor.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
"FROM pg_stat_activity "
@@ -581,9 +601,9 @@ def drop_database_multiple(connection, database_names):
" db.datname='%s'" % database_name)
if pg_cursor.fetchall():
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor.execute('''DROP DATABASE "%s"''' % database_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
connection.close()
@@ -597,9 +617,9 @@ def drop_tablespace(connection):
for table_space in table_spaces:
if table_space[0] not in ["pg_default", "pg_global"]:
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor.execute("DROP TABLESPACE %s" % table_space[0])
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
connection.close()
@@ -1140,11 +1160,11 @@ def create_schema(server, db_name, schema_name):
server['sslmode']
)
old_isolation_level = connection.isolation_level
connection.set_isolation_level(0)
set_isolation_level(connection, 0)
pg_cursor = connection.cursor()
pg_cursor.execute(
'''CREATE SCHEMA "%s"''' % schema_name)
connection.set_isolation_level(old_isolation_level)
set_isolation_level(connection, old_isolation_level)
connection.commit()
except Exception:
@@ -1205,7 +1225,7 @@ def check_binary_path_or_skip_test(cls, utility_name):
def get_driver_version():
version = getattr(psycopg2, '__version__', None)
version = getattr(psycopg, '__version__', None)
return version