Acquire a lock regardless of the authentication sources while getting the database server connection.

This commit is contained in:
Khushboo Vashi 2022-02-07 11:25:08 +05:30 committed by Akshay Joshi
parent 9cc2985d13
commit a7ee4e5909
3 changed files with 131 additions and 94 deletions

View File

@ -25,6 +25,7 @@ from pgadmin.utils import u_encode, file_quote, fs_encoding, \
get_complete_file_path, get_storage_directory, IS_WIN get_complete_file_path, get_storage_directory, IS_WIN
from pgadmin.browser.server_groups.servers.utils import does_server_exists from pgadmin.browser.server_groups.servers.utils import does_server_exists
from pgadmin.utils.constants import KERBEROS from pgadmin.utils.constants import KERBEROS
from pgadmin.utils.locker import ConnectionLocker
import pytz import pytz
from dateutil import parser from dateutil import parser
@ -274,14 +275,18 @@ class BatchProcess(object):
str(cmd) str(cmd)
) )
# Make a copy of environment, and add new variables to support # Acquiring lock while copying the environment from the parent process
env = os.environ.copy() # for the child process
with ConnectionLocker(_is_kerberos_conn=False):
# Make a copy of environment, and add new variables to support
env = os.environ.copy()
env['PROCID'] = self.id env['PROCID'] = self.id
env['OUTDIR'] = self.log_dir env['OUTDIR'] = self.log_dir
env['PGA_BGP_FOREGROUND'] = "1" env['PGA_BGP_FOREGROUND'] = "1"
if config.SERVER_MODE and session and \ if config.SERVER_MODE and session and \
session['auth_source_manager']['current_source'] == \ session['auth_source_manager']['current_source'] == \
KERBEROS: KERBEROS and 'KRB5CCNAME' in session:
env['KRB5CCNAME'] = session['KRB5CCNAME'] env['KRB5CCNAME'] = session['KRB5CCNAME']
if self.env: if self.env:

View File

