diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 3346ba78c..10e1e813e 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -335,7 +335,8 @@ def panel(trans_id): @blueprint.route( - '/initialize/sqleditor////', + '/initialize/sqleditor////' + '', methods=["POST"], endpoint='initialize_sqleditor_with_did' ) @blueprint.route( @@ -373,7 +374,7 @@ def initialize_sqleditor(trans_id, sgid, sid, did=None): } is_error, errmsg, conn_id, version = _init_sqleditor( - trans_id, connect, sgid, sid, did, **kwargs) + trans_id, connect, sgid, sid, did, data['dbname'], **kwargs) if is_error: return errmsg @@ -410,7 +411,7 @@ def _connect(conn, **kwargs): return status, msg, is_ask_password, user, role, password -def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs): +def _init_sqleditor(trans_id, connect, sgid, sid, did, dbname=None, **kwargs): # Create asynchronous connection using random connection id. conn_id = str(secrets.choice(range(1, 9999999))) conn_id_ac = str(secrets.choice(range(1, 9999999))) @@ -429,10 +430,12 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs): return True, internal_server_error(errormsg=str(e)), '', '' try: - conn = manager.connection(did=did, conn_id=conn_id, + conn = manager.connection(conn_id=conn_id, auto_reconnect=False, use_binary_placeholder=True, - array_to_string=True) + array_to_string=True, + **({"database": dbname} if dbname is not None + else {"did": did})) pref = Preferences.module('sqleditor') if connect: @@ -463,10 +466,13 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs): errormsg=str(msg)), '', '' if pref.preference('autocomplete_on_key_press').get(): - conn_ac = manager.connection(did=did, conn_id=conn_id_ac, + conn_ac = manager.connection(conn_id=conn_id_ac, auto_reconnect=False, use_binary_placeholder=True, - array_to_string=True) + array_to_string=True, + **({"database": dbname} + if dbname is not None + else {"did": did})) status, msg, is_ask_password, user, role, password = _connect( conn_ac, **kwargs) @@ -486,6 +492,8 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs): command_obj.set_auto_commit(pref.preference('auto_commit').get()) command_obj.set_auto_rollback(pref.preference('auto_rollback').get()) + # Set the value of database name, that will be used later + command_obj.dbname = dbname if dbname else None # Use pickle to store the command object which will be used # later by the sql grid module. sql_grid_data[str(trans_id)] = { @@ -533,7 +541,8 @@ def update_sqleditor_connection(trans_id, sgid, sid, did): } is_error, errmsg, conn_id, version = _init_sqleditor( - new_trans_id, connect, sgid, sid, did, **kwargs) + new_trans_id, connect, sgid, sid, did, data['database_name'], + **kwargs) if is_error: return errmsg @@ -851,6 +860,7 @@ def start_query_tool(trans_id): Args: trans_id: unique transaction id """ + sql = extract_sql_from_network_parameters( request.data, request.args, request.form ) @@ -1632,7 +1642,10 @@ def cancel_transaction(trans_id): try: manager = get_driver( PG_DEFAULT_DRIVER).connection_manager(trans_obj.sid) - conn = manager.connection(did=trans_obj.did) + conn = manager.connection(**({"database": trans_obj.dbname} + if trans_obj.dbname is not None + else {"did": trans_obj.did})) + except Exception as e: return internal_server_error(errormsg=str(e)) diff --git a/web/pgadmin/tools/sqleditor/static/js/components/QueryToolComponent.jsx b/web/pgadmin/tools/sqleditor/static/js/components/QueryToolComponent.jsx index 9b1106b32..160c48037 100644 --- a/web/pgadmin/tools/sqleditor/static/js/components/QueryToolComponent.jsx +++ b/web/pgadmin/tools/sqleditor/static/js/components/QueryToolComponent.jsx @@ -113,7 +113,7 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN fgcolor: params.fgcolor, bgcolor: params.bgcolor, conn_title: getTitle( - pgAdmin, null, selectedNodeInfo, true, _.unescape(params.server_name), _.escape(params.database_name) || getDatabaseLabel(selectedNodeInfo), + pgAdmin, null, selectedNodeInfo, true, _.unescape(params.server_name), _.unescape(params.database_name) || getDatabaseLabel(selectedNodeInfo), _.unescape(params.role) || _.unescape(params.user), params.is_query_tool == 'true' ? true : false), server_name: _.unescape(params.server_name), database_name: _.unescape(params.database_name) || getDatabaseLabel(selectedNodeInfo), @@ -261,7 +261,8 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN api.post(baseUrl, qtState.params.is_query_tool ? { user: selectedConn.user, role: selectedConn.role, - password: password + password: password, + dbname: selectedConn.database_name } : JSON.stringify(qtState.params.sql_filter)) .then(()=>{ setQtState({ @@ -651,10 +652,13 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN is_selected: true, }; - let existIdx = _.findIndex(qtState.connection_list, (conn)=>( - conn.sid == connectionData.sid && conn.did == connectionData.did - && conn.user == connectionData.user && conn.role == connectionData.role - )); + let existIdx = _.findIndex(qtState.connection_list, (conn)=>{ + conn.role= conn.role == ''? null :conn.role; + return( + conn.sid == connectionData.sid && conn.database_name == connectionData.database_name + && conn.user == connectionData.user && conn.role == connectionData.role + ); + }); if(existIdx > -1) { reject(gettext('Connection with this configuration already present.')); return; diff --git a/web/pgadmin/tools/sqleditor/tests/test_download_csv_query_tool.py b/web/pgadmin/tools/sqleditor/tests/test_download_csv_query_tool.py index dba5f86eb..934d63263 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_download_csv_query_tool.py +++ b/web/pgadmin/tools/sqleditor/tests/test_download_csv_query_tool.py @@ -119,7 +119,6 @@ class TestDownloadCSV(BaseTestGenerator): url = '/sqleditor/query_tool/start/{0}'.format(trans_id) response = self.tester.post(url, data=json.dumps({"sql": sql_query}), content_type='html/json') - self.assertEqual(response.status_code, 200) return async_poll(tester=self.tester, @@ -138,7 +137,9 @@ class TestDownloadCSV(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = self.init_url.format( self.trans_id, test_utils.SERVER_GROUP, self._sid, self._did) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self._db_name + })) self.assertEqual(response.status_code, 200) res = self.initiate_sql_query_tool(self.trans_id, self.sql) diff --git a/web/pgadmin/tools/sqleditor/tests/test_editor_history.py b/web/pgadmin/tools/sqleditor/tests/test_editor_history.py index 245398893..acfc3a28b 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_editor_history.py +++ b/web/pgadmin/tools/sqleditor/tests/test_editor_history.py @@ -72,7 +72,9 @@ class TestEditorHistory(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": database_info["db_name"] + })) self.assertEqual(response.status_code, 200) def runTest(self): diff --git a/web/pgadmin/tools/sqleditor/tests/test_encoding_charset.py b/web/pgadmin/tools/sqleditor/tests/test_encoding_charset.py index c624ea30f..b872d7fa1 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_encoding_charset.py +++ b/web/pgadmin/tools/sqleditor/tests/test_encoding_charset.py @@ -267,14 +267,16 @@ class TestEncodingCharset(BaseTestGenerator): url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'\ .format(self.trans_id, test_utils.SERVER_GROUP, self.encode_sid, self.encode_did) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self.encode_db_name + })) self.assertEqual(response.status_code, 200) # Check character url = "/sqleditor/query_tool/start/{0}".format(self.trans_id) sql = "select E'{0}';".format(self.test_str) - response = self.tester.post(url, data=json.dumps({"sql": sql}), - content_type='html/json') + response = (self.tester.post(url, data=json.dumps({"sql": sql}), + content_type='html/json')) self.assertEqual(response.status_code, 200) response = async_poll(tester=self.tester, poll_url='/sqleditor/poll/{0}'.format( diff --git a/web/pgadmin/tools/sqleditor/tests/test_explain_plan.py b/web/pgadmin/tools/sqleditor/tests/test_explain_plan.py index fbb816d6f..8935bc153 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_explain_plan.py +++ b/web/pgadmin/tools/sqleditor/tests/test_explain_plan.py @@ -38,7 +38,9 @@ class TestExplainPlan(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": database_info["db_name"] + })) self.assertEqual(response.status_code, 200) # Start query tool transaction diff --git a/web/pgadmin/tools/sqleditor/tests/test_macros.py b/web/pgadmin/tools/sqleditor/tests/test_macros.py index c94b9ee83..662fd1142 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_macros.py +++ b/web/pgadmin/tools/sqleditor/tests/test_macros.py @@ -108,7 +108,9 @@ class TestMacros(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": database_info["db_name"] + })) self.assertEqual(response.status_code, 200) def runTest(self): diff --git a/web/pgadmin/tools/sqleditor/tests/test_poll_query_tool.py b/web/pgadmin/tools/sqleditor/tests/test_poll_query_tool.py index 054d50f31..40ff48452 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_poll_query_tool.py +++ b/web/pgadmin/tools/sqleditor/tests/test_poll_query_tool.py @@ -79,7 +79,9 @@ NOTICE: Hello, world! self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": database_info["db_name"] + })) self.assertEqual(response.status_code, 200) cnt = 0 @@ -89,7 +91,6 @@ NOTICE: Hello, world! url = '/sqleditor/query_tool/start/{0}'.format(self.trans_id) response = self.tester.post(url, data=json.dumps({"sql": s}), content_type='html/json') - self.assertEqual(response.status_code, 200) url = '/sqleditor/poll/{0}'.format(self.trans_id) diff --git a/web/pgadmin/tools/sqleditor/tests/test_sql_ascii_encoding.py b/web/pgadmin/tools/sqleditor/tests/test_sql_ascii_encoding.py index d4a04bed1..2283ddf6d 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_sql_ascii_encoding.py +++ b/web/pgadmin/tools/sqleditor/tests/test_sql_ascii_encoding.py @@ -98,7 +98,9 @@ class TestSQLASCIIEncoding(BaseTestGenerator): url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'\ .format(self.trans_id, test_utils.SERVER_GROUP, self.encode_sid, self.encode_did) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self.encode_db_name + })) self.assertEqual(response.status_code, 200) # Check character diff --git a/web/pgadmin/tools/sqleditor/tests/test_start_query_tool.py b/web/pgadmin/tools/sqleditor/tests/test_start_query_tool.py index 9ae2ca25a..4e01e2d3a 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_start_query_tool.py +++ b/web/pgadmin/tools/sqleditor/tests/test_start_query_tool.py @@ -7,6 +7,7 @@ # ########################################################################## +import json from pgadmin.utils.route import BaseTestGenerator from pgadmin.tools.sqleditor import StartRunningQuery from unittest.mock import patch, ANY @@ -30,13 +31,13 @@ class StartQueryTool(BaseTestGenerator): return_value='some result' ) as StartRunningQuery_execute_mock: response = self.tester.post( - '/sqleditor/query_tool/start/1234', - data='"some sql statement"' - ) - + '/sqleditor/query_tool/start/1234', data=json.dumps({ + "sql": "some sql statement"})) self.assertEqual(response.status, '200 OK') self.assertEqual(response.data, b'some result') StartRunningQuery_execute_mock \ .assert_called_with('transformed sql', 1234, ANY, False) extract_sql_from_network_parameters_mock \ - .assert_called_with(b'"some sql statement"', ANY, ANY) + .assert_called_with( + b'{"sql": "some sql statement"}', + ANY, ANY) diff --git a/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py b/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py index 916ea9a96..f91cbd51f 100644 --- a/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py +++ b/web/pgadmin/tools/sqleditor/tests/test_transaction_status.py @@ -306,7 +306,9 @@ class TestTransactionControl(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self.db_name + })) self.assertEqual(response.status_code, 200) def _initialize_urls(self): diff --git a/web/pgadmin/tools/sqleditor/utils/start_running_query.py b/web/pgadmin/tools/sqleditor/utils/start_running_query.py index 4d8e55bdd..5ce45cd71 100644 --- a/web/pgadmin/tools/sqleditor/utils/start_running_query.py +++ b/web/pgadmin/tools/sqleditor/utils/start_running_query.py @@ -67,6 +67,7 @@ class StartRunningQuery: PG_DEFAULT_DRIVER).connection_manager( transaction_object.sid) conn = manager.connection(did=transaction_object.did, + database=transaction_object.dbname, conn_id=self.connection_id, auto_reconnect=False, use_binary_placeholder=True, diff --git a/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py b/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py index 54d7762e3..02671c3b6 100644 --- a/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py +++ b/web/pgadmin/tools/sqleditor/utils/tests/test_is_query_resultset_updatable.py @@ -8,6 +8,7 @@ ########################################################################## import secrets +import json from pgadmin.browser.server_groups.servers.databases.tests import utils as \ database_utils @@ -204,7 +205,9 @@ class TestQueryUpdatableResultset(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self.db_name + })) self.assertEqual(response.status_code, 200) def _initialize_urls(self): diff --git a/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py b/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py index 084817374..086a64067 100644 --- a/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py +++ b/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_data.py @@ -923,7 +923,9 @@ class TestSaveChangedData(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self.db_name + })) self.assertEqual(response.status_code, 200) def _initialize_urls_and_select_sql(self): diff --git a/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_uuid_data.py b/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_uuid_data.py index fbc955c7d..97dbbb998 100644 --- a/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_uuid_data.py +++ b/web/pgadmin/tools/sqleditor/utils/tests/test_save_changed_uuid_data.py @@ -247,7 +247,9 @@ class TestSaveChangedDataUUID(BaseTestGenerator): self.trans_id = str(secrets.choice(range(1, 9999999))) url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format( self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id) - response = self.tester.post(url) + response = self.tester.post(url, data=json.dumps({ + "dbname": self.db_name + })) self.assertEqual(response.status_code, 200) def _initialize_urls_and_select_sql(self): diff --git a/web/pgadmin/tools/sqleditor/utils/tests/test_start_running_query.py b/web/pgadmin/tools/sqleditor/utils/tests/test_start_running_query.py index 004f7eb2a..5406f2447 100644 --- a/web/pgadmin/tools/sqleditor/utils/tests/test_start_running_query.py +++ b/web/pgadmin/tools/sqleditor/utils/tests/test_start_running_query.py @@ -30,7 +30,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict() + http_session=dict(), ), pickle_load_return=None, get_driver_exception=False, @@ -60,7 +60,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict(gridData=dict()) + http_session=dict(gridData=dict()), ), pickle_load_return=None, get_driver_exception=False, @@ -91,7 +91,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict(gridData={'123': dict(command_obj='')}) + http_session=dict(gridData={'123': dict(command_obj='')}), ), pickle_load_return=None, get_driver_exception=False, @@ -126,7 +126,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict(gridData={'123': dict(command_obj='')}) + http_session=dict(gridData={'123': dict(command_obj='')}), ), pickle_load_return=MagicMock(conn_id=1, update_fetched_row_cnt=MagicMock()), @@ -155,7 +155,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict(gridData={'123': dict(command_obj='')}) + http_session=dict(gridData={'123': dict(command_obj='')}), ), pickle_load_return=MagicMock( conn_id=1, @@ -184,7 +184,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict(gridData={'123': dict(command_obj='')}) + http_session=dict(gridData={'123': dict(command_obj='')}), ), pickle_load_return=MagicMock( conn_id=1, @@ -212,7 +212,7 @@ class StartRunningQueryTest(BaseTestGenerator): function_parameters=dict( sql=dict(sql='some sql', explain_plan=None), trans_id=123, - http_session=dict(gridData={'123': dict(command_obj='')}) + http_session=dict(gridData={'123': dict(command_obj='')}), ), pickle_load_return=MagicMock( conn_id=1, diff --git a/web/pgadmin/utils/driver/psycopg3/server_manager.py b/web/pgadmin/utils/driver/psycopg3/server_manager.py index a1991cac4..786a94992 100644 --- a/web/pgadmin/utils/driver/psycopg3/server_manager.py +++ b/web/pgadmin/utils/driver/psycopg3/server_manager.py @@ -204,10 +204,13 @@ class ServerManager(object): if did is not None and did in self.db_info: self.db_info[did]['datname'] = database else: + conn_str = 'CONN:{0}'.format(conn_id) if did is None: database = self.db elif did in self.db_info: database = self.db_info[did]['datname'] + elif conn_id and conn_str in self.connections: + database = self.connections[conn_str].db else: maintenance_db_id = 'DB:{0}'.format(self.db) if maintenance_db_id in self.connections: