mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2025-02-25 18:55:31 -06:00
Added support for psycopg3 along with psycopg2. #5011
This commit is contained in:
committed by
Akshay Joshi
parent
7a4951f211
commit
5e0daccf76
@@ -31,7 +31,10 @@ class SQLTemplateTestBase(BaseTestGenerator):
|
||||
# To be implemented by child classes
|
||||
pass
|
||||
|
||||
def generate_sql(self, version):
|
||||
def get_server_version(self, connection):
|
||||
return connection.info.server_version
|
||||
|
||||
def generate_sql(self, connection):
|
||||
# To be implemented by child classes
|
||||
pass
|
||||
|
||||
@@ -50,7 +53,7 @@ class SQLTemplateTestBase(BaseTestGenerator):
|
||||
cursor = connection.cursor()
|
||||
self.test_setup(connection, cursor)
|
||||
|
||||
sql = self.generate_sql(connection.server_version)
|
||||
sql = self.generate_sql(connection)
|
||||
|
||||
cursor = connection.cursor()
|
||||
cursor.execute(sql)
|
||||
|
||||
@@ -12,7 +12,6 @@ import traceback
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import psycopg2
|
||||
import sqlite3
|
||||
import shutil
|
||||
from functools import partial
|
||||
@@ -36,11 +35,17 @@ import regression
|
||||
from regression import test_setup
|
||||
|
||||
from pgadmin.utils.preferences import Preferences
|
||||
from pgadmin.utils.constants import BINARY_PATHS
|
||||
from pgadmin.utils.constants import BINARY_PATHS, PSYCOPG3
|
||||
from pgadmin.utils import set_binary_path
|
||||
|
||||
from functools import wraps
|
||||
|
||||
# Remove this condition, once psycopg2 will be removed completely
|
||||
if config.PG_DEFAULT_DRIVER == PSYCOPG3:
|
||||
import psycopg
|
||||
else:
|
||||
import psycopg2 as psycopg
|
||||
|
||||
CURRENT_PATH = os.path.abspath(os.path.join(os.path.dirname(
|
||||
os.path.realpath(__file__)), "../"))
|
||||
|
||||
@@ -51,8 +56,8 @@ file_name = os.path.realpath(__file__)
|
||||
|
||||
def get_db_connection(db, username, password, host, port, sslmode="prefer"):
|
||||
"""This function returns the connection object of psycopg"""
|
||||
connection = psycopg2.connect(
|
||||
database=db,
|
||||
connection = psycopg.connect(
|
||||
dbname=db,
|
||||
user=username,
|
||||
password=password,
|
||||
host=host,
|
||||
@@ -62,6 +67,19 @@ def get_db_connection(db, username, password, host, port, sslmode="prefer"):
|
||||
return connection
|
||||
|
||||
|
||||
def get_server_version(connection):
|
||||
return connection.info.server_version
|
||||
|
||||
|
||||
def set_isolation_level(connection, level):
|
||||
if level == 0:
|
||||
connection.rollback()
|
||||
connection.autocommit = True
|
||||
else:
|
||||
connection.autocommit = False
|
||||
connection.isolation_level = level
|
||||
|
||||
|
||||
def login_tester_account(tester):
|
||||
"""
|
||||
This function login the test client using env variables email and password
|
||||
@@ -141,7 +159,8 @@ def create_database(server, db_name, encoding=None):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
connection.autocommit = True
|
||||
pg_cursor = connection.cursor()
|
||||
if encoding is None:
|
||||
pg_cursor.execute(
|
||||
@@ -151,7 +170,8 @@ def create_database(server, db_name, encoding=None):
|
||||
'''CREATE DATABASE "%s" TEMPLATE template0
|
||||
ENCODING='%s' LC_COLLATE='%s' LC_CTYPE='%s' ''' %
|
||||
(db_name, encoding[0], encoding[1], encoding[1]))
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
connection.autocommit = False
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
# Get 'oid' from newly created database
|
||||
@@ -208,7 +228,7 @@ def create_table(server, db_name, table_name, extra_columns=[]):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
|
||||
extra_columns_sql = ", " + ", ".join(extra_columns) \
|
||||
if len(extra_columns) > 0 else ''
|
||||
@@ -230,7 +250,7 @@ def create_table(server, db_name, table_name, extra_columns=[]):
|
||||
VALUES ('Yet-Another-Name', 14,
|
||||
'cool info')''' % table_name)
|
||||
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -287,10 +307,10 @@ def create_table_with_query(server, db_name, query):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
pg_cursor.execute(query)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -312,14 +332,14 @@ def create_constraint(server,
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
pg_cursor.execute('''
|
||||
ALTER TABLE "%s"
|
||||
ADD CONSTRAINT "%s" %s (some_column)
|
||||
''' % (table_name, constraint_name, constraint_type.upper()))
|
||||
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -349,7 +369,7 @@ def create_type(server, db_name, type_name, type_fields=[]):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
|
||||
type_fields_sql = ", ".join(type_fields)
|
||||
|
||||
@@ -357,7 +377,7 @@ def create_type(server, db_name, type_name, type_fields=[]):
|
||||
pg_cursor.execute(
|
||||
'''CREATE TYPE %s AS (%s)''' % (type_name, type_fields_sql))
|
||||
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -375,7 +395,7 @@ def create_debug_function(server, db_name, function_name="test_func"):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
try:
|
||||
pg_cursor.execute('''CREATE EXTENSION pldbgapi;''')
|
||||
@@ -396,7 +416,7 @@ def create_debug_function(server, db_name, function_name="test_func"):
|
||||
END;
|
||||
$function$;
|
||||
''' % function_name)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -414,12 +434,12 @@ def drop_debug_function(server, db_name, function_name="test_func"):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
pg_cursor.execute('''
|
||||
DROP FUNCTION public."%s"();
|
||||
''' % function_name)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -458,14 +478,14 @@ def grant_role(server, db_name, role_name="test_role",
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
sql_query = '''GRANT "%s" TO %s;''' % (grant_role, role_name)
|
||||
|
||||
pg_cursor.execute(
|
||||
sql_query
|
||||
)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -483,7 +503,7 @@ def create_role(server, db_name, role_name="test_role"):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
sql_query = '''
|
||||
CREATE USER "%s" WITH
|
||||
@@ -493,13 +513,13 @@ def create_role(server, db_name, role_name="test_role"):
|
||||
CREATEDB
|
||||
NOCREATEROLE
|
||||
''' % (role_name)
|
||||
if connection.server_version > 90100:
|
||||
if get_server_version(connection) > 90100:
|
||||
sql_query += '\nNOREPLICATION'
|
||||
|
||||
pg_cursor.execute(
|
||||
sql_query
|
||||
)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -517,12 +537,12 @@ def drop_role(server, db_name, role_name="test_role"):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
pg_cursor.execute('''
|
||||
DROP USER "%s"
|
||||
''' % role_name)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -533,7 +553,7 @@ def drop_database(connection, database_name):
|
||||
"""This function used to drop the database"""
|
||||
if database_name not in ["postgres", "template1", "template0"]:
|
||||
pg_cursor = connection.cursor()
|
||||
if connection.server_version >= 90100:
|
||||
if connection.info.server_version >= 90100:
|
||||
pg_cursor.execute(
|
||||
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
|
||||
"FROM pg_stat_activity "
|
||||
@@ -551,9 +571,9 @@ def drop_database(connection, database_name):
|
||||
" db.datname='%s'" % database_name)
|
||||
if pg_cursor.fetchall():
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor.execute('''DROP DATABASE "%s"''' % database_name)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
connection.close()
|
||||
|
||||
@@ -563,7 +583,7 @@ def drop_database_multiple(connection, database_names):
|
||||
for database_name in database_names:
|
||||
if database_name not in ["postgres", "template1", "template0"]:
|
||||
pg_cursor = connection.cursor()
|
||||
if connection.server_version >= 90100:
|
||||
if get_server_version(connection) >= 90100:
|
||||
pg_cursor.execute(
|
||||
"SELECT pg_terminate_backend(pg_stat_activity.pid) "
|
||||
"FROM pg_stat_activity "
|
||||
@@ -581,9 +601,9 @@ def drop_database_multiple(connection, database_names):
|
||||
" db.datname='%s'" % database_name)
|
||||
if pg_cursor.fetchall():
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor.execute('''DROP DATABASE "%s"''' % database_name)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
connection.close()
|
||||
|
||||
@@ -597,9 +617,9 @@ def drop_tablespace(connection):
|
||||
for table_space in table_spaces:
|
||||
if table_space[0] not in ["pg_default", "pg_global"]:
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor.execute("DROP TABLESPACE %s" % table_space[0])
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
connection.close()
|
||||
|
||||
@@ -1140,11 +1160,11 @@ def create_schema(server, db_name, schema_name):
|
||||
server['sslmode']
|
||||
)
|
||||
old_isolation_level = connection.isolation_level
|
||||
connection.set_isolation_level(0)
|
||||
set_isolation_level(connection, 0)
|
||||
pg_cursor = connection.cursor()
|
||||
pg_cursor.execute(
|
||||
'''CREATE SCHEMA "%s"''' % schema_name)
|
||||
connection.set_isolation_level(old_isolation_level)
|
||||
set_isolation_level(connection, old_isolation_level)
|
||||
connection.commit()
|
||||
|
||||
except Exception:
|
||||
@@ -1205,7 +1225,7 @@ def check_binary_path_or_skip_test(cls, utility_name):
|
||||
|
||||
|
||||
def get_driver_version():
|
||||
version = getattr(psycopg2, '__version__', None)
|
||||
version = getattr(psycopg, '__version__', None)
|
||||
return version
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import traceback
|
||||
from urllib.parse import urlencode
|
||||
from flask import url_for
|
||||
import regression
|
||||
import config
|
||||
from regression import parent_node_dict
|
||||
from pgadmin.utils.route import BaseTestGenerator
|
||||
from regression.python_test_utils import test_utils as utils
|
||||
@@ -20,6 +21,7 @@ from pgadmin.browser.server_groups.servers.databases.tests import \
|
||||
utils as database_utils
|
||||
from pgadmin.utils.versioned_template_loader import \
|
||||
get_version_mapping_directories
|
||||
from config import PG_DEFAULT_DRIVER
|
||||
|
||||
|
||||
def create_resql_module_list(all_modules, exclude_pkgs, for_modules):
|
||||
@@ -99,13 +101,15 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
|
||||
# Schema ID placeholder in JSON file which needs to be replaced
|
||||
# while running the test cases
|
||||
|
||||
self.JSON_PLACEHOLDERS = {'schema_id': '<SCHEMA_ID>',
|
||||
'owner': '<OWNER>',
|
||||
'timestamptz_1': '<TIMESTAMPTZ_1>',
|
||||
'password': '<PASSWORD>',
|
||||
'pga_job_id': '<PGA_JOB_ID>',
|
||||
'timestamptz_2': '<TIMESTAMPTZ_2>',
|
||||
'db_name': '<TEST_DB_NAME>'}
|
||||
'db_name': '<TEST_DB_NAME>',
|
||||
'db_driver': '<DB_DRIVER>'}
|
||||
|
||||
resql_module_list = create_resql_module_list(
|
||||
BaseTestGenerator.re_sql_module_list,
|
||||
@@ -232,6 +236,9 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
elif self.check_precondition(
|
||||
scenario['precondition_sql'], False):
|
||||
skip_test_case = False
|
||||
elif 'pg_driver' in scenario and\
|
||||
scenario['pg_driver'] != PG_DEFAULT_DRIVER:
|
||||
skip_test_case = True
|
||||
else:
|
||||
skip_test_case = False
|
||||
|
||||
@@ -272,6 +279,10 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
try:
|
||||
self.assertEqual(response.status_code, 200)
|
||||
except Exception as e:
|
||||
response = self.tester.post(create_url,
|
||||
data=json.dumps(
|
||||
scenario['data']),
|
||||
content_type='html/json')
|
||||
self.final_test_status = False
|
||||
print(scenario['name'] + "... FAIL")
|
||||
traceback.print_exc()
|
||||
@@ -307,6 +318,11 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
except Exception as e:
|
||||
self.final_test_status = False
|
||||
alter_url = self.get_url(scenario['endpoint'], object_id)
|
||||
response = self.tester.put(alter_url,
|
||||
data=json.dumps(
|
||||
scenario['data']),
|
||||
follow_redirects=True)
|
||||
print(scenario['name'] + "... FAIL")
|
||||
traceback.print_exc()
|
||||
continue
|
||||
@@ -493,7 +509,6 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
try:
|
||||
self.assertEqual(response.status_code, 200)
|
||||
except Exception as e:
|
||||
|
||||
self.final_test_status = False
|
||||
traceback.print_exc()
|
||||
return False
|
||||
@@ -522,6 +537,8 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
try:
|
||||
self.assertEqual(sql, resp_sql)
|
||||
except Exception as e:
|
||||
print(sql)
|
||||
print(resp_sql)
|
||||
self.final_test_status = False
|
||||
traceback.print_exc()
|
||||
return False
|
||||
@@ -636,6 +653,8 @@ class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
||||
scenario['data'][key_attr]['added'][0]):
|
||||
self.get_db_connection()
|
||||
pg_cursor = self.connection.cursor()
|
||||
pg_cursor.execute("SET DateStyle=ISO;")
|
||||
|
||||
try:
|
||||
if is_tz_columns_list:
|
||||
query = "SELECT timestamp with time zone '" \
|
||||
|
||||
@@ -22,6 +22,10 @@ import secrets
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
if sys.version_info < (3, 4):
|
||||
raise RuntimeError('The test suite must be run under Python 3.4 or later.')
|
||||
@@ -495,7 +499,7 @@ def execute_test(test_module_list_passed, server_passed, driver_passed,
|
||||
)
|
||||
|
||||
# Add the server version in server information
|
||||
server_information['server_version'] = connection.server_version
|
||||
server_information['server_version'] = connection.info.server_version
|
||||
server_information['type'] = server_passed['type']
|
||||
|
||||
# Drop the database if already exists.
|
||||
|
||||
Reference in New Issue
Block a user