########################################################################## # # pgAdmin 4 - PostgreSQL Tools # # Copyright (C) 2013 - 2022, The pgAdmin Development Team # This software is released under the PostgreSQL Licence # ########################################################################## import os import sys import json import subprocess from collections import defaultdict from operator import attrgetter from flask import Blueprint, current_app from flask_babel import gettext from flask_security import current_user, login_required from threading import Lock from .paths import get_storage_directory from .preferences import Preferences from pgadmin.utils.constants import UTILITIES_ARRAY, USER_NOT_FOUND from pgadmin.model import db, User, ServerGroup, Server from urllib.parse import unquote ADD_SERVERS_MSG = "Added %d Server Group(s) and %d Server(s)." class PgAdminModule(Blueprint): """ Base class for every PgAdmin Module. This class defines a set of method and attributes that every module should implement. """ def __init__(self, name, import_name, **kwargs): kwargs.setdefault('url_prefix', '/' + name) kwargs.setdefault('template_folder', 'templates') kwargs.setdefault('static_folder', 'static') self.submodules = [] self.parentmodules = [] super(PgAdminModule, self).__init__(name, import_name, **kwargs) def create_module_preference(): # Create preference for each module by default if hasattr(self, 'LABEL'): self.preference = Preferences(self.name, self.LABEL) else: self.preference = Preferences(self.name, None) self.register_preferences() # Create and register the module preference object and preferences for # it just before the first request self.before_app_first_request(create_module_preference) def register_preferences(self): # To be implemented by child classes pass def register(self, app, options): """ Override the default register function to automagically register sub-modules at once. """ super(PgAdminModule, self).register(app, options) for module in self.submodules: module.parentmodules.append(self) if app.blueprints.get(module.name) is None: app.register_blueprint(module) app.register_logout_hook(module) def get_own_stylesheets(self): """ Returns: list: the stylesheets used by this module, not including any stylesheet needed by the submodules. """ return [] def get_own_messages(self): """ Returns: dict: the i18n messages used by this module, not including any messages needed by the submodules. """ return dict() def get_own_menuitems(self): """ Returns: dict: the menuitems for this module, not including any needed from the submodules. """ return defaultdict(list) def get_panels(self): """ Returns: list: a list of panel objects to add """ return [] def get_exposed_url_endpoints(self): """ Returns: list: a list of url endpoints exposed to the client. """ return [] @property def stylesheets(self): stylesheets = self.get_own_stylesheets() for module in self.submodules: stylesheets.extend(module.stylesheets) return stylesheets @property def messages(self): res = self.get_own_messages() for module in self.submodules: res.update(module.messages) return res @property def menu_items(self): menu_items = self.get_own_menuitems() for module in self.submodules: for key, value in module.menu_items.items(): menu_items[key].extend(value) menu_items = dict((key, sorted(value, key=attrgetter('priority'))) for key, value in menu_items.items()) return menu_items @property def exposed_endpoints(self): res = self.get_exposed_url_endpoints() for module in self.submodules: res += module.exposed_endpoints return res IS_WIN = (os.name == 'nt') sys_encoding = sys.getdefaultencoding() if not sys_encoding or sys_encoding == 'ascii': # Fall back to 'utf-8', if we couldn't determine the default encoding, # or 'ascii'. sys_encoding = 'utf-8' fs_encoding = sys.getfilesystemencoding() if not fs_encoding or fs_encoding == 'ascii': # Fall back to 'utf-8', if we couldn't determine the file-system encoding, # or 'ascii'. fs_encoding = 'utf-8' def u_encode(_s, _encoding=sys_encoding): return _s def file_quote(_p): return _p if IS_WIN: import ctypes from ctypes import wintypes def env(name): if name in os.environ: return os.environ[name] return None _GetShortPathNameW = ctypes.windll.kernel32.GetShortPathNameW _GetShortPathNameW.argtypes = [ wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.DWORD ] _GetShortPathNameW.restype = wintypes.DWORD def fs_short_path(_path): """ Gets the short path name of a given long path. http://stackoverflow.com/a/23598461/200291 """ buf_size = len(_path) while True: res = ctypes.create_unicode_buffer(buf_size) # Note:- _GetShortPathNameW may return empty value # if directory doesn't exist. needed = _GetShortPathNameW(_path, res, buf_size) if buf_size >= needed: return res.value else: buf_size += needed def document_dir(): CSIDL_PERSONAL = 5 # My Documents SHGFP_TYPE_CURRENT = 0 # Get current, not default value buf = ctypes.create_unicode_buffer(wintypes.MAX_PATH) ctypes.windll.shell32.SHGetFolderPathW( None, CSIDL_PERSONAL, None, SHGFP_TYPE_CURRENT, buf ) return buf.value else: def env(name): if name in os.environ: return os.environ[name] return None def fs_short_path(_path): return _path def document_dir(): return os.path.realpath(os.path.expanduser('~/')) def get_complete_file_path(file, validate=True): """ Args: file: File returned by file manager Returns: Full path for the file """ if not file: return None # If desktop mode if current_app.PGADMIN_RUNTIME or not current_app.config['SERVER_MODE']: return file if os.path.isfile(file) else None storage_dir = get_storage_directory() if storage_dir: file = os.path.join( storage_dir, file.lstrip('/').lstrip('\\') ) if IS_WIN: file = file.replace('\\', '/') file = fs_short_path(file) if validate: return file if os.path.isfile(file) else None else: return file def does_utility_exist(file): """ This function will check the utility file exists on given path. :return: """ error_msg = None if file is None: error_msg = gettext("Utility file not found. Please correct the Binary" " Path in the Preferences dialog") return error_msg if not os.path.exists(file): error_msg = gettext("'%s' file not found. Please correct the Binary" " Path in the Preferences dialog" % file) return error_msg def get_server(sid): """ # Fetch the server etc :param sid: :return: server """ server = Server.query.filter_by(id=sid).first() return server def set_binary_path(binary_path, bin_paths, server_type, version_number=None, set_as_default=False): """ This function is used to iterate through the utilities and set the default binary path. """ path_with_dir = binary_path if "$DIR" in binary_path else None # Check if "$DIR" present in binary path binary_path = replace_binary_path(binary_path) for utility in UTILITIES_ARRAY: full_path = os.path.abspath( os.path.join(binary_path, (utility if os.name != 'nt' else (utility + '.exe')))) try: # if version_number is provided then no need to fetch it. if version_number is None: # Get the output of the '--version' command version_string = \ subprocess.getoutput('"{0}" --version'.format(full_path)) # Get the version number by splitting the result string version_number = \ version_string.split(") ", 1)[1].split('.', 1)[0] elif version_number.find('.'): version_number = version_number.split('.', 1)[0] # Get the paths array based on server type if 'pg_bin_paths' in bin_paths or 'as_bin_paths' in bin_paths: paths_array = bin_paths['pg_bin_paths'] if server_type == 'ppas': paths_array = bin_paths['as_bin_paths'] else: paths_array = bin_paths for path in paths_array: if path['version'].find(version_number) == 0 and \ path['binaryPath'] is None: path['binaryPath'] = path_with_dir \ if path_with_dir is not None else binary_path if set_as_default: path['isDefault'] = True break break except Exception: continue def replace_binary_path(binary_path): """ This function is used to check if $DIR is present in the binary path. If it is there then replace it with module. """ if "$DIR" in binary_path: # When running as an WSGI application, we will not find the # '__file__' attribute for the '__main__' module. main_module_file = getattr( sys.modules['__main__'], '__file__', None ) if main_module_file is not None: binary_path = binary_path.replace( "$DIR", os.path.dirname(main_module_file) ) return binary_path def add_value(attr_dict, key, value): """Add a value to the attribute dict if non-empty. Args: attr_dict (dict): The dictionary to add the values to key (str): The key for the new value value (str): The value to add Returns: The updated attribute dictionary """ if value != "" and value is not None: attr_dict[key] = value return attr_dict def dump_database_servers(output_file, selected_servers, dump_user=current_user, from_setup=False): """Dump the server groups and servers. """ user = _does_user_exist(dump_user, from_setup) if user is None: return False, USER_NOT_FOUND % dump_user user_id = user.id # Dict to collect the output object_dict = {} # Counters servers_dumped = 0 # Dump servers servers = Server.query.filter_by(user_id=user_id).all() server_dict = {} for server in servers: if selected_servers is None or str(server.id) in selected_servers: # Get the group name group_name = ServerGroup.query.filter_by( user_id=user_id, id=server.servergroup_id).first().name attr_dict = {} add_value(attr_dict, "Name", server.name) add_value(attr_dict, "Group", group_name) add_value(attr_dict, "Host", server.host) add_value(attr_dict, "HostAddr", server.hostaddr) add_value(attr_dict, "Port", server.port) add_value(attr_dict, "MaintenanceDB", server.maintenance_db) add_value(attr_dict, "Username", server.username) add_value(attr_dict, "Role", server.role) add_value(attr_dict, "SSLMode", server.ssl_mode) add_value(attr_dict, "Comment", server.comment) add_value(attr_dict, "Shared", server.shared) add_value(attr_dict, "DBRestriction", server.db_res) add_value(attr_dict, "PassFile", server.passfile) add_value(attr_dict, "SSLCert", server.sslcert) add_value(attr_dict, "SSLKey", server.sslkey) add_value(attr_dict, "SSLRootCert", server.sslrootcert) add_value(attr_dict, "SSLCrl", server.sslcrl) add_value(attr_dict, "SSLCompression", server.sslcompression) add_value(attr_dict, "BGColor", server.bgcolor) add_value(attr_dict, "FGColor", server.fgcolor) add_value(attr_dict, "Service", server.service) add_value(attr_dict, "Timeout", server.connect_timeout) add_value(attr_dict, "UseSSHTunnel", server.use_ssh_tunnel) add_value(attr_dict, "TunnelHost", server.tunnel_host) add_value(attr_dict, "TunnelPort", server.tunnel_port) add_value(attr_dict, "TunnelUsername", server.tunnel_username) add_value(attr_dict, "TunnelAuthentication", server.tunnel_authentication) servers_dumped = servers_dumped + 1 server_dict[servers_dumped] = attr_dict object_dict["Servers"] = server_dict # retrieve storage directory path storage_manager_path = None if not from_setup: storage_manager_path = get_storage_directory(user) # generate full path of file file_path = unquote(output_file) from pgadmin.misc.file_manager import Filemanager try: Filemanager.check_access_permission(storage_manager_path, file_path) except Exception as e: return _handle_error(str(e), from_setup) if storage_manager_path is not None: file_path = os.path.join( storage_manager_path, file_path.lstrip('/').lstrip('\\') ) # write to file file_content = json.dumps(object_dict, indent=4) error_str = "Error: {0}" try: with open(file_path, 'w') as output_file: output_file.write(file_content) except IOError as e: err_msg = error_str.format(e.strerror) return _handle_error(err_msg, from_setup) except Exception as e: err_msg = error_str.format(e.strerror) return _handle_error(err_msg, from_setup) msg = gettext("Configuration for %s servers dumped to %s" % (servers_dumped, output_file.name)) print(msg) return True, msg def validate_json_data(data, is_admin): """ Used internally by load_servers to validate servers data. :param data: servers data :param is_admin: :return: error message if any """ skip_servers = [] # Loop through the servers... if "Servers" not in data: return gettext("'Servers' attribute not found in the specified file.") for server in data["Servers"]: obj = data["Servers"][server] # Check if server is shared.Won't import if user is non-admin if obj.get('Shared', None) and not is_admin: print("Won't import the server '%s' as it is shared " % obj["Name"]) skip_servers.append(server) continue def check_attrib(attrib): if attrib not in obj: return gettext("'%s' attribute not found for server '%s'" % (attrib, server)) return None def check_is_integer(value): if not isinstance(value, int): return gettext("Port must be integer for server '%s'" % server) return None for attrib in ("Group", "Name"): errmsg = check_attrib(attrib) if errmsg: return errmsg is_service_attrib_available = obj.get("Service", None) is not None if not is_service_attrib_available: for attrib in ("Port", "Username"): errmsg = check_attrib(attrib) if errmsg: return errmsg if attrib == 'Port': errmsg = check_is_integer(obj[attrib]) if errmsg: return errmsg for attrib in ("SSLMode", "MaintenanceDB"): errmsg = check_attrib(attrib) if errmsg: return errmsg if "Host" not in obj and "HostAddr" not in obj and not \ is_service_attrib_available: return gettext("'Host', 'HostAddr' or 'Service' attribute not " "found for server '%s'" % server) for server in skip_servers: del data["Servers"][server] return None def load_database_servers(input_file, selected_servers, load_user=current_user, from_setup=False): """Load server groups and servers. """ user = _does_user_exist(load_user, from_setup) if user is None: return False, USER_NOT_FOUND % load_user # retrieve storage directory path storage_manager_path = None if not from_setup: storage_manager_path = get_storage_directory(user) # generate full path of file file_path = unquote(input_file) if storage_manager_path: # generate full path of file file_path = os.path.join( storage_manager_path, file_path.lstrip('/').lstrip('\\') ) try: with open(file_path) as f: data = json.load(f) except json.decoder.JSONDecodeError as e: return _handle_error(gettext("Error parsing input file %s: %s" % (file_path, e)), from_setup) except Exception as e: return _handle_error(gettext("Error reading input file %s: [%d] %s" % (file_path, e.errno, e.strerror)), from_setup) f.close() user_id = user.id # Counters groups_added = 0 servers_added = 0 # Get the server groups groups = ServerGroup.query.filter_by(user_id=user_id) # Validate server data error_msg = validate_json_data(data, user.has_role("Administrator")) if error_msg is not None and from_setup: print(ADD_SERVERS_MSG % (groups_added, servers_added)) return _handle_error(error_msg, from_setup) for server in data["Servers"]: if selected_servers is None or str(server) in selected_servers: obj = data["Servers"][server] # Get the group. Create if necessary group_id = next( (g.id for g in groups if g.name == obj["Group"]), -1) if group_id == -1: new_group = ServerGroup() new_group.name = obj["Group"] new_group.user_id = user_id db.session.add(new_group) try: db.session.commit() except Exception as e: if from_setup: print(ADD_SERVERS_MSG % (groups_added, servers_added)) return _handle_error( gettext("Error creating server group '%s': %s" % (new_group.name, e)), from_setup) group_id = new_group.id groups_added = groups_added + 1 groups = ServerGroup.query.filter_by(user_id=user_id) # Create the server new_server = Server() new_server.name = obj["Name"] new_server.servergroup_id = group_id new_server.user_id = user_id new_server.ssl_mode = obj["SSLMode"] new_server.maintenance_db = obj["MaintenanceDB"] new_server.host = obj.get("Host", None) new_server.hostaddr = obj.get("HostAddr", None) new_server.port = obj.get("Port", None) new_server.username = obj.get("Username", None) new_server.role = obj.get("Role", None) new_server.ssl_mode = obj["SSLMode"] new_server.comment = obj.get("Comment", None) new_server.db_res = obj.get("DBRestriction", None) new_server.passfile = obj.get("PassFile", None) new_server.sslcert = obj.get("SSLCert", None) new_server.sslkey = obj.get("SSLKey", None) new_server.sslrootcert = obj.get("SSLRootCert", None) new_server.sslcrl = obj.get("SSLCrl", None) new_server.sslcompression = obj.get("SSLCompression", None) new_server.bgcolor = obj.get("BGColor", None) new_server.fgcolor = obj.get("FGColor", None) new_server.service = obj.get("Service", None) new_server.connect_timeout = obj.get("Timeout", None) new_server.use_ssh_tunnel = obj.get("UseSSHTunnel", None) new_server.tunnel_host = obj.get("TunnelHost", None) new_server.tunnel_port = obj.get("TunnelPort", None) new_server.tunnel_username = obj.get("TunnelUsername", None) new_server.tunnel_authentication = \ obj.get("TunnelAuthentication", None) new_server.shared = \ obj.get("Shared", None) db.session.add(new_server) try: db.session.commit() except Exception as e: if from_setup: print(ADD_SERVERS_MSG % (groups_added, servers_added)) return _handle_error(gettext("Error creating server '%s': %s" % (new_server.name, e)), from_setup) servers_added = servers_added + 1 msg = ADD_SERVERS_MSG % (groups_added, servers_added) print(msg) return True, msg def clear_database_servers(load_user=current_user, from_setup=False): """Clear groups and servers configurations. """ user = _does_user_exist(load_user, from_setup) if user is None: return False user_id = user.id # Remove all servers servers = Server.query.filter_by(user_id=user_id) for server in servers: db.session.delete(server) # Remove all groups groups = ServerGroup.query.filter_by(user_id=user_id) for group in groups: db.session.delete(group) servers = Server.query.filter_by(user_id=user_id) for server in servers: db.session.delete(server) try: db.session.commit() except Exception as e: error_msg = \ gettext("Error clearing server configuration with error (%s)" % str(e)) if from_setup: print(error_msg) sys.exit(1) return False, error_msg def _does_user_exist(user, from_setup): """ This function will check user is exist or not. If exist then return """ if isinstance(user, User): user = user.email user = User.query.filter_by(email=user).first() if user is None: print(USER_NOT_FOUND % user) if from_setup: sys.exit(1) return user def _handle_error(error_msg, from_setup): """ This function is used to print the error msg and exit from app if called from setup.py """ if from_setup: print(error_msg) sys.exit(1) return False, error_msg # Shortcut configuration for Accesskey ACCESSKEY_FIELDS = [ { 'name': 'key', 'type': 'keyCode', 'label': gettext('Key') } ] # Shortcut configuration SHORTCUT_FIELDS = [ { 'name': 'key', 'type': 'keyCode', 'label': gettext('Key') }, { 'name': 'shift', 'type': 'checkbox', 'label': gettext('Shift') }, { 'name': 'control', 'type': 'checkbox', 'label': gettext('Ctrl') }, { 'name': 'alt', 'type': 'checkbox', 'label': gettext('Alt/Option') } ] class KeyManager: def __init__(self): self.users = dict() self.lock = Lock() @login_required def get(self): user = self.users.get(current_user.id, None) if user is not None: return user.get('key', None) @login_required def set(self, _key, _new_login=True): with self.lock: user = self.users.get(current_user.id, None) if user is None: self.users[current_user.id] = dict( session_count=1, key=_key) else: if _new_login: user['session_count'] += 1 user['key'] = _key @login_required def reset(self): with self.lock: user = self.users.get(current_user.id, None) if user is not None: # This will not decrement if session expired user['session_count'] -= 1 if user['session_count'] == 0: del self.users[current_user.id] @login_required def hard_reset(self): with self.lock: user = self.users.get(current_user.id, None) if user is not None: del self.users[current_user.id]