1) Added RLS Policy support in Schema Diff. Fixes #5601

2) Fixed 'cant execute empty query' issue when remove the value of
   'USING' or 'WITH CHECK' option of RLS Policy. Fixes #5631
This commit is contained in:
Pradip Parkale
2020-07-01 12:44:28 +05:30
committed by Akshay Joshi
parent 3102a8d24b
commit 979f806161
18 changed files with 245 additions and 86 deletions

View File

@@ -1194,16 +1194,6 @@ class FunctionView(PGChildNodeView, DataTypeReader, SchemaDiffObjectCompare):
@staticmethod
def _prepare_final_dict(data, old_data, chngd_variables, del_variables,
all_ids_dict):
# To compare old and new variables, preparing name :
# value dict
# if 'variables' in data and 'changed' in data['variables']:
# for v in data['variables']['changed']:
# chngd_variables[v['name']] = v['value']
#
# if 'variables' in data and 'added' in data['variables']:
# for v in data['variables']['added']:
# chngd_variables[v['name']] = v['value']
# In case of schema diff we don't want variables from
# old data
@@ -1319,13 +1309,6 @@ class FunctionView(PGChildNodeView, DataTypeReader, SchemaDiffObjectCompare):
# Deleted Variables
FunctionView._delete_variable_in_edit_mode(data, del_variables)
FunctionView._merge_variable_changes(data, chngd_variables)
# if 'variables' in data and 'changed' in data['variables']:
# for v in data['variables']['changed']:
# chngd_variables[v['name']] = v['value']
#
# if 'variables' in data and 'added' in data['variables']:
# for v in data['variables']['added']:
# chngd_variables[v['name']] = v['value']
# Prepare final dict
FunctionView._prepare_final_dict(data, old_data, chngd_variables,

View File

@@ -1671,6 +1671,9 @@ class TableView(BaseTableView, DataTypeReader, VacuumSettings,
self.manager.version >= 120000:
sub_modules.append('compound_trigger')
if self.manager.version >= 90500:
sub_modules.append('row_security_policy')
if tid:
status, data = self._fetch_properties(did, scid, tid)

View File

@@ -13,7 +13,7 @@ import simplejson as json
from functools import wraps
import pgadmin.browser.server_groups.servers.databases as databases
from flask import render_template, request, jsonify
from flask import render_template, request, jsonify, current_app
from flask_babelex import gettext
from pgadmin.browser.collection import CollectionNodeModule
from pgadmin.browser.utils import PGChildNodeView
@@ -24,6 +24,8 @@ from config import PG_DEFAULT_DRIVER
from pgadmin.browser.server_groups.servers.databases.schemas.tables. \
row_security_policies import utils as row_security_policies_utils
from pgadmin.utils.compile_template_name import compile_template_path
from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry
from pgadmin.tools.schema_diff.directory_compare import directory_diff
class RowSecurityModule(CollectionNodeModule):
@@ -124,26 +126,24 @@ class RowSecurityView(PGChildNodeView):
- This function will used to create all the child node within that
collection. Here it will create all the policy nodes.
* properties(gid, sid, did, rg_id)
* properties(gid, sid, did, plid)
- This function will show the properties of the selected policy node
* create(gid, sid, did, rg_id)
* create(gid, sid, did, plid)
- This function will create the new policy object
* update(gid, sid, did, rg_id)
* update(gid, sid, did, plid)
- This function will update the data for the selected policy node
* delete(self, gid, sid, rg_id):
* delete(self, gid, sid, plid):
- This function will drop the policy object
* msql(gid, sid, did, rg_id)
* msql(gid, sid, did, plid)
- This function is used to return modified sql for the selected
policy node
* get_sql(data, rg_id)
- This function will generate sql from model data
* sql(gid, sid, did, rg_id):
* sql(gid, sid, did, plid):
- This function will generate sql to show in sql pane for the selected
policy node.
"""
@@ -226,7 +226,8 @@ class RowSecurityView(PGChildNodeView):
# fetch schema name by schema id
sql = render_template("/".join(
[self.template_path, 'properties.sql']), tid=tid)
[self.template_path, 'properties.sql']), schema=self.schema,
tid=tid)
status, res = self.conn.execute_dict(sql)
if not status:
@@ -297,7 +298,7 @@ class RowSecurityView(PGChildNodeView):
properties tab
"""
status, data = self._fetch_properties(plid)
status, data = self._fetch_properties(did, scid, tid, plid)
if not status:
return data
@@ -306,7 +307,7 @@ class RowSecurityView(PGChildNodeView):
status=200
)
def _fetch_properties(self, plid):
def _fetch_properties(self, did, scid, tid, plid):
"""
This function is used to fetch the properties of the specified object
:param plid:
@@ -314,7 +315,7 @@ class RowSecurityView(PGChildNodeView):
"""
sql = render_template("/".join(
[self.template_path, 'properties.sql']
), plid=plid, datlastsysoid=self.datlastsysoid)
), plid=plid, scid=scid, datlastsysoid=self.datlastsysoid)
status, res = self.conn.execute_dict(sql)
if not status:
@@ -413,7 +414,7 @@ class RowSecurityView(PGChildNodeView):
)
try:
sql, name = row_security_policies_utils.get_sql(self.conn, data,
did,
did, scid,
tid, plid,
self.datlastsysoid,
self.schema,
@@ -438,7 +439,7 @@ class RowSecurityView(PGChildNodeView):
return internal_server_error(errormsg=str(e))
@check_precondition
def delete(self, gid, sid, did, scid, tid, plid=None):
def delete(self, gid, sid, did, scid, tid, plid=None, only_sql=False):
"""
This function will drop the policy object
:param plid: policy id
@@ -495,6 +496,8 @@ class RowSecurityView(PGChildNodeView):
cascade=cascade,
result=result
)
if only_sql:
return sql
status, res = self.conn.execute_scalar(sql)
if not status:
return internal_server_error(errormsg=res)
@@ -515,7 +518,7 @@ class RowSecurityView(PGChildNodeView):
data = dict(request.args)
sql, name = row_security_policies_utils.get_sql(self.conn, data, did,
tid, plid,
scid, tid, plid,
self.datlastsysoid,
self.schema,
self.table)
@@ -534,19 +537,21 @@ class RowSecurityView(PGChildNodeView):
def sql(self, gid, sid, did, scid, tid, plid):
"""
This function will generate sql to render into the sql panel
Args:
gid: Server Group ID
sid: Server ID
did: Database ID
scid: Schema ID
tid: Table ID
plid: policy ID
"""
status, res_data = self._fetch_properties(plid)
if not status:
return res_data
res_data['schema'] = self.schema
res_data['table'] = self.table
SQL = row_security_policies_utils.get_reverse_engineered_sql(
self.conn, self.schema, self.table, did, scid, tid, plid,
self.datlastsysoid)
sql = render_template("/".join(
[self.template_path, 'create.sql']),
data=res_data, display_comments=True)
return ajax_response(response=sql)
return ajax_response(response=SQL)
@check_precondition
def dependents(self, gid, sid, did, scid, tid, plid):
@@ -588,5 +593,138 @@ class RowSecurityView(PGChildNodeView):
status=200
)
@check_precondition
def get_sql_from_diff(self, gid, sid, did, scid, tid, plid, data=None,
diff_schema=None, drop_req=False):
sql = ''
if data:
data['schema'] = self.schema
data['table'] = self.table
sql, name = row_security_policies_utils.get_sql(
self.conn, data, did, scid, tid, plid, self.datlastsysoid,
self.schema, self.table)
sql = sql.strip('\n').strip(' ')
elif diff_schema:
schema = diff_schema
sql = row_security_policies_utils.get_reverse_engineered_sql(
self.conn, schema,
self.table, did, scid, tid, plid,
self.datlastsysoid,
template_path=None, with_header=False)
drop_sql = ''
if drop_req:
drop_sql = '\n' + self.delete(gid=1, sid=sid, did=did,
scid=scid, tid=tid,
plid=plid, only_sql=True)
if drop_sql != '':
sql = drop_sql + '\n\n' + sql
return sql
@check_precondition
def fetch_objects_to_compare(self, sid, did, scid, tid, oid=None):
"""
This function will fetch the list of all the policies for
specified schema id.
:param sid: Server Id
:param did: Database Id
:param scid: Schema Id
:param oid: Policy Id
:return:
"""
res = dict()
if not oid:
SQL = render_template("/".join([self.template_path,
'nodes.sql']), tid=tid)
status, policies = self.conn.execute_2darray(SQL)
if not status:
current_app.logger.error(policies)
return False
for row in policies['rows']:
status, data = self._fetch_properties(did, scid, tid,
row['oid'])
if status:
res[row['name']] = data
else:
status, data = self._fetch_properties(did, scid, tid, oid)
if not status:
current_app.logger.error(data)
return False
res = data
return res
def ddl_compare(self, **kwargs):
"""
This function returns the DDL/DML statements based on the
comparison status.
:param kwargs:
:return:
"""
src_params = kwargs.get('source_params')
tgt_params = kwargs.get('target_params')
source = kwargs.get('source')
target = kwargs.get('target')
target_schema = kwargs.get('target_schema')
comp_status = kwargs.get('comp_status')
diff = ''
if comp_status == 'source_only':
diff = self.get_sql_from_diff(gid=src_params['gid'],
sid=src_params['sid'],
did=src_params['did'],
scid=src_params['scid'],
tid=src_params['tid'],
plid=source['oid'],
diff_schema=target_schema)
elif comp_status == 'target_only':
diff = self.delete(gid=1,
sid=tgt_params['sid'],
did=tgt_params['did'],
scid=tgt_params['scid'],
tid=tgt_params['tid'],
plid=target['oid'],
only_sql=True)
elif comp_status == 'different':
diff_dict = directory_diff(
source, target, difference={}
)
if 'event' in diff_dict:
delete_sql = self.get_sql_from_diff(gid=1,
sid=tgt_params['sid'],
did=tgt_params['did'],
scid=tgt_params['scid'],
tid=tgt_params['tid'],
plid=target['oid'],
drop_req=True)
diff = self.get_sql_from_diff(gid=src_params['gid'],
sid=src_params['sid'],
did=src_params['did'],
scid=src_params['scid'],
tid=src_params['tid'],
plid=source['oid'],
diff_schema=target_schema)
return delete_sql + diff
diff = self.get_sql_from_diff(gid=tgt_params['gid'],
sid=tgt_params['sid'],
did=tgt_params['did'],
scid=tgt_params['scid'],
tid=tgt_params['tid'],
plid=target['oid'],
data=diff_dict)
return '\n' + diff
SchemaDiffRegistry(blueprint.node_type, RowSecurityView, 'table')
RowSecurityView.register_node_view(blueprint)

View File

@@ -80,6 +80,10 @@ define('pgadmin.node.row_security_policy', [
name: undefined,
policyowner: 'public',
event: 'ALL',
using: undefined,
using_orig: undefined,
withcheck: undefined,
withcheck_orig: undefined,
},
schema: [{
id: 'name', label: gettext('Name'), cell: 'string',
@@ -111,7 +115,7 @@ define('pgadmin.node.row_security_policy', [
control: 'sql-field', visible: true, group: gettext('Commands'),
},
{
id: 'withcheck', label: gettext('With Check'), deps: ['withcheck', 'event'],
id: 'withcheck', label: gettext('With check'), deps: ['withcheck', 'event'],
type: 'text', mode: ['create', 'edit', 'properties'],
control: 'sql-field', visible: true, group: gettext('Commands'),
disabled: 'disableWithCheck',
@@ -135,7 +139,6 @@ define('pgadmin.node.row_security_policy', [
validate: function(keys) {
var msg;
this.errorModel.clear();
// If nothing to validate
if (keys && keys.length == 0) {
return null;
@@ -147,6 +150,16 @@ define('pgadmin.node.row_security_policy', [
this.errorModel.set('name', msg);
return msg;
}
if (!this.isNew() && !_.isNull(this.get('using_orig')) && this.get('using_orig') != '' && String(this.get('using')).replace(/^\s+|\s+$/g, '') == ''){
msg = gettext('"USING" can not be empty once the value is set');
this.errorModel.set('using', msg);
return msg;
}
if (!this.isNew() && !_.isNull(this.get('withcheck_orig')) && this.get('withcheck_orig') != '' && String(this.get('withcheck')).replace(/^\s+|\s+$/g, '') == ''){
msg = gettext('"Withcheck" can not be empty once the value is set');
this.errorModel.set('withcheck', msg);
return msg;
}
return null;
},
disableWithCheck: function(m){

View File

@@ -1,4 +1,4 @@
-- POLICY: policy_1 ON public.test_rls_policy
-- POLICY: policy_1
-- DROP POLICY policy_1 ON public.test_rls_policy;

View File

@@ -0,0 +1,11 @@
-- POLICY: all_event_policy
-- DROP POLICY all_event_policy ON public.test_rls_policy;
CREATE POLICY all_event_policy
ON public.test_rls_policy
FOR ALL
TO public
USING (true)
WITH CHECK (true);

View File

@@ -1,4 +1,4 @@
-- POLICY: insert_policy ON public.test_rls_policy
-- POLICY: insert_policy
-- DROP POLICY insert_policy ON public.test_rls_policy;

View File

@@ -1,4 +1,4 @@
-- POLICY: test ON public.test_rls_policy
-- POLICY: test
-- DROP POLICY test ON public.test_rls_policy;

View File

@@ -1,4 +1,4 @@
-- POLICY: select_policy ON public.test_rls_policy
-- POLICY: select_policy
-- DROP POLICY select_policy ON public.test_rls_policy;

View File

@@ -36,7 +36,8 @@
"data": {
"name": "select_policy",
"event": "SELECT",
"policyowner": "public"
"policyowner": "public",
"schema": "public"
},
"expected_sql_file": "create_select_policy.sql"
},
@@ -48,7 +49,8 @@
"data": {
"name": "insert_policy",
"event": "INSERT",
"policyowner": "public"
"policyowner": "public",
"schema": "public"
},
"expected_sql_file": "create_insert_policy.sql"
},
@@ -58,7 +60,8 @@
"endpoint": "NODE-row_security_policy.obj",
"sql_endpoint": "NODE-row_security_policy.sql_id",
"data": {
"name": "test"
"name": "test",
"schema": "public"
},
"expected_sql_file": "create_public_policy.sql"
},
@@ -74,6 +77,21 @@
"expected_sql_file": "alter_policy.sql",
"expected_msql_file": "alter_policy_msql.sql"
},
{
"type": "create",
"name": "Create RLS policy for event 'ALL'",
"endpoint": "NODE-row_security_policy.obj",
"sql_endpoint": "NODE-row_security_policy.sql_id",
"data": {
"name": "all_event_policy",
"event": "ALL",
"policyowner": "public",
"schema": "public",
"using": "true",
"withcheck": "true"
},
"expected_sql_file": "create_all_event_policy.sql"
},
{
"type": "delete",
"name": "Drop policy",

View File

@@ -61,22 +61,21 @@ def get_parent(conn, tid, template_path=None):
@get_template_path
def get_sql(conn, data, did, tid, plid, datlastsysoid, schema, table,
mode=None, template_path=None):
def get_sql(conn, data, did, scid, tid, plid, datlastsysoid, schema, table,
template_path=None):
"""
This function will generate sql from model data
"""
if plid is not None:
sql = render_template("/".join(
[template_path, 'properties.sql']), plid=plid)
sql = render_template("/".join([template_path, 'properties.sql']),
schema=schema, plid=plid, scid=scid)
status, res = conn.execute_dict(sql)
if not status:
return internal_server_error(errormsg=res)
if len(res['rows']) == 0:
raise ObjectGone(_('Could not find the index in the table.'))
raise ObjectGone(_('Could not find the policy in the table.'))
old_data = dict(res['rows'][0])
old_data['schema'] = schema
@@ -95,7 +94,7 @@ def get_sql(conn, data, did, tid, plid, datlastsysoid, schema, table,
@get_template_path
def get_reverse_engineered_sql(conn, schema, table, did, tid, plid,
def get_reverse_engineered_sql(conn, schema, table, did, scid, tid, plid,
datlastsysoid,
template_path=None, with_header=True):
"""
@@ -114,21 +113,22 @@ def get_reverse_engineered_sql(conn, schema, table, did, tid, plid,
:return:
"""
SQL = render_template("/".join(
[template_path, 'properties.sql']), plid=plid)
[template_path, 'properties.sql']), plid=plid, scid=scid)
status, res = conn.execute_dict(SQL)
if not status:
raise Exception(res)
if len(res['rows']) == 0:
raise ObjectGone(_('Could not find the index in the table.'))
raise ObjectGone(_('Could not find the policy in the table.'))
data = dict(res['rows'][0])
# Adding parent into data dict, will be using it while creating sql
data['schema'] = schema
data['table'] = table
SQL, name = get_sql(conn, data, did, tid, None, datlastsysoid, schema,
SQL, name = get_sql(conn, data, did, scid, tid, None, datlastsysoid,
schema,
table)
if with_header:

View File

@@ -254,7 +254,7 @@ class SchemaDiffTableCompare(SchemaDiffObjectCompare):
target_params['diff_data'] = diff_dict
diff = self.get_sql_from_table_diff(**target_params)
ignore_sub_modules = ['column', 'constraints', 'row_security_policy']
ignore_sub_modules = ['column', 'constraints']
if self.manager.version < 100000:
ignore_sub_modules.append('partition')
if self.manager.server_type == 'pg' or self.manager.version < 120000:

View File

@@ -1247,21 +1247,13 @@ define('pgadmin.node.table', [
this.errorModel.set('partition_keys', msg);
return msg;
}
if (this.get('rlspolicy') && this.isNew()){
Alertify.confirm(
this.errorModel.unset('partition_keys');
if (this.get('rlspolicy') && this.isNew() && this.changed.rlspolicy){
Alertify.alert(
gettext('Check Policy?'),
gettext('Check if any policy exist. If no policy exists for the table, a default-deny policy is used, meaning that no rows are visible or can be modified'),
function() {
self.close();
return true;
},
function() {
// Do nothing.
return true;
}
gettext('Please check if any policy exist. If no policy exists for the table, a default-deny policy is used, meaning that no rows are visible or can be modified by other users')
);
}
this.errorModel.unset('partition_keys');
return null;
},
// We will disable everything if we are under catalog node

View File

@@ -1,14 +1,9 @@
{# CREATE POLICY Statement #}
-- POLICY: {{ conn|qtIdent(data.name) }} ON {{ conn|qtIdent(data.schema, data.table) }}
-- DROP POLICY {{ conn|qtIdent(data.name) }} ON {{ conn|qtIdent(data.schema, data.table) }};
{% set add_semicolon_after = 'to' %}
{% if data.withcheck is defined and data.withcheck != None and data.withcheck != '' %}
{% set add_semicolon_after = 'with_check' %}
{% elif data.using is defined and data.using != None and data.using != '' %}
{% set add_semicolon_after = 'using' %}
{% endif %}
CREATE POLICY {{ conn|qtIdent(data.name) }}
ON {{conn|qtIdent(data.schema, data.table)}}
{% if data.event %}

View File

@@ -3,15 +3,19 @@ SELECT
pl.polname AS name,
rw.cmd AS event,
rw.qual AS using,
rw.qual AS using_orig,
rw.with_check AS withcheck,
rw.with_check AS withcheck_orig,
array_to_string(rw.roles::name[], ', ') AS policyowner
FROM
pg_policy pl
JOIN pg_policies rw ON pl.polname=rw.policyname
JOIN pg_namespace n ON n.nspname=rw.schemaname
WHERE
{% if plid %}
pl.oid = {{ plid }}
pl.oid = {{ plid }} and n.oid = {{ scid }};
{% endif %}
{% if tid %}
pl.polrelid = {{ tid }}
{% endif %};
pl.polrelid = {{ tid }};
{% endif %}

View File

@@ -9,7 +9,7 @@ ALTER POLICY {{ o_data.name }} ON {{conn|qtIdent(o_data.schema, o_data.table)}}
{#####################################################}
{## Change policy using condition ##}
{#####################################################}
{% if data.using and o_data.withcheck != data.using %}
{% if data.using and o_data.using != data.using %}
ALTER POLICY {{ o_data.name }} ON {{conn|qtIdent(o_data.schema, o_data.table)}}
USING ({{ data.using }});
{% endif %}

View File

@@ -535,7 +535,7 @@ class BaseTableView(PGChildNodeView, BasePartitionTable):
for row in rset['rows']:
policy_sql = row_security_policies_utils. \
get_reverse_engineered_sql(
self.conn, schema, table, did, tid, row['oid'],
self.conn, schema, table, did, scid, tid, row['oid'],
self.datlastsysoid,
template_path=None, with_header=json_resp)
policy_sql = u"\n" + policy_sql