@ -18,7 +18,6 @@ import select
import datetime import datetime
from collections import deque from collections import deque
import psycopg2 import psycopg2
import threading
from flask import g, current_app, session from flask import g, current_app, session
from flask_babel import gettext from flask_babel import gettext
from flask_security import current_user from flask_security import current_user
@ -41,8 +40,7 @@ from pgadmin.utils import csv
from pgadmin.utils.master_password import get_crypt_key from pgadmin.utils.master_password import get_crypt_key
from io import StringIO from io import StringIO
from pgadmin.utils.constants import KERBEROS from pgadmin.utils.constants import KERBEROS
from pgadmin.utils.locker import ConnectionLocker
lock = threading.Lock()
_ = gettext _ = gettext
@ -179,7 +177,6 @@ class Connection(BaseConnection):
self.reconnecting = False self.reconnecting = False
self.use_binary_placeholder = use_binary_placeholder self.use_binary_placeholder = use_binary_placeholder
self.array_to_string = array_to_string self.array_to_string = array_to_string
super(Connection, self).__init__() super(Connection, self).__init__()
def as_dict(self): def as_dict(self):
@ -318,47 +315,35 @@ class Connection(BaseConnection):
os.environ['PGAPPNAME'] = '{0} - {1}'.format( os.environ['PGAPPNAME'] = '{0} - {1}'.format(
config.APP_NAME, conn_id) config.APP_NAME, conn_id)
if config.SERVER_MODE and \ with ConnectionLocker(manager.kerberos_conn):
session['auth_source_manager']['current_source'] == \ pg_conn = psycopg2.connect(
KERBEROS and 'KRB5CCNAME' in session\ host=manager.local_bind_host if manager.use_ssh_tunnel
and manager.kerberos_conn: else manager.host,
lock.acquire() hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
environ['KRB5CCNAME'] = session['KRB5CCNAME'] else manager.hostaddr,
port=manager.local_bind_port if manager.use_ssh_tunnel
else manager.port,
database=database,
user=user,
password=password,
async_=self.async_,
passfile=get_complete_file_path(passfile),
sslmode=manager.ssl_mode,
sslcert=get_complete_file_path(manager.sslcert),
sslkey=get_complete_file_path(manager.sslkey),
sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcrl=get_complete_file_path(manager.sslcrl),
sslcompression=True if manager.sslcompression else False,
service=manager.service,
connect_timeout=manager.connect_timeout
)
pg_conn = psycopg2.connect( # If connection is asynchronous then we will have to wait
host=manager.local_bind_host if manager.use_ssh_tunnel # until the connection is ready to use.
else manager.host, if self.async_ == 1:
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel self._wait(pg_conn)
else manager.hostaddr,
port=manager.local_bind_port if manager.use_ssh_tunnel
else manager.port,
database=database,
user=user,
password=password,
async_=self.async_,
passfile=get_complete_file_path(passfile),
sslmode=manager.ssl_mode,
sslcert=get_complete_file_path(manager.sslcert),
sslkey=get_complete_file_path(manager.sslkey),
sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcrl=get_complete_file_path(manager.sslcrl),
sslcompression=True if manager.sslcompression else False,
service=manager.service,
connect_timeout=manager.connect_timeout
)
# If connection is asynchronous then we will have to wait
# until the connection is ready to use.
if self.async_ == 1:
self._wait(pg_conn)
if config.SERVER_MODE and \
session['auth_source_manager']['current_source'] == \
KERBEROS:
environ['KRB5CCNAME'] = ''
except psycopg2.Error as e: except psycopg2.Error as e:
environ['KRB5CCNAME'] = ''
manager.stop_ssh_tunnel() manager.stop_ssh_tunnel()
if e.pgerror: if e.pgerror:
msg = e.pgerror msg = e.pgerror
@ -376,11 +361,6 @@ class Connection(BaseConnection):
) )
) )
return False, msg return False, msg
finally:
if config.SERVER_MODE and \
session['auth_source_manager']['current_source'] == \
KERBEROS and lock.locked():
lock.release()
# Overwrite connection notice attr to support # Overwrite connection notice attr to support
# more than 50 notices at a time # more than 50 notices at a time
@ -1408,26 +1388,27 @@ WHERE db.datname = current_database()""")
return False, return_value return False, return_value
try: try:
pg_conn = psycopg2.connect( with ConnectionLocker(manager.kerberos_conn):
host=manager.local_bind_host if manager.use_ssh_tunnel pg_conn = psycopg2.connect(
else manager.host, host=manager.local_bind_host if manager.use_ssh_tunnel
hostaddr=manager.local_bind_host if manager.use_ssh_tunnel else manager.host,
else manager.hostaddr, hostaddr=manager.local_bind_host if manager.use_ssh_tunnel
port=manager.local_bind_port if manager.use_ssh_tunnel else manager.hostaddr,
else manager.port, port=manager.local_bind_port if manager.use_ssh_tunnel
database=self.db, else manager.port,
user=manager.user, database=self.db,
password=password, user=manager.user,
passfile=get_complete_file_path(manager.passfile), password=password,
sslmode=manager.ssl_mode, passfile=get_complete_file_path(manager.passfile),
sslcert=get_complete_file_path(manager.sslcert), sslmode=manager.ssl_mode,
sslkey=get_complete_file_path(manager.sslkey), sslcert=get_complete_file_path(manager.sslcert),
sslrootcert=get_complete_file_path(manager.sslrootcert), sslkey=get_complete_file_path(manager.sslkey),
sslcrl=get_complete_file_path(manager.sslcrl), sslrootcert=get_complete_file_path(manager.sslrootcert),
sslcompression=True if manager.sslcompression else False, sslcrl=get_complete_file_path(manager.sslcrl),
service=manager.service, sslcompression=True if manager.sslcompression else False,
connect_timeout=manager.connect_timeout service=manager.service,
) connect_timeout=manager.connect_timeout
)
except psycopg2.Error as e: except psycopg2.Error as e:
if e.pgerror: if e.pgerror:
@ -1710,30 +1691,31 @@ Failed to reset the connection to the server due to following error:
.decode() .decode()
try: try:
pg_conn = psycopg2.connect( with ConnectionLocker(self.manager.kerberos_conn):
host=self.manager.local_bind_host if pg_conn = psycopg2.connect(
self.manager.use_ssh_tunnel else self.manager.host, host=self.manager.local_bind_host if
hostaddr=self.manager.local_bind_host if self.manager.use_ssh_tunnel else self.manager.host,
self.manager.use_ssh_tunnel else hostaddr=self.manager.local_bind_host if
self.manager.hostaddr, self.manager.use_ssh_tunnel else
port=self.manager.local_bind_port if self.manager.hostaddr,
self.manager.use_ssh_tunnel else self.manager.port, port=self.manager.local_bind_port if
database=self.db, self.manager.use_ssh_tunnel else self.manager.port,
user=self.manager.user, database=self.db,
password=password, user=self.manager.user,
passfile=get_complete_file_path(self.manager.passfile), password=password,
sslmode=self.manager.ssl_mode, passfile=get_complete_file_path(self.manager.passfile),
sslcert=get_complete_file_path(self.manager.sslcert), sslmode=self.manager.ssl_mode,
sslkey=get_complete_file_path(self.manager.sslkey), sslcert=get_complete_file_path(self.manager.sslcert),
sslrootcert=get_complete_file_path( sslkey=get_complete_file_path(self.manager.sslkey),
self.manager.sslrootcert sslrootcert=get_complete_file_path(
), self.manager.sslrootcert
sslcrl=get_complete_file_path(self.manager.sslcrl), ),
sslcompression=True if self.manager.sslcompression sslcrl=get_complete_file_path(self.manager.sslcrl),
else False, sslcompression=True if self.manager.sslcompression
service=self.manager.service, else False,
connect_timeout=self.manager.connect_timeout service=self.manager.service,
) connect_timeout=self.manager.connect_timeout
)
# Get the cursor and run the query # Get the cursor and run the query
cur = pg_conn.cursor() cur = pg_conn.cursor()

View File

@ -0,0 +1,50 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2022, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""
Kerberos Environment Locker class
"""
from threading import Lock
from os import environ
from flask import session, current_app
import config
from pgadmin.utils.constants import KERBEROS
class ConnectionLocker:
"""Implementing lock while setting/unsetting
the Kerberos environ variables."""
lock = Lock()
def __init__(self, _is_kerberos_conn=False):
self.is_kerberos_conn = _is_kerberos_conn
def __enter__(self):
if config.SERVER_MODE:
current_app.logger.info("Waiting for a lock.")
self.lock.acquire()
current_app.logger.info("Acquired a lock.")
if session['auth_source_manager']['current_source'] == \
KERBEROS and 'KRB5CCNAME' in session \
and self.is_kerberos_conn:
environ['KRB5CCNAME'] = session['KRB5CCNAME']
else:
environ.pop('KRB5CCNAME', None)
return self
def __exit__(self, type, value, traceback):
if config.SERVER_MODE:
environ.pop('KRB5CCNAME', None)
if self.lock.locked():
current_app.logger.info("Released a lock.")
self.lock.release()