Use SocketIO instead of REST for schema diff compare. #4841

This commit is contained in:
Pravesh Sharma
2022-10-21 09:29:19 +05:30
committed by GitHub
parent 0384f55de1
commit 1647fc54e1
4 changed files with 139 additions and 161 deletions

View File

@@ -28,9 +28,12 @@ from pgadmin.utils.driver import get_driver
from pgadmin.utils.constants import PREF_LABEL_DISPLAY, MIMETYPE_APP_JS,\
ERROR_MSG_TRANS_ID_NOT_FOUND
from sqlalchemy import or_
from pgadmin.authenticate import socket_login_required
from pgadmin import socketio
MODULE_NAME = 'schema_diff'
COMPARE_MSG = gettext("Comparing objects...")
SOCKETIO_NAMESPACE = '/{0}'.format(MODULE_NAME)
class SchemaDiffModule(PgAdminModule):
@@ -59,9 +62,6 @@ class SchemaDiffModule(PgAdminModule):
'schema_diff.servers',
'schema_diff.databases',
'schema_diff.schemas',
'schema_diff.compare_database',
'schema_diff.compare_schema',
'schema_diff.poll',
'schema_diff.ddl_compare',
'schema_diff.connect_server',
'schema_diff.connect_database',
@@ -436,39 +436,38 @@ def schemas(sid, did):
return make_json_response(data=res)
@blueprint.route(
'/compare_database/<int:trans_id>/<int:source_sid>/<int:source_did>/'
'<int:target_sid>/<int:target_did>/<int:ignore_owner>/'
'<int:ignore_whitespaces>',
methods=["GET"],
endpoint="compare_database"
)
@login_required
def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
ignore_owner, ignore_whitespaces):
@socketio.on('compare_database', namespace=SOCKETIO_NAMESPACE)
@socket_login_required
def compare_database(params):
"""
This function will compare the two databases.
"""
# Check the pre validation before compare
status, error_msg, diff_model_obj, session_obj = \
compare_pre_validation(trans_id, source_sid, target_sid)
compare_pre_validation(params['trans_id'], params['source_sid'],
params['target_sid'])
if not status:
socketio.emit('compare_database_failed', error_msg,
namespace=SOCKETIO_NAMESPACE, to=request.sid)
return error_msg
comparison_result = []
diff_model_obj.set_comparison_info(COMPARE_MSG, 0)
update_session_diff_transaction(trans_id, session_obj,
socketio.emit('compare_status', {'diff_percentage': 0,
'compare_msg': COMPARE_MSG}, namespace=SOCKETIO_NAMESPACE,
to=request.sid)
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)
try:
ignore_owner = bool(ignore_owner)
ignore_whitespaces = bool(ignore_whitespaces)
ignore_owner = bool(params['ignore_owner'])
ignore_whitespaces = bool(params['ignore_whitespaces'])
# Fetch all the schemas of source and target database
# Compare them and get the status.
schema_result = fetch_compare_schemas(source_sid, source_did,
target_sid, target_did)
schema_result = \
fetch_compare_schemas(params['source_sid'], params['source_did'],
params['target_sid'], params['target_did'])
total_schema = len(schema_result['source_only']) + len(
schema_result['target_only']) + len(
@@ -483,12 +482,13 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
# Compare Database objects
comparison_schema_result, total_percent = \
compare_database_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
target_sid=target_sid, target_did=target_did,
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
target_sid=params['target_sid'],
target_did=params['target_did'],
diff_model_obj=diff_model_obj, total_percent=total_percent,
node_percent=node_percent,
ignore_owner=ignore_owner,
node_percent=node_percent, ignore_owner=ignore_owner,
ignore_whitespaces=ignore_whitespaces)
comparison_result = \
comparison_result + comparison_schema_result
@@ -499,10 +499,12 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
for item in schema_result['source_only']:
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=item['scid'], target_sid=target_sid,
target_did=target_did, target_scid=None,
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=item['scid'],
target_sid=params['target_sid'],
target_did=params['target_did'], target_scid=None,
schema_name=item['schema_name'],
diff_model_obj=diff_model_obj,
total_percent=total_percent,
@@ -519,10 +521,12 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
for item in schema_result['target_only']:
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=None, target_sid=target_sid,
target_did=target_did, target_scid=item['scid'],
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=None, target_sid=params['target_sid'],
target_did=params['target_did'],
target_scid=item['scid'],
schema_name=item['schema_name'],
diff_model_obj=diff_model_obj,
total_percent=total_percent,
@@ -539,10 +543,13 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
for item in schema_result['in_both_database']:
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=item['src_scid'], target_sid=target_sid,
target_did=target_did, target_scid=item['tar_scid'],
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=item['src_scid'],
target_sid=params['target_sid'],
target_did=params['target_did'],
target_scid=item['tar_scid'],
schema_name=item['schema_name'],
diff_model_obj=diff_model_obj,
total_percent=total_percent,
@@ -555,54 +562,54 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did,
msg = gettext("Successfully compare the specified databases.")
total_percent = 100
diff_model_obj.set_comparison_info(msg, total_percent)
# Update the message and total percentage done in session object
update_session_diff_transaction(trans_id, session_obj, diff_model_obj)
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)
except Exception as e:
app.logger.exception(e)
socketio.emit('compare_database_failed', str(e),
namespace=SOCKETIO_NAMESPACE, to=request.sid)
return make_json_response(data=comparison_result)
socketio.emit('compare_database_success', comparison_result,
namespace=SOCKETIO_NAMESPACE, to=request.sid)
@blueprint.route(
'/compare_schema/<int:trans_id>/<int:source_sid>/<int:source_did>/'
'<int:source_scid>/<int:target_sid>/<int:target_did>/<int:target_scid>/'
'<int:ignore_owner>/<int:ignore_whitespaces>',
methods=["GET"],
endpoint="compare_schema"
)
@login_required
def compare_schema(trans_id, source_sid, source_did, source_scid,
target_sid, target_did, target_scid, ignore_owner,
ignore_whitespaces):
@socketio.on('compare_schema', namespace=SOCKETIO_NAMESPACE)
@socket_login_required
def compare_schema(params):
"""
This function will compare the two schema.
"""
# Check the pre validation before compare
status, error_msg, diff_model_obj, session_obj = \
compare_pre_validation(trans_id, source_sid, target_sid)
compare_pre_validation(params['trans_id'], params['source_sid'],
params['target_sid'])
if not status:
socketio.emit('compare_schema_failed', error_msg,
namespace=SOCKETIO_NAMESPACE, to=request.sid)
return error_msg
comparison_result = []
diff_model_obj.set_comparison_info(COMPARE_MSG, 0)
update_session_diff_transaction(trans_id, session_obj,
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)
try:
ignore_owner = bool(ignore_owner)
ignore_whitespaces = bool(ignore_whitespaces)
ignore_owner = bool(params['ignore_owner'])
ignore_whitespaces = bool(params['ignore_whitespaces'])
all_registered_nodes = SchemaDiffRegistry.get_registered_nodes()
node_percent = round(100 / len(all_registered_nodes))
total_percent = 0
comparison_schema_result, total_percent = \
compare_schema_objects(
trans_id=trans_id, session_obj=session_obj,
source_sid=source_sid, source_did=source_did,
source_scid=source_scid, target_sid=target_sid,
target_did=target_did, target_scid=target_scid,
trans_id=params['trans_id'], session_obj=session_obj,
source_sid=params['source_sid'],
source_did=params['source_did'],
source_scid=params['source_scid'],
target_sid=params['target_sid'],
target_did=params['target_did'],
target_scid=params['target_scid'],
schema_name=gettext('Schema Objects'),
diff_model_obj=diff_model_obj,
total_percent=total_percent,
@@ -615,43 +622,16 @@ def compare_schema(trans_id, source_sid, source_did, source_scid,
msg = gettext("Successfully compare the specified schemas.")
total_percent = 100
diff_model_obj.set_comparison_info(msg, total_percent)
# Update the message and total percentage done in session object
update_session_diff_transaction(trans_id, session_obj, diff_model_obj)
update_session_diff_transaction(params['trans_id'], session_obj,
diff_model_obj)
except Exception as e:
app.logger.exception(e)
return make_json_response(data=comparison_result)
@blueprint.route(
'/poll/<int:trans_id>', methods=["GET"], endpoint="poll"
)
@login_required
def poll(trans_id):
"""
This function is used to check the schema comparison is completed or not.
:param trans_id:
:return:
"""
# Check the transaction and connection status
status, error_msg, diff_model_obj, session_obj = \
check_transaction_status(trans_id)
if error_msg == ERROR_MSG_TRANS_ID_NOT_FOUND:
return make_json_response(success=0, errormsg=error_msg, status=404)
msg, diff_percentage = diff_model_obj.get_comparison_info()
if diff_percentage == 100:
diff_model_obj.set_comparison_info(COMPARE_MSG, 0)
update_session_diff_transaction(trans_id, session_obj,
diff_model_obj)
return make_json_response(data={'compare_msg': msg,
'diff_percentage': diff_percentage})
socketio.emit('compare_schema_failed', str(e),
namespace=SOCKETIO_NAMESPACE, to=request.sid)
socketio.emit('compare_schema_success', comparison_result,
namespace=SOCKETIO_NAMESPACE, to=request.sid)
@blueprint.route(
@@ -781,7 +761,9 @@ def compare_database_objects(**kwargs):
msg = gettext('Comparing {0}'). \
format(gettext(view.blueprint.collection_label))
app.logger.debug(msg)
diff_model_obj.set_comparison_info(msg, total_percent)
socketio.emit('compare_status', {'diff_percentage': total_percent,
'compare_msg': msg}, namespace=SOCKETIO_NAMESPACE,
to=request.sid)
# Update the message and total percentage in session object
update_session_diff_transaction(trans_id, session_obj,
diff_model_obj)
@@ -843,7 +825,9 @@ def compare_schema_objects(**kwargs):
format(gettext(view.blueprint.collection_label),
gettext(schema_name))
app.logger.debug(msg)
diff_model_obj.set_comparison_info(msg, total_percent)
socketio.emit('compare_status', {'diff_percentage': total_percent,
'compare_msg': msg}, namespace=SOCKETIO_NAMESPACE,
to=request.sid)
# Update the message and total percentage in session object
update_session_diff_transaction(trans_id, session_obj,
diff_model_obj)
@@ -943,3 +927,15 @@ def compare_pre_validation(trans_id, source_sid, target_sid):
return False, res, None, None
return True, '', diff_model_obj, session_obj
@socketio.on('connect', namespace=SOCKETIO_NAMESPACE)
def connect():
"""
Connect to the server through socket.
:return:
:rtype:
"""
socketio.emit('connected', {'sid': request.sid},
namespace=SOCKETIO_NAMESPACE,
to=request.sid)

View File

@@ -32,8 +32,6 @@ class SchemaDiffModel(object):
**kwargs : N number of parameters
"""
self._comparison_result = dict()
self._comparison_msg = gettext('Comparision started...')
self._comparison_percentage = 0
def clear_data(self):
"""
@@ -59,20 +57,3 @@ class SchemaDiffModel(object):
return self._comparison_result[node_name]
return self._comparison_result
def get_comparison_info(self):
"""
This function is used to get the comparison information.
:return:
"""
return self._comparison_msg, self._comparison_percentage
def set_comparison_info(self, msg, percentage):
"""
This function is used to set the comparison information.
:param msg:
:param percentage:
:return:
"""
self._comparison_msg = msg
self._comparison_percentage = percentage

View File

@@ -30,7 +30,8 @@ import { InputComponent } from './InputComponent';
import { SchemaDiffButtonComponent } from './SchemaDiffButtonComponent';
import { SchemaDiffContext, SchemaDiffEventsContext } from './SchemaDiffComponent';
import { ResultGridComponent } from './ResultGridComponent';
import { openSocket, socketApiGet } from '../../../../../static/js/socket_instance';
import { parseApiError } from '../../../../../static/js/api_instance';
const useStyles = makeStyles(() => ({
table: {
@@ -268,15 +269,13 @@ export function SchemaDiffCompare({ params }) {
}
};
const triggerCompareDiff = ({ sourceData, targetData, compareParams, filterParams }) => {
const triggerCompareDiff = async ({ sourceData, targetData, compareParams, filterParams }) => {
setGridData([]);
setIsInit(false);
if (JSON.stringify(sourceData) === JSON.stringify(targetData)) {
Notifier.alert(gettext('Selection Error'),
gettext('Please select the different source and target.'));
} else {
getCompareStatus();
let schemaDiffPollInterval = setInterval(getCompareStatus, 1000);
setLoaderText('Comparing objects... (this may take a few minutes)...');
let url_params = {
'trans_id': params.transId,
@@ -287,28 +286,34 @@ export function SchemaDiffCompare({ params }) {
'ignore_owner': compareParams['ignoreOwner'],
'ignore_whitespaces': compareParams['ignoreWhitespaces'],
};
let baseUrl = url_for('schema_diff.compare_database', url_params);
let socketEndpoint = 'compare_database';
if (sourceData['scid'] != null && targetData['scid'] != null) {
url_params['source_scid'] = sourceData['scid'];
url_params['target_scid'] = targetData['scid'];
baseUrl = url_for('schema_diff.compare_schema', url_params);
socketEndpoint = 'compare_schema';
}
setCompareOptions(compareParams);
schemaDiffToolContext.api.get(baseUrl).then((res) => {
let resData = [];
let socket;
try {
setCompareOptions(compareParams);
socket = await openSocket('/schema_diff');
socket.on('compare_status', res=>{
let msg = res.compare_msg;
msg = msg + gettext(` (this may take a few minutes)... ${res.diff_percentage} %`);
setLoaderText(msg);
});
resData = await socketApiGet(socket, socketEndpoint, url_params);
setShowResultGrid(true);
setLoaderText(null);
clearInterval(schemaDiffPollInterval);
setFilterOptions(filterParams);
getResultGridData(res.data.data, filterParams);
}).catch((err) => {
clearInterval(schemaDiffPollInterval);
getResultGridData(resData, filterParams);
} catch (error) {
setLoaderText(null);
setShowResultGrid(false);
Notifier.alert(gettext('Schema compare error'), gettext(err.response.data.errormsg));
});
console.error(error);
Notifier.alert(gettext('Error'), parseApiError(error));
}
socket?.disconnect();
}
};
@@ -561,22 +566,6 @@ export function SchemaDiffCompare({ params }) {
setAllRowIdList([...new Set(allRowIds)]);
}
const getCompareStatus = () => {
let url_params = { 'trans_id': params.transId };
schemaDiffToolContext.api.get(url_for('schema_diff.poll', url_params)).then((res) => {
let msg = res.data.data.compare_msg;
if (res.data.data.diff_percentage != 100) {
msg = msg + gettext(` (this may take a few minutes)... ${res.data.data.diff_percentage} %`);
setLoaderText(msg);
}
})
.catch((err) => {
Notifier.error(gettext(err.message));
});
};
const connectDatabase = (sid, selectedDB, diff_type, databaseList) => {
schemaDiffToolContext.api({
method: 'POST',

View File

@@ -12,7 +12,7 @@ import json
import os
import secrets
from pgadmin.utils.route import BaseTestGenerator
from pgadmin.utils.route import BaseTestGenerator, BaseSocketTestGenerator
from regression import parent_node_dict
from regression.python_test_utils import test_utils as utils
from .utils import restore_schema
@@ -20,15 +20,17 @@ from pgadmin.utils.versioned_template_loader import \
get_version_mapping_directories
class SchemaDiffTestCase(BaseTestGenerator):
class SchemaDiffTestCase(BaseSocketTestGenerator):
""" This class will test the schema diff. """
scenarios = [
# Fetching default URL for database node.
('Schema diff comparison', dict(
url='schema_diff/compare_database/{0}/{1}/{2}/{3}/{4}/0/0'))
]
SOCKET_NAMESPACE = '/schema_diff'
def setUp(self):
super(SchemaDiffTestCase, self).setUp()
self.src_database = "db_schema_diff_src_%s" % str(uuid.uuid4())[1:8]
self.tar_database = "db_schema_diff_tar_%s" % str(uuid.uuid4())[1:8]
@@ -108,16 +110,22 @@ class SchemaDiffTestCase(BaseTestGenerator):
return None
def compare(self):
comp_url = self.url.format(self.trans_id, self.server_id,
self.src_db_id,
self.server_id,
self.tar_db_id
)
response = self.tester.get(comp_url)
self.assertEqual(response.status_code, 200)
return json.loads(response.data.decode('utf-8'))
data = {
'trans_id': self.trans_id,
'source_sid': self.server_id,
'source_did': self.src_db_id,
'target_sid': self.server_id,
'target_did': self.tar_db_id,
'ignore_owner': 0,
'ignore_whitespaces': 0
}
self.socket_client.emit('compare_database', data,
namespace=self.SOCKET_NAMESPACE)
received = self.socket_client.get_received(self.SOCKET_NAMESPACE)
response_data = received[-1]['args'][0]
self.assertEqual(received[-1]['name'], "compare_database_success",
response_data)
return response_data
def runTest(self):
""" This function will test the schema diff."""
@@ -127,6 +135,9 @@ class SchemaDiffTestCase(BaseTestGenerator):
response_data = json.loads(response.data.decode('utf-8'))
self.trans_id = response_data['data']['schemaDiffTransId']
received = self.socket_client.get_received(self.SOCKET_NAMESPACE)
assert received[0]['name'] == 'connected'
url = 'schema_diff/server/connect/{}'.format(self.server_id)
data = {'password': self.server['db_password']}
response = self.tester.post(url,
@@ -148,7 +159,7 @@ class SchemaDiffTestCase(BaseTestGenerator):
str(secrets.choice(range(1, 99999)))))
file_obj = open(diff_file, 'a')
for diff in response_data['data']:
for diff in response_data:
if diff['status'] == 'Identical':
src_obj_oid = diff['source_oid']
tar_obj_oid = diff['target_oid']
@@ -186,7 +197,7 @@ class SchemaDiffTestCase(BaseTestGenerator):
os.remove(diff_file)
response_data = self.compare()
for diff in response_data['data']:
for diff in response_data:
self.assertEqual(diff['status'], 'Identical')
except Exception as e:
if os.path.exists(diff_file):
@@ -194,6 +205,7 @@ class SchemaDiffTestCase(BaseTestGenerator):
def tearDown(self):
"""This function drop the added database"""
super(SchemaDiffTestCase, self).tearDown()
connection = utils.get_db_connection(self.server['db'],
self.server['username'],
self.server['db_password'],