mirror of
https://github.com/pgadmin-org/pgadmin4.git
synced 2024-11-24 09:40:21 -06:00
Fixed restoration of query tool database connection after dropping and re-creating the database with the same name. #6487
This commit is contained in:
parent
179332ed5a
commit
0b707be615
@ -335,7 +335,8 @@ def panel(trans_id):
|
||||
|
||||
|
||||
@blueprint.route(
|
||||
'/initialize/sqleditor/<int:trans_id>/<int:sgid>/<int:sid>/<int:did>',
|
||||
'/initialize/sqleditor/<int:trans_id>/<int:sgid>/<int:sid>/'
|
||||
'<int:did>',
|
||||
methods=["POST"], endpoint='initialize_sqleditor_with_did'
|
||||
)
|
||||
@blueprint.route(
|
||||
@ -373,7 +374,7 @@ def initialize_sqleditor(trans_id, sgid, sid, did=None):
|
||||
}
|
||||
|
||||
is_error, errmsg, conn_id, version = _init_sqleditor(
|
||||
trans_id, connect, sgid, sid, did, **kwargs)
|
||||
trans_id, connect, sgid, sid, did, data['dbname'], **kwargs)
|
||||
if is_error:
|
||||
return errmsg
|
||||
|
||||
@ -410,7 +411,7 @@ def _connect(conn, **kwargs):
|
||||
return status, msg, is_ask_password, user, role, password
|
||||
|
||||
|
||||
def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
|
||||
def _init_sqleditor(trans_id, connect, sgid, sid, did, dbname=None, **kwargs):
|
||||
# Create asynchronous connection using random connection id.
|
||||
conn_id = str(secrets.choice(range(1, 9999999)))
|
||||
conn_id_ac = str(secrets.choice(range(1, 9999999)))
|
||||
@ -429,10 +430,12 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
|
||||
return True, internal_server_error(errormsg=str(e)), '', ''
|
||||
|
||||
try:
|
||||
conn = manager.connection(did=did, conn_id=conn_id,
|
||||
conn = manager.connection(conn_id=conn_id,
|
||||
auto_reconnect=False,
|
||||
use_binary_placeholder=True,
|
||||
array_to_string=True)
|
||||
array_to_string=True,
|
||||
**({"database": dbname} if dbname is not None
|
||||
else {"did": did}))
|
||||
pref = Preferences.module('sqleditor')
|
||||
|
||||
if connect:
|
||||
@ -463,10 +466,13 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
|
||||
errormsg=str(msg)), '', ''
|
||||
|
||||
if pref.preference('autocomplete_on_key_press').get():
|
||||
conn_ac = manager.connection(did=did, conn_id=conn_id_ac,
|
||||
conn_ac = manager.connection(conn_id=conn_id_ac,
|
||||
auto_reconnect=False,
|
||||
use_binary_placeholder=True,
|
||||
array_to_string=True)
|
||||
array_to_string=True,
|
||||
**({"database": dbname}
|
||||
if dbname is not None
|
||||
else {"did": did}))
|
||||
status, msg, is_ask_password, user, role, password = _connect(
|
||||
conn_ac, **kwargs)
|
||||
|
||||
@ -486,6 +492,8 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, **kwargs):
|
||||
command_obj.set_auto_commit(pref.preference('auto_commit').get())
|
||||
command_obj.set_auto_rollback(pref.preference('auto_rollback').get())
|
||||
|
||||
# Set the value of database name, that will be used later
|
||||
command_obj.dbname = dbname if dbname else None
|
||||
# Use pickle to store the command object which will be used
|
||||
# later by the sql grid module.
|
||||
sql_grid_data[str(trans_id)] = {
|
||||
@ -533,7 +541,8 @@ def update_sqleditor_connection(trans_id, sgid, sid, did):
|
||||
}
|
||||
|
||||
is_error, errmsg, conn_id, version = _init_sqleditor(
|
||||
new_trans_id, connect, sgid, sid, did, **kwargs)
|
||||
new_trans_id, connect, sgid, sid, did, data['database_name'],
|
||||
**kwargs)
|
||||
|
||||
if is_error:
|
||||
return errmsg
|
||||
@ -851,6 +860,7 @@ def start_query_tool(trans_id):
|
||||
Args:
|
||||
trans_id: unique transaction id
|
||||
"""
|
||||
|
||||
sql = extract_sql_from_network_parameters(
|
||||
request.data, request.args, request.form
|
||||
)
|
||||
@ -1632,7 +1642,10 @@ def cancel_transaction(trans_id):
|
||||
try:
|
||||
manager = get_driver(
|
||||
PG_DEFAULT_DRIVER).connection_manager(trans_obj.sid)
|
||||
conn = manager.connection(did=trans_obj.did)
|
||||
conn = manager.connection(**({"database": trans_obj.dbname}
|
||||
if trans_obj.dbname is not None
|
||||
else {"did": trans_obj.did}))
|
||||
|
||||
except Exception as e:
|
||||
return internal_server_error(errormsg=str(e))
|
||||
|
||||
|
@ -113,7 +113,7 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN
|
||||
fgcolor: params.fgcolor,
|
||||
bgcolor: params.bgcolor,
|
||||
conn_title: getTitle(
|
||||
pgAdmin, null, selectedNodeInfo, true, _.unescape(params.server_name), _.escape(params.database_name) || getDatabaseLabel(selectedNodeInfo),
|
||||
pgAdmin, null, selectedNodeInfo, true, _.unescape(params.server_name), _.unescape(params.database_name) || getDatabaseLabel(selectedNodeInfo),
|
||||
_.unescape(params.role) || _.unescape(params.user), params.is_query_tool == 'true' ? true : false),
|
||||
server_name: _.unescape(params.server_name),
|
||||
database_name: _.unescape(params.database_name) || getDatabaseLabel(selectedNodeInfo),
|
||||
@ -261,7 +261,8 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN
|
||||
api.post(baseUrl, qtState.params.is_query_tool ? {
|
||||
user: selectedConn.user,
|
||||
role: selectedConn.role,
|
||||
password: password
|
||||
password: password,
|
||||
dbname: selectedConn.database_name
|
||||
} : JSON.stringify(qtState.params.sql_filter))
|
||||
.then(()=>{
|
||||
setQtState({
|
||||
@ -651,10 +652,13 @@ export default function QueryToolComponent({params, pgWindow, pgAdmin, selectedN
|
||||
is_selected: true,
|
||||
};
|
||||
|
||||
let existIdx = _.findIndex(qtState.connection_list, (conn)=>(
|
||||
conn.sid == connectionData.sid && conn.did == connectionData.did
|
||||
&& conn.user == connectionData.user && conn.role == connectionData.role
|
||||
));
|
||||
let existIdx = _.findIndex(qtState.connection_list, (conn)=>{
|
||||
conn.role= conn.role == ''? null :conn.role;
|
||||
return(
|
||||
conn.sid == connectionData.sid && conn.database_name == connectionData.database_name
|
||||
&& conn.user == connectionData.user && conn.role == connectionData.role
|
||||
);
|
||||
});
|
||||
if(existIdx > -1) {
|
||||
reject(gettext('Connection with this configuration already present.'));
|
||||
return;
|
||||
|
@ -119,7 +119,6 @@ class TestDownloadCSV(BaseTestGenerator):
|
||||
url = '/sqleditor/query_tool/start/{0}'.format(trans_id)
|
||||
response = self.tester.post(url, data=json.dumps({"sql": sql_query}),
|
||||
content_type='html/json')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
return async_poll(tester=self.tester,
|
||||
@ -138,7 +137,9 @@ class TestDownloadCSV(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = self.init_url.format(
|
||||
self.trans_id, test_utils.SERVER_GROUP, self._sid, self._did)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self._db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
res = self.initiate_sql_query_tool(self.trans_id, self.sql)
|
||||
|
@ -72,7 +72,9 @@ class TestEditorHistory(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": database_info["db_name"]
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def runTest(self):
|
||||
|
@ -267,14 +267,16 @@ class TestEncodingCharset(BaseTestGenerator):
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'\
|
||||
.format(self.trans_id, test_utils.SERVER_GROUP, self.encode_sid,
|
||||
self.encode_did)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self.encode_db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Check character
|
||||
url = "/sqleditor/query_tool/start/{0}".format(self.trans_id)
|
||||
sql = "select E'{0}';".format(self.test_str)
|
||||
response = self.tester.post(url, data=json.dumps({"sql": sql}),
|
||||
content_type='html/json')
|
||||
response = (self.tester.post(url, data=json.dumps({"sql": sql}),
|
||||
content_type='html/json'))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
response = async_poll(tester=self.tester,
|
||||
poll_url='/sqleditor/poll/{0}'.format(
|
||||
|
@ -38,7 +38,9 @@ class TestExplainPlan(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": database_info["db_name"]
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Start query tool transaction
|
||||
|
@ -108,7 +108,9 @@ class TestMacros(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": database_info["db_name"]
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def runTest(self):
|
||||
|
@ -79,7 +79,9 @@ NOTICE: Hello, world!
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": database_info["db_name"]
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
cnt = 0
|
||||
@ -89,7 +91,6 @@ NOTICE: Hello, world!
|
||||
url = '/sqleditor/query_tool/start/{0}'.format(self.trans_id)
|
||||
response = self.tester.post(url, data=json.dumps({"sql": s}),
|
||||
content_type='html/json')
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
url = '/sqleditor/poll/{0}'.format(self.trans_id)
|
||||
|
||||
|
@ -98,7 +98,9 @@ class TestSQLASCIIEncoding(BaseTestGenerator):
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'\
|
||||
.format(self.trans_id, test_utils.SERVER_GROUP, self.encode_sid,
|
||||
self.encode_did)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self.encode_db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Check character
|
||||
|
@ -7,6 +7,7 @@
|
||||
#
|
||||
##########################################################################
|
||||
|
||||
import json
|
||||
from pgadmin.utils.route import BaseTestGenerator
|
||||
from pgadmin.tools.sqleditor import StartRunningQuery
|
||||
from unittest.mock import patch, ANY
|
||||
@ -30,13 +31,13 @@ class StartQueryTool(BaseTestGenerator):
|
||||
return_value='some result'
|
||||
) as StartRunningQuery_execute_mock:
|
||||
response = self.tester.post(
|
||||
'/sqleditor/query_tool/start/1234',
|
||||
data='"some sql statement"'
|
||||
)
|
||||
|
||||
'/sqleditor/query_tool/start/1234', data=json.dumps({
|
||||
"sql": "some sql statement"}))
|
||||
self.assertEqual(response.status, '200 OK')
|
||||
self.assertEqual(response.data, b'some result')
|
||||
StartRunningQuery_execute_mock \
|
||||
.assert_called_with('transformed sql', 1234, ANY, False)
|
||||
extract_sql_from_network_parameters_mock \
|
||||
.assert_called_with(b'"some sql statement"', ANY, ANY)
|
||||
.assert_called_with(
|
||||
b'{"sql": "some sql statement"}',
|
||||
ANY, ANY)
|
||||
|
@ -306,7 +306,9 @@ class TestTransactionControl(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self.db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def _initialize_urls(self):
|
||||
|
@ -67,6 +67,7 @@ class StartRunningQuery:
|
||||
PG_DEFAULT_DRIVER).connection_manager(
|
||||
transaction_object.sid)
|
||||
conn = manager.connection(did=transaction_object.did,
|
||||
database=transaction_object.dbname,
|
||||
conn_id=self.connection_id,
|
||||
auto_reconnect=False,
|
||||
use_binary_placeholder=True,
|
||||
|
@ -8,6 +8,7 @@
|
||||
##########################################################################
|
||||
|
||||
import secrets
|
||||
import json
|
||||
|
||||
from pgadmin.browser.server_groups.servers.databases.tests import utils as \
|
||||
database_utils
|
||||
@ -204,7 +205,9 @@ class TestQueryUpdatableResultset(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self.db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def _initialize_urls(self):
|
||||
|
@ -923,7 +923,9 @@ class TestSaveChangedData(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self.db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def _initialize_urls_and_select_sql(self):
|
||||
|
@ -247,7 +247,9 @@ class TestSaveChangedDataUUID(BaseTestGenerator):
|
||||
self.trans_id = str(secrets.choice(range(1, 9999999)))
|
||||
url = '/sqleditor/initialize/sqleditor/{0}/{1}/{2}/{3}'.format(
|
||||
self.trans_id, utils.SERVER_GROUP, self.server_id, self.db_id)
|
||||
response = self.tester.post(url)
|
||||
response = self.tester.post(url, data=json.dumps({
|
||||
"dbname": self.db_name
|
||||
}))
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def _initialize_urls_and_select_sql(self):
|
||||
|
@ -30,7 +30,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict()
|
||||
http_session=dict(),
|
||||
),
|
||||
pickle_load_return=None,
|
||||
get_driver_exception=False,
|
||||
@ -60,7 +60,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict(gridData=dict())
|
||||
http_session=dict(gridData=dict()),
|
||||
),
|
||||
pickle_load_return=None,
|
||||
get_driver_exception=False,
|
||||
@ -91,7 +91,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict(gridData={'123': dict(command_obj='')})
|
||||
http_session=dict(gridData={'123': dict(command_obj='')}),
|
||||
),
|
||||
pickle_load_return=None,
|
||||
get_driver_exception=False,
|
||||
@ -126,7 +126,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict(gridData={'123': dict(command_obj='')})
|
||||
http_session=dict(gridData={'123': dict(command_obj='')}),
|
||||
),
|
||||
pickle_load_return=MagicMock(conn_id=1,
|
||||
update_fetched_row_cnt=MagicMock()),
|
||||
@ -155,7 +155,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict(gridData={'123': dict(command_obj='')})
|
||||
http_session=dict(gridData={'123': dict(command_obj='')}),
|
||||
),
|
||||
pickle_load_return=MagicMock(
|
||||
conn_id=1,
|
||||
@ -184,7 +184,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict(gridData={'123': dict(command_obj='')})
|
||||
http_session=dict(gridData={'123': dict(command_obj='')}),
|
||||
),
|
||||
pickle_load_return=MagicMock(
|
||||
conn_id=1,
|
||||
@ -212,7 +212,7 @@ class StartRunningQueryTest(BaseTestGenerator):
|
||||
function_parameters=dict(
|
||||
sql=dict(sql='some sql', explain_plan=None),
|
||||
trans_id=123,
|
||||
http_session=dict(gridData={'123': dict(command_obj='')})
|
||||
http_session=dict(gridData={'123': dict(command_obj='')}),
|
||||
),
|
||||
pickle_load_return=MagicMock(
|
||||
conn_id=1,
|
||||
|
@ -204,10 +204,13 @@ class ServerManager(object):
|
||||
if did is not None and did in self.db_info:
|
||||
self.db_info[did]['datname'] = database
|
||||
else:
|
||||
conn_str = 'CONN:{0}'.format(conn_id)
|
||||
if did is None:
|
||||
database = self.db
|
||||
elif did in self.db_info:
|
||||
database = self.db_info[did]['datname']
|
||||
elif conn_id and conn_str in self.connections:
|
||||
database = self.connections[conn_str].db
|
||||
else:
|
||||
maintenance_db_id = 'DB:{0}'.format(self.db)
|
||||
if maintenance_db_id in self.connections:
|
||||
|
Loading…
Reference in New Issue
Block a user