Fixed restoration of query tool database connection after dropping and re-creating the database with the same name. #6487

This commit is contained in:
Anil Sahoo 2023-11-01 15:27:18 +05:30 committed by GitHub
parent 179332ed5a
commit 0b707be615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 85 additions and 42 deletions

View File

@ -335,7 +335,8 @@ def panel(trans_id):
@blueprint.route(
'/initialize/sqleditor/<int:trans_id>/<int:sgid>/<int:sid>/<int:did>',
'/initialize/sqleditor/<int:trans_id>/<int:sgid>/<int:sid>/'
'<int:did>',
methods=["POST"], endpoint='initialize_sqleditor_with_did'
)
@blueprint.route(
@ -373,7 +374,7 @@ def initialize_sqleditor(trans_id, sgid, sid, did=None):
}
is_error, errmsg, conn_id, version = _init_sqleditor(
trans_id, connect, sgid, sid, did, **kwargs)
trans_id, connect, sgid, sid, did, data['dbname'], **kwargs)
if is_error:
return errmsg
@ -410,7 +411,7 @@ def _connect(conn, **kwargs):
return status, msg, is_ask_password, user, role, password
def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
def _init_sqleditor(trans_id, connect, sgid, sid, did, dbname=None, **kwargs):
# Create asynchronous connection using random connection id.
conn_id = str(secrets.choice(range(1, 9999999)))
conn_id_ac = str(secrets.choice(range(1, 9999999)))
@ -429,10 +430,12 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
return True, internal_server_error(errormsg=str(e)), '', ''
try:
conn = manager.connection(did=did, conn_id=conn_id,
conn = manager.connection(conn_id=conn_id,
auto_reconnect=False,
use_binary_placeholder=True,
array_to_string=True)
array_to_string=True,
**({"database": dbname} if dbname is not None
else {"did": did}))
pref = Preferences.module('sqleditor')
if connect:
@ -463,10 +466,13 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
errormsg=str(msg)), '', ''
if pref.preference('autocomplete_on_key_press').get():
conn_ac = manager.connection(did=did, conn_id=conn_id_ac,
conn_ac = manager.connection(conn_id=conn_id_ac,
auto_reconnect=False,
use_binary_placeholder=True,
array_to_string=True)
array_to_string=True,
**({"database": dbname}
if dbname is not None
else {"did": did}))
status, msg, is_ask_password, user, role, password = _connect(
conn_ac, **kwargs)
@ -486,6 +492,8 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
command_obj.set_auto_commit(pref.preference('auto_commit').get())
command_obj.set_auto_rollback(pref.preference('auto_rollback').get())
# Set the value of database name, that will be used later
command_obj.dbname = dbname if dbname else None
# Use pickle to store the command object which will be used
# later by the sql grid module.
sql_grid_data[str(trans_id)] = {
@ -533,7 +541,8 @@ def update_sqleditor_connection(trans_id, sgid, sid, did):
}
is_error, errmsg, conn_id, version = _init_sqleditor(
new_trans_id, connect, sgid, sid, did, **kwargs)
new_trans_id, connect, sgid, sid, did, data['database_name'],
**kwargs)
if is_error:
return errmsg
@ -851,6 +860,7 @@ def start_query_tool(trans_id):
Args:
trans_id: unique transaction id
"""
sql = extract_sql_from_network_parameters(
request.data, request.args, request.form
)
@ -1632,7 +1642,10 @@ def cancel_transaction(trans_id):
try:
manager = get_driver(
PG_DEFAULT_DRIVER).connection_manager(trans_obj.sid)
conn = manager.connection(did=trans_obj.did)
conn = manager.connection(**({"database": trans_obj.dbname}
if trans_obj.dbname is not None
else {"did": trans_obj.did}))
except Exception as e:
return internal_server_error(errormsg=str(e))

View File

@ -113,7 +113,7 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN
fgcolor: params.fgcolor,
bgcolor: params.bgcolor,
conn_title: getTitle(
pgAdmin, null, selectedNodeInfo, true, _.unescape(params.server_name), _.escape(params.database_name) || getDatabaseLabel(selectedNodeInfo),
pgAdmin, null, selectedNodeInfo, true, _.unescape(params.server_name), _.unescape(params.database_name) || getDatabaseLabel(selectedNodeInfo),
_.unescape(params.role) || _.unescape(params.user), params.is_query_tool == 'true' ? true : false),
server_name: _.unescape(params.server_name),
database_name: _.unescape(params.database_name) || getDatabaseLabel(selectedNodeInfo),
@ -261,7 +261,8 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN
api.post(baseUrl, qtState.params.is_query_tool ? {
user: selectedConn.user,
role: selectedConn.role,
password: password
password: password,
dbname: selectedConn.database_name
} : JSON.stringify(qtState.params.sql_filter))
.then(()=>{
setQtState({
@ -651,10 +652,13 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN
is_selected: true,
};
let existIdx = _.findIndex(qtState.connection_list, (conn)=>(
conn.sid == connectionData.sid && conn.did == connectionData.did
let existIdx = _.findIndex(qtState.connection_list, (conn)=>{
conn.role= conn.role == ''? null :conn.role;
return(
conn.sid == connectionData.sid && conn.database_name == connectionData.database_name
&& conn.user == connectionData.user && conn.role == connectionData.role
));
);
});
if(existIdx > -1) {
reject(gettext('Connection with this configuration already present.'));
return;

View File

@ -119,7 +119,6 @@ class TestDownloadCSV(BaseTestGenerator):
url = '/sqleditor/query_tool/start/{0}'.format(trans_id)
response = self.tester.post(url, data=json.dumps({"sql": sql_query}),
content_type='html/json')
self.assertEqual(response.status_code, 200)
return async_poll(tester=self.tester,
@ -138,7 +137,9 @@ class TestDownloadCSV(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = self.init_url.format(
self.trans_id, test_utils.SERVER_GROUP, self._sid, self._did)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self._db_name
}))
self.assertEqual(response.status_code, 200)
res = self.initiate_sql_query_tool(self.trans_id, self.sql)

View File

@ -72,7 +72,9 @@ class TestEditorHistory(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": database_info["db_name"]
}))
self.assertEqual(response.status_code, 200)
def runTest(self):

View File

@ -267,14 +267,16 @@ class TestEncodingCharset(BaseTestGenerator):
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'\
.format(self.trans_id, test_utils.SERVER_GROUP, self.encode_sid,
self.encode_did)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self.encode_db_name
}))
self.assertEqual(response.status_code, 200)
# Check character
url = "/sqleditor/query_tool/start/{0}".format(self.trans_id)
sql = "select E'{0}';".format(self.test_str)
response = self.tester.post(url, data=json.dumps({"sql": sql}),
content_type='html/json')
response = (self.tester.post(url, data=json.dumps({"sql": sql}),
content_type='html/json'))
self.assertEqual(response.status_code, 200)
response = async_poll(tester=self.tester,
poll_url='/sqleditor/poll/{0}'.format(

View File

@ -38,7 +38,9 @@ class TestExplainPlan(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": database_info["db_name"]
}))
self.assertEqual(response.status_code, 200)
# Start query tool transaction

View File

@ -108,7 +108,9 @@ class TestMacros(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": database_info["db_name"]
}))
self.assertEqual(response.status_code, 200)
def runTest(self):

View File

@ -79,7 +79,9 @@ NOTICE: Hello, world!
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": database_info["db_name"]
}))
self.assertEqual(response.status_code, 200)
cnt = 0
@ -89,7 +91,6 @@ NOTICE: Hello, world!
url = '/sqleditor/query_tool/start/{0}'.format(self.trans_id)
response = self.tester.post(url, data=json.dumps({"sql": s}),
content_type='html/json')
self.assertEqual(response.status_code, 200)
url = '/sqleditor/poll/{0}'.format(self.trans_id)

View File

@ -98,7 +98,9 @@ class TestSQLASCIIEncoding(BaseTestGenerator):
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'\
.format(self.trans_id, test_utils.SERVER_GROUP, self.encode_sid,
self.encode_did)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self.encode_db_name
}))
self.assertEqual(response.status_code, 200)
# Check character

View File

@ -7,6 +7,7 @@
#
##########################################################################
import json
from pgadmin.utils.route import BaseTestGenerator
from pgadmin.tools.sqleditor import StartRunningQuery
from unittest.mock import patch, ANY
@ -30,13 +31,13 @@ class StartQueryTool(BaseTestGenerator):
return_value='some result'
) as StartRunningQuery_execute_mock:
response = self.tester.post(
'/sqleditor/query_tool/start/1234',
data='"some sql statement"'
)
'/sqleditor/query_tool/start/1234', data=json.dumps({
"sql": "some sql statement"}))
self.assertEqual(response.status, '200 OK')
self.assertEqual(response.data, b'some result')
StartRunningQuery_execute_mock \
.assert_called_with('transformed sql', 1234, ANY, False)
extract_sql_from_network_parameters_mock \
.assert_called_with(b'"some sql statement"', ANY, ANY)
.assert_called_with(
b'{"sql": "some sql statement"}',
ANY, ANY)

View File

@ -306,7 +306,9 @@ class TestTransactionControl(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self.db_name
}))
self.assertEqual(response.status_code, 200)
def _initialize_urls(self):

View File

@ -67,6 +67,7 @@ class StartRunningQuery:
PG_DEFAULT_DRIVER).connection_manager(
transaction_object.sid)
conn = manager.connection(did=transaction_object.did,
database=transaction_object.dbname,
conn_id=self.connection_id,
auto_reconnect=False,
use_binary_placeholder=True,

View File

@ -8,6 +8,7 @@
##########################################################################
import secrets
import json
from pgadmin.browser.server_groups.servers.databases.tests import utils as \
database_utils
@ -204,7 +205,9 @@ class TestQueryUpdatableResultset(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self.db_name
}))
self.assertEqual(response.status_code, 200)
def _initialize_urls(self):

View File

@ -923,7 +923,9 @@ class TestSaveChangedData(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self.db_name
}))
self.assertEqual(response.status_code, 200)
def _initialize_urls_and_select_sql(self):

View File

@ -247,7 +247,9 @@ class TestSaveChangedDataUUID(BaseTestGenerator):
self.trans_id = str(secrets.choice(range(1, 9999999)))
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
response = self.tester.post(url)
response = self.tester.post(url, data=json.dumps({
"dbname": self.db_name
}))
self.assertEqual(response.status_code, 200)
def _initialize_urls_and_select_sql(self):

View File

@ -30,7 +30,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict()
http_session=dict(),
),
pickle_load_return=None,
get_driver_exception=False,
@ -60,7 +60,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData=dict())
http_session=dict(gridData=dict()),
),
pickle_load_return=None,
get_driver_exception=False,
@ -91,7 +91,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData={'123': dict(command_obj='')})
http_session=dict(gridData={'123': dict(command_obj='')}),
),
pickle_load_return=None,
get_driver_exception=False,
@ -126,7 +126,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData={'123': dict(command_obj='')})
http_session=dict(gridData={'123': dict(command_obj='')}),
),
pickle_load_return=MagicMock(conn_id=1,
update_fetched_row_cnt=MagicMock()),
@ -155,7 +155,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData={'123': dict(command_obj='')})
http_session=dict(gridData={'123': dict(command_obj='')}),
),
pickle_load_return=MagicMock(
conn_id=1,
@ -184,7 +184,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData={'123': dict(command_obj='')})
http_session=dict(gridData={'123': dict(command_obj='')}),
),
pickle_load_return=MagicMock(
conn_id=1,
@ -212,7 +212,7 @@ class StartRunningQueryTest(BaseTestGenerator):
function_parameters=dict(
sql=dict(sql='some sql', explain_plan=None),
trans_id=123,
http_session=dict(gridData={'123': dict(command_obj='')})
http_session=dict(gridData={'123': dict(command_obj='')}),
),
pickle_load_return=MagicMock(
conn_id=1,

View File

@ -204,10 +204,13 @@ class ServerManager(object):
if did is not None and did in self.db_info:
self.db_info[did]['datname'] = database
else:
conn_str = 'CONN:{0}'.format(conn_id)
if did is None:
database = self.db
elif did in self.db_info:
database = self.db_info[did]['datname']
elif conn_id and conn_str in self.connections:
database = self.connections[conn_str].db
else:
maintenance_db_id = 'DB:{0}'.format(self.db)
if maintenance_db_id in self.connections: