pgadmin4/web/setup.py

447 lines
15 KiB
Python

#########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2020, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""Perform the initial setup of the application, by creating the auth
and settings database."""
import argparse
import json
import os
import sys
import builtins
from pgadmin.model import db, User, Version, ServerGroup, Server, \
SCHEMA_VERSION as CURRENT_SCHEMA_VERSION
# Grab the SERVER_MODE if it's been set by the runtime
if 'SERVER_MODE' in globals():
builtins.SERVER_MODE = globals()['SERVER_MODE']
else:
builtins.SERVER_MODE = None
# We need to include the root directory in sys.path to ensure that we can
# find everything we need when running in the standalone runtime.
root = os.path.dirname(os.path.realpath(__file__))
if sys.path[0] != root:
sys.path.insert(0, root)
from pgadmin import create_app
def add_value(attr_dict, key, value):
"""Add a value to the attribute dict if non-empty.
Args:
attr_dict (dict): The dictionary to add the values to
key (str): The key for the new value
value (str): The value to add
Returns:
The updated attribute dictionary
"""
if value != "" and value is not None:
attr_dict[key] = value
return attr_dict
def dump_servers(args):
"""Dump the server groups and servers.
Args:
args (ArgParser): The parsed command line options
"""
# What user?
if args.user is not None:
dump_user = args.user
else:
dump_user = config.DESKTOP_USER
# And the sqlite path
if args.sqlite_path is not None:
config.SQLITE_PATH = args.sqlite_path
print('----------')
print('Dumping servers with:')
print('User:', dump_user)
print('SQLite pgAdmin config:', config.SQLITE_PATH)
print('----------')
app = create_app(config.APP_NAME + '-cli')
with app.app_context():
user = User.query.filter_by(email=dump_user).first()
if user is None:
print("The specified user ID (%s) could not be found." %
dump_user)
sys.exit(1)
user_id = user.id
# Dict to collect the output
object_dict = {}
# Counters
servers_dumped = 0
# Dump servers
servers = Server.query.filter_by(user_id=user_id).all()
server_dict = {}
for server in servers:
if args.servers is None or str(server.id) in args.servers:
# Get the group name
group_name = ServerGroup.query.filter_by(
user_id=user_id, id=server.servergroup_id).first().name
attr_dict = {}
add_value(attr_dict, "Name", server.name)
add_value(attr_dict, "Group", group_name)
add_value(attr_dict, "Host", server.host)
add_value(attr_dict, "HostAddr", server.hostaddr)
add_value(attr_dict, "Port", server.port)
add_value(attr_dict, "MaintenanceDB", server.maintenance_db)
add_value(attr_dict, "Username", server.username)
add_value(attr_dict, "Role", server.role)
add_value(attr_dict, "SSLMode", server.ssl_mode)
add_value(attr_dict, "Comment", server.comment)
add_value(attr_dict, "Shared", server.shared)
add_value(attr_dict, "DBRestriction", server.db_res)
add_value(attr_dict, "PassFile", server.passfile)
add_value(attr_dict, "SSLCert", server.sslcert)
add_value(attr_dict, "SSLKey", server.sslkey)
add_value(attr_dict, "SSLRootCert", server.sslrootcert)
add_value(attr_dict, "SSLCrl", server.sslcrl)
add_value(attr_dict, "SSLCompression", server.sslcompression)
add_value(attr_dict, "BGColor", server.bgcolor)
add_value(attr_dict, "FGColor", server.fgcolor)
add_value(attr_dict, "Service", server.service)
add_value(attr_dict, "Timeout", server.connect_timeout)
add_value(attr_dict, "UseSSHTunnel", server.use_ssh_tunnel)
add_value(attr_dict, "TunnelHost", server.tunnel_host)
add_value(attr_dict, "TunnelPort", server.tunnel_port)
add_value(attr_dict, "TunnelUsername", server.tunnel_username)
add_value(attr_dict, "TunnelAuthentication",
server.tunnel_authentication)
servers_dumped = servers_dumped + 1
server_dict[servers_dumped] = attr_dict
object_dict["Servers"] = server_dict
try:
f = open(args.dump_servers, "w")
except Exception as e:
print("Error opening output file %s: [%d] %s" %
(args.dump_servers, e.errno, e.strerror))
sys.exit(1)
try:
f.write(json.dumps(object_dict, indent=4))
except Exception as e:
print("Error writing output file %s: [%d] %s" %
(args.dump_servers, e.errno, e.strerror))
sys.exit(1)
f.close()
print("Configuration for %s servers dumped to %s." %
(servers_dumped, args.dump_servers))
def _validate_servers_data(data, is_admin):
"""
Used internally by load_servers to validate servers data.
:param data: servers data
:return: error message if any
"""
skip_servers = []
# Loop through the servers...
if "Servers" not in data:
return ("'Servers' attribute not found in file '%s'" %
args.load_servers)
for server in data["Servers"]:
obj = data["Servers"][server]
# Check if server is shared.Won't import if user is non-admin
if obj.get('Shared', None) and not is_admin:
print("Won't import the server '%s' as it is shared " %
obj["Name"])
skip_servers.append(server)
continue
def check_attrib(attrib):
if attrib not in obj:
return ("'%s' attribute not found for server '%s'" %
(attrib, server))
check_attrib("Name")
check_attrib("Group")
is_service_attrib_available = obj.get("Service", None) is not None
if not is_service_attrib_available:
check_attrib("Port")
check_attrib("Username")
check_attrib("SSLMode")
check_attrib("MaintenanceDB")
if "Host" not in obj and "HostAddr" not in obj and not \
is_service_attrib_available:
return ("'Host', 'HostAddr' or 'Service' attribute "
"not found for server '%s'" % server)
for server in skip_servers:
del data["Servers"][server]
return None
def load_servers(args):
"""Load server groups and servers.
Args:
args (ArgParser): The parsed command line options
"""
# What user?
load_user = args.user if args.user is not None else config.DESKTOP_USER
# And the sqlite path
if args.sqlite_path is not None:
config.SQLITE_PATH = args.sqlite_path
print('----------')
print('Loading servers with:')
print('User:', load_user)
print('SQLite pgAdmin config:', config.SQLITE_PATH)
print('----------')
try:
with open(args.load_servers) as f:
data = json.load(f)
except json.decoder.JSONDecodeError as e:
print("Error parsing input file %s: %s" %
(args.load_servers, e))
sys.exit(1)
except Exception as e:
print("Error reading input file %s: [%d] %s" %
(args.load_servers, e.errno, e.strerror))
sys.exit(1)
f.close()
app = create_app(config.APP_NAME + '-cli')
with app.app_context():
user = User.query.filter_by(email=load_user).first()
if user is None:
print("The specified user ID (%s) could not be found." %
load_user)
sys.exit(1)
user_id = user.id
# Counters
groups_added = 0
servers_added = 0
# Get the server groups
groups = ServerGroup.query.filter_by(user_id=user_id)
def print_summary():
print("Added %d Server Group(s) and %d Server(s)." %
(groups_added, servers_added))
err_msg = _validate_servers_data(data, user.has_role("Administrator"))
if err_msg is not None:
print(err_msg)
print_summary()
sys.exit(1)
for server in data["Servers"]:
obj = data["Servers"][server]
# Get the group. Create if necessary
group_id = next(
(g.id for g in groups if g.name == obj["Group"]), -1)
if group_id == -1:
new_group = ServerGroup()
new_group.name = obj["Group"]
new_group.user_id = user_id
db.session.add(new_group)
try:
db.session.commit()
except Exception as e:
print("Error creating server group '%s': %s" %
(new_group.name, e))
print_summary()
sys.exit(1)
group_id = new_group.id
groups_added = groups_added + 1
groups = ServerGroup.query.filter_by(user_id=user_id)
# Create the server
new_server = Server()
new_server.name = obj["Name"]
new_server.servergroup_id = group_id
new_server.user_id = user_id
new_server.ssl_mode = obj["SSLMode"]
new_server.maintenance_db = obj["MaintenanceDB"]
new_server.host = obj.get("Host", None)
new_server.hostaddr = obj.get("HostAddr", None)
new_server.port = obj.get("Port", None)
new_server.username = obj.get("Username", None)
new_server.role = obj.get("Role", None)
new_server.ssl_mode = obj["SSLMode"]
new_server.comment = obj.get("Comment", None)
new_server.db_res = obj.get("DBRestriction", None)
new_server.passfile = obj.get("PassFile", None)
new_server.sslcert = obj.get("SSLCert", None)
new_server.sslkey = obj.get("SSLKey", None)
new_server.sslrootcert = obj.get("SSLRootCert", None)
new_server.sslcrl = obj.get("SSLCrl", None)
new_server.sslcompression = obj.get("SSLCompression", None)
new_server.bgcolor = obj.get("BGColor", None)
new_server.fgcolor = obj.get("FGColor", None)
new_server.service = obj.get("Service", None)
new_server.connect_timeout = obj.get("Timeout", None)
new_server.use_ssh_tunnel = obj.get("UseSSHTunnel", None)
new_server.tunnel_host = obj.get("TunnelHost", None)
new_server.tunnel_port = obj.get("TunnelPort", None)
new_server.tunnel_username = obj.get("TunnelUsername", None)
new_server.tunnel_authentication = \
obj.get("TunnelAuthentication", None)
db.session.add(new_server)
try:
db.session.commit()
except Exception as e:
print("Error creating server '%s': %s" %
(new_server.name, e))
print_summary()
sys.exit(1)
servers_added = servers_added + 1
print_summary()
def setup_db():
"""Setup the configuration database."""
create_app_data_directory(config)
app = create_app()
print("pgAdmin 4 - Application Initialisation")
print("======================================\n")
with app.app_context():
# Run migration for the first time i.e. create database
from config import SQLITE_PATH
if not os.path.exists(SQLITE_PATH):
db_upgrade(app)
else:
version = Version.query.filter_by(name='ConfigDB').first()
schema_version = version.value
# Run migration if current schema version is greater than the
# schema version stored in version table
if CURRENT_SCHEMA_VERSION >= schema_version:
db_upgrade(app)
# Update schema version to the latest
if CURRENT_SCHEMA_VERSION > schema_version:
version = Version.query.filter_by(name='ConfigDB').first()
version.value = CURRENT_SCHEMA_VERSION
db.session.commit()
if os.name != 'nt':
os.chmod(config.SQLITE_PATH, 0o600)
if __name__ == '__main__':
# Configuration settings
import config
from pgadmin.model import SCHEMA_VERSION
from pgadmin.setup import db_upgrade, create_app_data_directory
parser = argparse.ArgumentParser(description='Setup the pgAdmin config DB')
exp_group = parser.add_argument_group('Dump server config')
exp_group.add_argument('--dump-servers', metavar="OUTPUT_FILE",
help='Dump the servers in the DB', required=False)
exp_group.add_argument('--servers', metavar="SERVERS", nargs='*',
help='One or more servers to dump', required=False)
imp_group = parser.add_argument_group('Load server config')
imp_group.add_argument('--load-servers', metavar="INPUT_FILE",
help='Load servers into the DB', required=False)
# Common args
parser.add_argument('--sqlite-path', metavar="PATH",
help='Dump/load with the specified pgAdmin config DB'
' file. This is particularly helpful when there'
' are multiple pgAdmin configurations. It is also'
' recommended to use this option when running'
' pgAdmin in desktop mode.', required=False)
parser.add_argument('--user', metavar="USER_NAME",
help='Dump/load servers for the specified username',
required=False)
args, extra = parser.parse_known_args()
config.SETTINGS_SCHEMA_VERSION = SCHEMA_VERSION
if "PGADMIN_TESTING_MODE" in os.environ and \
os.environ["PGADMIN_TESTING_MODE"] == "1":
config.SQLITE_PATH = config.TEST_SQLITE_PATH
# What to do?
if args.dump_servers is not None:
try:
dump_servers(args)
except Exception as e:
print(str(e))
elif args.load_servers is not None:
try:
load_servers(args)
except Exception as e:
print(str(e))
else:
setup_db()