diff --git a/web/pgadmin/tools/schema_diff/__init__.py b/web/pgadmin/tools/schema_diff/__init__.py index 7845a51df..71ed391b1 100644 --- a/web/pgadmin/tools/schema_diff/__init__.py +++ b/web/pgadmin/tools/schema_diff/__init__.py @@ -28,9 +28,12 @@ from pgadmin.utils.driver import get_driver from pgadmin.utils.constants import PREF_LABEL_DISPLAY, MIMETYPE_APP_JS,\ ERROR_MSG_TRANS_ID_NOT_FOUND from sqlalchemy import or_ +from pgadmin.authenticate import socket_login_required +from pgadmin import socketio MODULE_NAME = 'schema_diff' COMPARE_MSG = gettext("Comparing objects...") +SOCKETIO_NAMESPACE = '/{0}'.format(MODULE_NAME) class SchemaDiffModule(PgAdminModule): @@ -59,9 +62,6 @@ class SchemaDiffModule(PgAdminModule): 'schema_diff.servers', 'schema_diff.databases', 'schema_diff.schemas', - 'schema_diff.compare_database', - 'schema_diff.compare_schema', - 'schema_diff.poll', 'schema_diff.ddl_compare', 'schema_diff.connect_server', 'schema_diff.connect_database', @@ -436,39 +436,38 @@ def schemas(sid, did): return make_json_response(data=res) -@blueprint.route( - '/compare_database////' - '///' - '', - methods=["GET"], - endpoint="compare_database" -) -@login_required -def compare_database(trans_id, source_sid, source_did, target_sid, target_did, - ignore_owner, ignore_whitespaces): +@socketio.on('compare_database', namespace=SOCKETIO_NAMESPACE) +@socket_login_required +def compare_database(params): """ This function will compare the two databases. """ # Check the pre validation before compare status, error_msg, diff_model_obj, session_obj = \ - compare_pre_validation(trans_id, source_sid, target_sid) + compare_pre_validation(params['trans_id'], params['source_sid'], + params['target_sid']) if not status: + socketio.emit('compare_database_failed', error_msg, + namespace=SOCKETIO_NAMESPACE, to=request.sid) return error_msg comparison_result = [] - diff_model_obj.set_comparison_info(COMPARE_MSG, 0) - update_session_diff_transaction(trans_id, session_obj, + socketio.emit('compare_status', {'diff_percentage': 0, + 'compare_msg': COMPARE_MSG}, namespace=SOCKETIO_NAMESPACE, + to=request.sid) + update_session_diff_transaction(params['trans_id'], session_obj, diff_model_obj) try: - ignore_owner = bool(ignore_owner) - ignore_whitespaces = bool(ignore_whitespaces) + ignore_owner = bool(params['ignore_owner']) + ignore_whitespaces = bool(params['ignore_whitespaces']) # Fetch all the schemas of source and target database # Compare them and get the status. - schema_result = fetch_compare_schemas(source_sid, source_did, - target_sid, target_did) + schema_result = \ + fetch_compare_schemas(params['source_sid'], params['source_did'], + params['target_sid'], params['target_did']) total_schema = len(schema_result['source_only']) + len( schema_result['target_only']) + len( @@ -483,12 +482,13 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did, # Compare Database objects comparison_schema_result, total_percent = \ compare_database_objects( - trans_id=trans_id, session_obj=session_obj, - source_sid=source_sid, source_did=source_did, - target_sid=target_sid, target_did=target_did, + trans_id=params['trans_id'], session_obj=session_obj, + source_sid=params['source_sid'], + source_did=params['source_did'], + target_sid=params['target_sid'], + target_did=params['target_did'], diff_model_obj=diff_model_obj, total_percent=total_percent, - node_percent=node_percent, - ignore_owner=ignore_owner, + node_percent=node_percent, ignore_owner=ignore_owner, ignore_whitespaces=ignore_whitespaces) comparison_result = \ comparison_result + comparison_schema_result @@ -499,10 +499,12 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did, for item in schema_result['source_only']: comparison_schema_result, total_percent = \ compare_schema_objects( - trans_id=trans_id, session_obj=session_obj, - source_sid=source_sid, source_did=source_did, - source_scid=item['scid'], target_sid=target_sid, - target_did=target_did, target_scid=None, + trans_id=params['trans_id'], session_obj=session_obj, + source_sid=params['source_sid'], + source_did=params['source_did'], + source_scid=item['scid'], + target_sid=params['target_sid'], + target_did=params['target_did'], target_scid=None, schema_name=item['schema_name'], diff_model_obj=diff_model_obj, total_percent=total_percent, @@ -519,10 +521,12 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did, for item in schema_result['target_only']: comparison_schema_result, total_percent = \ compare_schema_objects( - trans_id=trans_id, session_obj=session_obj, - source_sid=source_sid, source_did=source_did, - source_scid=None, target_sid=target_sid, - target_did=target_did, target_scid=item['scid'], + trans_id=params['trans_id'], session_obj=session_obj, + source_sid=params['source_sid'], + source_did=params['source_did'], + source_scid=None, target_sid=params['target_sid'], + target_did=params['target_did'], + target_scid=item['scid'], schema_name=item['schema_name'], diff_model_obj=diff_model_obj, total_percent=total_percent, @@ -539,10 +543,13 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did, for item in schema_result['in_both_database']: comparison_schema_result, total_percent = \ compare_schema_objects( - trans_id=trans_id, session_obj=session_obj, - source_sid=source_sid, source_did=source_did, - source_scid=item['src_scid'], target_sid=target_sid, - target_did=target_did, target_scid=item['tar_scid'], + trans_id=params['trans_id'], session_obj=session_obj, + source_sid=params['source_sid'], + source_did=params['source_did'], + source_scid=item['src_scid'], + target_sid=params['target_sid'], + target_did=params['target_did'], + target_scid=item['tar_scid'], schema_name=item['schema_name'], diff_model_obj=diff_model_obj, total_percent=total_percent, @@ -555,54 +562,54 @@ def compare_database(trans_id, source_sid, source_did, target_sid, target_did, msg = gettext("Successfully compare the specified databases.") total_percent = 100 - diff_model_obj.set_comparison_info(msg, total_percent) # Update the message and total percentage done in session object - update_session_diff_transaction(trans_id, session_obj, diff_model_obj) + update_session_diff_transaction(params['trans_id'], session_obj, + diff_model_obj) except Exception as e: app.logger.exception(e) + socketio.emit('compare_database_failed', str(e), + namespace=SOCKETIO_NAMESPACE, to=request.sid) - return make_json_response(data=comparison_result) + socketio.emit('compare_database_success', comparison_result, + namespace=SOCKETIO_NAMESPACE, to=request.sid) -@blueprint.route( - '/compare_schema////' - '////' - '/', - methods=["GET"], - endpoint="compare_schema" -) -@login_required -def compare_schema(trans_id, source_sid, source_did, source_scid, - target_sid, target_did, target_scid, ignore_owner, - ignore_whitespaces): +@socketio.on('compare_schema', namespace=SOCKETIO_NAMESPACE) +@socket_login_required +def compare_schema(params): """ This function will compare the two schema. """ # Check the pre validation before compare status, error_msg, diff_model_obj, session_obj = \ - compare_pre_validation(trans_id, source_sid, target_sid) + compare_pre_validation(params['trans_id'], params['source_sid'], + params['target_sid']) if not status: + socketio.emit('compare_schema_failed', error_msg, + namespace=SOCKETIO_NAMESPACE, to=request.sid) return error_msg comparison_result = [] - diff_model_obj.set_comparison_info(COMPARE_MSG, 0) - update_session_diff_transaction(trans_id, session_obj, + update_session_diff_transaction(params['trans_id'], session_obj, diff_model_obj) try: - ignore_owner = bool(ignore_owner) - ignore_whitespaces = bool(ignore_whitespaces) + ignore_owner = bool(params['ignore_owner']) + ignore_whitespaces = bool(params['ignore_whitespaces']) all_registered_nodes = SchemaDiffRegistry.get_registered_nodes() node_percent = round(100 / len(all_registered_nodes)) total_percent = 0 comparison_schema_result, total_percent = \ compare_schema_objects( - trans_id=trans_id, session_obj=session_obj, - source_sid=source_sid, source_did=source_did, - source_scid=source_scid, target_sid=target_sid, - target_did=target_did, target_scid=target_scid, + trans_id=params['trans_id'], session_obj=session_obj, + source_sid=params['source_sid'], + source_did=params['source_did'], + source_scid=params['source_scid'], + target_sid=params['target_sid'], + target_did=params['target_did'], + target_scid=params['target_scid'], schema_name=gettext('Schema Objects'), diff_model_obj=diff_model_obj, total_percent=total_percent, @@ -615,43 +622,16 @@ def compare_schema(trans_id, source_sid, source_did, source_scid, msg = gettext("Successfully compare the specified schemas.") total_percent = 100 - diff_model_obj.set_comparison_info(msg, total_percent) # Update the message and total percentage done in session object - update_session_diff_transaction(trans_id, session_obj, diff_model_obj) + update_session_diff_transaction(params['trans_id'], session_obj, + diff_model_obj) except Exception as e: app.logger.exception(e) - - return make_json_response(data=comparison_result) - - -@blueprint.route( - '/poll/', methods=["GET"], endpoint="poll" -) -@login_required -def poll(trans_id): - """ - This function is used to check the schema comparison is completed or not. - :param trans_id: - :return: - """ - - # Check the transaction and connection status - status, error_msg, diff_model_obj, session_obj = \ - check_transaction_status(trans_id) - - if error_msg == ERROR_MSG_TRANS_ID_NOT_FOUND: - return make_json_response(success=0, errormsg=error_msg, status=404) - - msg, diff_percentage = diff_model_obj.get_comparison_info() - - if diff_percentage == 100: - diff_model_obj.set_comparison_info(COMPARE_MSG, 0) - update_session_diff_transaction(trans_id, session_obj, - diff_model_obj) - - return make_json_response(data={'compare_msg': msg, - 'diff_percentage': diff_percentage}) + socketio.emit('compare_schema_failed', str(e), + namespace=SOCKETIO_NAMESPACE, to=request.sid) + socketio.emit('compare_schema_success', comparison_result, + namespace=SOCKETIO_NAMESPACE, to=request.sid) @blueprint.route( @@ -781,7 +761,9 @@ def compare_database_objects(**kwargs): msg = gettext('Comparing {0}'). \ format(gettext(view.blueprint.collection_label)) app.logger.debug(msg) - diff_model_obj.set_comparison_info(msg, total_percent) + socketio.emit('compare_status', {'diff_percentage': total_percent, + 'compare_msg': msg}, namespace=SOCKETIO_NAMESPACE, + to=request.sid) # Update the message and total percentage in session object update_session_diff_transaction(trans_id, session_obj, diff_model_obj) @@ -843,7 +825,9 @@ def compare_schema_objects(**kwargs): format(gettext(view.blueprint.collection_label), gettext(schema_name)) app.logger.debug(msg) - diff_model_obj.set_comparison_info(msg, total_percent) + socketio.emit('compare_status', {'diff_percentage': total_percent, + 'compare_msg': msg}, namespace=SOCKETIO_NAMESPACE, + to=request.sid) # Update the message and total percentage in session object update_session_diff_transaction(trans_id, session_obj, diff_model_obj) @@ -943,3 +927,15 @@ def compare_pre_validation(trans_id, source_sid, target_sid): return False, res, None, None return True, '', diff_model_obj, session_obj + + +@socketio.on('connect', namespace=SOCKETIO_NAMESPACE) +def connect(): + """ + Connect to the server through socket. + :return: + :rtype: + """ + socketio.emit('connected', {'sid': request.sid}, + namespace=SOCKETIO_NAMESPACE, + to=request.sid) diff --git a/web/pgadmin/tools/schema_diff/model.py b/web/pgadmin/tools/schema_diff/model.py index 2dfb53f28..44428ea36 100644 --- a/web/pgadmin/tools/schema_diff/model.py +++ b/web/pgadmin/tools/schema_diff/model.py @@ -32,8 +32,6 @@ class SchemaDiffModel(object): **kwargs : N number of parameters """ self._comparison_result = dict() - self._comparison_msg = gettext('Comparision started...') - self._comparison_percentage = 0 def clear_data(self): """ @@ -59,20 +57,3 @@ class SchemaDiffModel(object): return self._comparison_result[node_name] return self._comparison_result - - def get_comparison_info(self): - """ - This function is used to get the comparison information. - :return: - """ - return self._comparison_msg, self._comparison_percentage - - def set_comparison_info(self, msg, percentage): - """ - This function is used to set the comparison information. - :param msg: - :param percentage: - :return: - """ - self._comparison_msg = msg - self._comparison_percentage = percentage diff --git a/web/pgadmin/tools/schema_diff/static/js/components/SchemaDiffCompare.jsx b/web/pgadmin/tools/schema_diff/static/js/components/SchemaDiffCompare.jsx index 3eca58828..b6df3b426 100644 --- a/web/pgadmin/tools/schema_diff/static/js/components/SchemaDiffCompare.jsx +++ b/web/pgadmin/tools/schema_diff/static/js/components/SchemaDiffCompare.jsx @@ -30,7 +30,8 @@ import { InputComponent } from './InputComponent'; import { SchemaDiffButtonComponent } from './SchemaDiffButtonComponent'; import { SchemaDiffContext, SchemaDiffEventsContext } from './SchemaDiffComponent'; import { ResultGridComponent } from './ResultGridComponent'; - +import { openSocket, socketApiGet } from '../../../../../static/js/socket_instance'; +import { parseApiError } from '../../../../../static/js/api_instance'; const useStyles = makeStyles(() => ({ table: { @@ -268,15 +269,13 @@ export function SchemaDiffCompare({ params }) { } }; - const triggerCompareDiff = ({ sourceData, targetData, compareParams, filterParams }) => { + const triggerCompareDiff = async ({ sourceData, targetData, compareParams, filterParams }) => { setGridData([]); setIsInit(false); if (JSON.stringify(sourceData) === JSON.stringify(targetData)) { Notifier.alert(gettext('Selection Error'), gettext('Please select the different source and target.')); } else { - getCompareStatus(); - let schemaDiffPollInterval = setInterval(getCompareStatus, 1000); setLoaderText('Comparing objects... (this may take a few minutes)...'); let url_params = { 'trans_id': params.transId, @@ -287,28 +286,34 @@ export function SchemaDiffCompare({ params }) { 'ignore_owner': compareParams['ignoreOwner'], 'ignore_whitespaces': compareParams['ignoreWhitespaces'], }; - - let baseUrl = url_for('schema_diff.compare_database', url_params); + let socketEndpoint = 'compare_database'; if (sourceData['scid'] != null && targetData['scid'] != null) { url_params['source_scid'] = sourceData['scid']; url_params['target_scid'] = targetData['scid']; - baseUrl = url_for('schema_diff.compare_schema', url_params); + socketEndpoint = 'compare_schema'; } - - setCompareOptions(compareParams); - schemaDiffToolContext.api.get(baseUrl).then((res) => { + let resData = []; + let socket; + try { + setCompareOptions(compareParams); + socket = await openSocket('/schema_diff'); + socket.on('compare_status', res=>{ + let msg = res.compare_msg; + msg = msg + gettext(` (this may take a few minutes)... ${res.diff_percentage} %`); + setLoaderText(msg); + }); + resData = await socketApiGet(socket, socketEndpoint, url_params); setShowResultGrid(true); setLoaderText(null); - clearInterval(schemaDiffPollInterval); setFilterOptions(filterParams); - getResultGridData(res.data.data, filterParams); - }).catch((err) => { - clearInterval(schemaDiffPollInterval); + getResultGridData(resData, filterParams); + } catch (error) { setLoaderText(null); setShowResultGrid(false); - Notifier.alert(gettext('Schema compare error'), gettext(err.response.data.errormsg)); - }); - + console.error(error); + Notifier.alert(gettext('Error'), parseApiError(error)); + } + socket?.disconnect(); } }; @@ -561,22 +566,6 @@ export function SchemaDiffCompare({ params }) { setAllRowIdList([...new Set(allRowIds)]); } - const getCompareStatus = () => { - let url_params = { 'trans_id': params.transId }; - - schemaDiffToolContext.api.get(url_for('schema_diff.poll', url_params)).then((res) => { - let msg = res.data.data.compare_msg; - if (res.data.data.diff_percentage != 100) { - msg = msg + gettext(` (this may take a few minutes)... ${res.data.data.diff_percentage} %`); - setLoaderText(msg); - } - - }) - .catch((err) => { - Notifier.error(gettext(err.message)); - }); - }; - const connectDatabase = (sid, selectedDB, diff_type, databaseList) => { schemaDiffToolContext.api({ method: 'POST', diff --git a/web/pgadmin/tools/schema_diff/tests/test_schema_diff_comp.py b/web/pgadmin/tools/schema_diff/tests/test_schema_diff_comp.py index d1714fda2..84ce680d2 100644 --- a/web/pgadmin/tools/schema_diff/tests/test_schema_diff_comp.py +++ b/web/pgadmin/tools/schema_diff/tests/test_schema_diff_comp.py @@ -12,7 +12,7 @@ import json import os import secrets -from pgadmin.utils.route import BaseTestGenerator +from pgadmin.utils.route import BaseTestGenerator, BaseSocketTestGenerator from regression import parent_node_dict from regression.python_test_utils import test_utils as utils from .utils import restore_schema @@ -20,15 +20,17 @@ from pgadmin.utils.versioned_template_loader import \ get_version_mapping_directories -class SchemaDiffTestCase(BaseTestGenerator): +class SchemaDiffTestCase(BaseSocketTestGenerator): """ This class will test the schema diff. """ scenarios = [ # Fetching default URL for database node. ('Schema diff comparison', dict( url='schema_diff/compare_database/{0}/{1}/{2}/{3}/{4}/0/0')) ] + SOCKET_NAMESPACE = '/schema_diff' def setUp(self): + super(SchemaDiffTestCase, self).setUp() self.src_database = "db_schema_diff_src_%s" % str(uuid.uuid4())[1:8] self.tar_database = "db_schema_diff_tar_%s" % str(uuid.uuid4())[1:8] @@ -108,16 +110,22 @@ class SchemaDiffTestCase(BaseTestGenerator): return None def compare(self): - comp_url = self.url.format(self.trans_id, self.server_id, - self.src_db_id, - self.server_id, - self.tar_db_id - ) - - response = self.tester.get(comp_url) - - self.assertEqual(response.status_code, 200) - return json.loads(response.data.decode('utf-8')) + data = { + 'trans_id': self.trans_id, + 'source_sid': self.server_id, + 'source_did': self.src_db_id, + 'target_sid': self.server_id, + 'target_did': self.tar_db_id, + 'ignore_owner': 0, + 'ignore_whitespaces': 0 + } + self.socket_client.emit('compare_database', data, + namespace=self.SOCKET_NAMESPACE) + received = self.socket_client.get_received(self.SOCKET_NAMESPACE) + response_data = received[-1]['args'][0] + self.assertEqual(received[-1]['name'], "compare_database_success", + response_data) + return response_data def runTest(self): """ This function will test the schema diff.""" @@ -127,6 +135,9 @@ class SchemaDiffTestCase(BaseTestGenerator): response_data = json.loads(response.data.decode('utf-8')) self.trans_id = response_data['data']['schemaDiffTransId'] + received = self.socket_client.get_received(self.SOCKET_NAMESPACE) + assert received[0]['name'] == 'connected' + url = 'schema_diff/server/connect/{}'.format(self.server_id) data = {'password': self.server['db_password']} response = self.tester.post(url, @@ -148,7 +159,7 @@ class SchemaDiffTestCase(BaseTestGenerator): str(secrets.choice(range(1, 99999))))) file_obj = open(diff_file, 'a') - for diff in response_data['data']: + for diff in response_data: if diff['status'] == 'Identical': src_obj_oid = diff['source_oid'] tar_obj_oid = diff['target_oid'] @@ -186,7 +197,7 @@ class SchemaDiffTestCase(BaseTestGenerator): os.remove(diff_file) response_data = self.compare() - for diff in response_data['data']: + for diff in response_data: self.assertEqual(diff['status'], 'Identical') except Exception as e: if os.path.exists(diff_file): @@ -194,6 +205,7 @@ class SchemaDiffTestCase(BaseTestGenerator): def tearDown(self): """This function drop the added database""" + super(SchemaDiffTestCase, self).tearDown() connection = utils.get_db_connection(self.server['db'], self.server['username'], self.server['db_password'],