Reverse engineer serial columns when generating ERD for database/table. #6958

This commit is contained in:
Aditya Toshniwal 2023-11-23 15:50:54 +05:30 committed by GitHub
parent 9611e06dcf
commit 115208c8d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 9 deletions

View File

@ -17,6 +17,8 @@ from pgadmin.browser.server_groups.servers.databases.schemas.utils \
import DataTypeReader import DataTypeReader
from pgadmin.browser.server_groups.servers.utils import parse_priv_from_db, \ from pgadmin.browser.server_groups.servers.utils import parse_priv_from_db, \
parse_priv_to_db parse_priv_to_db
from pgadmin.browser.server_groups.servers.databases.utils \
import make_object_name
from functools import wraps from functools import wraps
@ -225,7 +227,8 @@ def parse_options_for_column(db_variables):
@get_template_path @get_template_path
def get_formatted_columns(conn, tid, data, other_columns, def get_formatted_columns(conn, tid, data, other_columns,
table_or_type, template_path=None): table_or_type, template_path=None,
with_serial=False):
""" """
This function will iterate and return formatted data for all This function will iterate and return formatted data for all
the columns. the columns.
@ -254,6 +257,28 @@ def get_formatted_columns(conn, tid, data, other_columns,
col['inheritedfrom' + table_or_type] = \ col['inheritedfrom' + table_or_type] = \
other_col['inheritedfrom'] other_col['inheritedfrom']
if with_serial:
# Here we assume if a column is serial
serial_seq_name = make_object_name(
data['name'], col['name'], 'seq')
# replace the escaped quotes for comparison
defval = (col.get('defval', '') or '').replace("''", "'").\
replace('""', '"')
if serial_seq_name in defval and defval.startswith("nextval('")\
and col['typname'] in ('integer', 'smallint', 'bigint'):
serial_type = {
'integer': 'serial',
'smallint': 'smallserial',
'bigint': 'bigserial'
}[col['typname']]
col['displaytypname'] = serial_type
col['cltype'] = serial_type
col['typname'] = serial_type
col['defval'] = ''
data['columns'] = all_columns data['columns'] = all_columns
if 'columns' in data and len(data['columns']) > 0: if 'columns' in data and len(data['columns']) > 0:

View File

@ -164,7 +164,7 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
return wrap return wrap
def _formatter(self, did, scid, tid, data): def _formatter(self, did, scid, tid, data, with_serial_cols=False):
""" """
Args: Args:
data: dict of query result data: dict of query result
@ -234,7 +234,8 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
# columns properties.sql, so we need to set template path # columns properties.sql, so we need to set template path
data = column_utils.get_formatted_columns(self.conn, tid, data = column_utils.get_formatted_columns(self.conn, tid,
data, other_columns, data, other_columns,
table_or_type) table_or_type,
with_serial=with_serial_cols)
self._add_constrints_to_output(data, did, tid) self._add_constrints_to_output(data, did, tid)
@ -493,7 +494,7 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
return condition return condition
def fetch_tables(self, sid, did, scid, tid=None): def fetch_tables(self, sid, did, scid, tid=None, with_serial_cols=False):
""" """
This function will fetch the list of all the tables This function will fetch the list of all the tables
and will be used by schema diff. and will be used by schema diff.
@ -502,6 +503,7 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
:param did: Database Id :param did: Database Id
:param scid: Schema Id :param scid: Schema Id
:param tid: Table Id :param tid: Table Id
:param with_serial_cols: Boolean
:return: Table dataset :return: Table dataset
""" """
@ -513,7 +515,8 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
data = BaseTableView.properties( data = BaseTableView.properties(
self, 0, sid, did, scid, tid, res=data, self, 0, sid, did, scid, tid, res=data,
return_ajax_response=False with_serial_cols=with_serial_cols,
return_ajax_response=False,
) )
return True, data return True, data
@ -534,6 +537,7 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
if status: if status:
data = BaseTableView.properties( data = BaseTableView.properties(
self, 0, sid, did, scid, row['oid'], res=data, self, 0, sid, did, scid, row['oid'], res=data,
with_serial_cols=with_serial_cols,
return_ajax_response=False return_ajax_response=False
) )
@ -1750,6 +1754,7 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
""" """
res = kwargs.get('res') res = kwargs.get('res')
return_ajax_response = kwargs.get('return_ajax_response', True) return_ajax_response = kwargs.get('return_ajax_response', True)
with_serial_cols = kwargs.get('with_serial_cols', False)
data = res['rows'][0] data = res['rows'][0]
@ -1768,7 +1773,8 @@ class BaseTableView(PGChildNodeView, BasePartitionTable, VacuumSettings):
'vacuum_settings_str' 'vacuum_settings_str'
].replace('=', ' = ') ].replace('=', ' = ')
data = self._formatter(did, scid, tid, data) data = self._formatter(did, scid, tid, data,
with_serial_cols=with_serial_cols)
# Fetch partition of this table if it is partitioned table. # Fetch partition of this table if it is partitioned table.
if 'is_partitioned' in data and data['is_partitioned']: if 'is_partitioned' in data and data['is_partitioned']:

View File

@ -126,3 +126,26 @@ def get_attributes_from_db_info(manager, kwargs):
return datistemplate, datallowconn return datistemplate, datallowconn
else: else:
return False, True return False, True
def make_object_name(name1: str, name2: str, label: str) -> str:
"""
This function is python port for makeObjectName in postgres.
https://github.com/postgres/postgres/blob/master/src/backend/commands/indexcmds.c
It is used by postgres to generate index name for auto index,
sequence name for serial columns.
:param name1: generally table name
:param name2: generally column name
:param label: a suffix
:return: name string
"""
namedatalen: int = 63
result = '{0}_{1}_{2}'.format(name1, name2, label)
while len(result) > namedatalen:
if len(name1) > len(name2):
name1 = name1[:-1]
else:
name2 = name2[:-1]
result = '{0}_{1}_{2}'.format(name1, name2, label)
return result

View File

@ -289,7 +289,7 @@ class TableNodeWidgetRaw extends React.Component {
<Box margin="auto 0"> <Box margin="auto 0">
<span data-test="column-name">{col.name}</span>&nbsp; <span data-test="column-name">{col.name}</span>&nbsp;
{this.state.show_details && {this.state.show_details &&
<span data-test="column-type">{cltype}</span>} <span data-test="column-type">{cltype + (col.colconstype == 'i' ? ` (${gettext('IDENTITY')})`:'')}</span>}
</Box> </Box>
</Box> </Box>
<Box marginLeft="auto" padding="0" minHeight="0" display="flex" alignItems="center"> <Box marginLeft="auto" padding="0" minHeight="0" display="flex" alignItems="center">

View File

@ -40,7 +40,8 @@ class ERDTableView(BaseTableView, DataTypeReader):
all_tables = [] all_tables = []
for row in schemas['rows']: for row in schemas['rows']:
status, res = \ status, res = \
BaseTableView.fetch_tables(self, sid, did, row['oid']) BaseTableView.fetch_tables(self, sid, did, row['oid'],
with_serial_cols=True)
if not status: if not status:
return status, res return status, res
@ -53,7 +54,8 @@ class ERDTableView(BaseTableView, DataTypeReader):
tid=None, related={}, maxdepth=0, currdepth=0): tid=None, related={}, maxdepth=0, currdepth=0):
status, res = \ status, res = \
BaseTableView.fetch_tables(self, sid, did, scid, tid=tid) BaseTableView.fetch_tables(self, sid, did, scid, tid=tid,
with_serial_cols=True)
if not status: if not status:
return status, res return status, res