Python: Avoid casting alltogether in favour of just generating the correct objects

This commit is contained in:
Gaute Lindkvist 2020-02-25 11:59:06 +01:00
parent d95f3a349d
commit 86cc2b5c9c
6 changed files with 107 additions and 149 deletions

View File

@ -1807,7 +1807,7 @@ void RiaApplication::generatePythonClasses( const QString& fileName )
QString scriptClassName = RicfObjectCapability::scriptClassNameFromClassKeyword( classKeyword );
if ( scriptClassName.isEmpty() ) scriptClassName = classKeyword;
if ( !classesWritten.count( classKeyword ) )
if ( !classesWritten.count( scriptClassName ) )
{
QString classCode;
if ( scriptSuperClassNames.empty() )
@ -1828,6 +1828,7 @@ void RiaApplication::generatePythonClasses( const QString& fileName )
}
classCode += " Attributes\n";
classCode += " class_keyword (string): the class keyword that uniquely defines a class\n";
for ( auto keyWordValuePair : classesGenerated[classKeyword] )
{
classCode += " " + keyWordValuePair.second.second;
@ -1838,19 +1839,18 @@ void RiaApplication::generatePythonClasses( const QString& fileName )
QString( " __custom_init__ = None #: Assign a custom init routine to be run at __init__\n\n" );
classCode += QString( " def __init__(self, pb2_object=None, channel=None):\n" );
classCode += QString( " self.class_keyword = \"%1\"\n" ).arg( scriptClassName );
if ( !scriptSuperClassNames.empty() )
{
// Parent constructor
classCode +=
QString( " %1.__init__(self, pb2_object, channel)\n" ).arg( scriptSuperClassNames.back() );
// Own attributes. This initializes a lot of attributes to None.
// This means it has to be done before we set any values.
for ( auto keyWordValuePair : classesGenerated[classKeyword] )
{
classCode += keyWordValuePair.second.first;
}
// Parent constructor
classCode +=
QString( " %1.__init__(self, pb2_object, channel)\n" ).arg( scriptSuperClassNames.back() );
}
classCode += QString( " if %1.__custom_init__ is not None:\n" ).arg( scriptClassName );
@ -1858,9 +1858,22 @@ void RiaApplication::generatePythonClasses( const QString& fileName )
.arg( scriptClassName );
out << classCode << "\n";
classesWritten.insert( classKeyword );
classesWritten.insert( scriptClassName );
}
scriptSuperClassNames.push_back( scriptClassName );
}
}
out << "def class_dict():\n";
out << " classes = {}\n";
for ( QString classKeyword : classesWritten )
{
out << QString( " classes['%1'] = %1\n" ).arg( classKeyword );
}
out << " return classes\n\n";
out << "def class_from_keyword(class_keyword):\n";
out << " all_classes = class_dict()\n";
out << " if class_keyword in all_classes.keys():\n";
out << " return all_classes[class_keyword]\n";
out << " return None\n";
}

View File

@ -16,7 +16,7 @@ if resinsight is not None:
for case in cases:
print("Case id: " + str(case.id))
print("Case name: " + case.name)
print("Case type: " + case.type)
print("Case type: " + case.class_keyword)
print("Case file name: " + case.file_path)
print("Case reservoir bounding box:", case.reservoir_boundingbox())
@ -25,10 +25,12 @@ if resinsight is not None:
print("Year: " + str(t.year))
print("Month: " + str(t.month))
coarsening_info = case.coarsening_info()
if coarsening_info:
print("Coarsening information:")
if isinstance(case, rips.EclipseCase):
print ("Getting coarsening info for case: ", case.name, case.id)
coarsening_info = case.coarsening_info()
if coarsening_info:
print("Coarsening information:")
for c in coarsening_info:
print("[{}, {}, {}] - [{}, {}, {}]".format(c.min.x, c.min.y, c.min.z,
c.max.x, c.max.y, c.max.z))
for c in coarsening_info:
print("[{}, {}, {}] - [{}, {}, {}]".format(c.min.x, c.min.y, c.min.z,
c.max.x, c.max.y, c.max.z))

View File

@ -19,14 +19,13 @@ import rips.generated.Properties_pb2 as Properties_pb2
import rips.generated.Properties_pb2_grpc as Properties_pb2_grpc
import rips.generated.NNCProperties_pb2 as NNCProperties_pb2
import rips.generated.NNCProperties_pb2_grpc as NNCProperties_pb2_grpc
from rips.generated.pdm_objects import Case
from rips.generated.pdm_objects import Case, EclipseCase, GeoMechCase
import rips.project
from rips.grid import Grid
from rips.pdmobject import add_method, PdmObject
from rips.view import View
from rips.contour_map import ContourMap, ContourMapType
from rips.well_bore_stability_plot import WellBoreStabilityPlot, WbsParameters
from rips.simulation_well import SimulationWell
@ -49,26 +48,27 @@ Attributes:
@add_method(Case)
def __custom_init__(self, pb2_object, channel):
self.__case_stub = Case_pb2_grpc.CaseStub(self._channel)
self.__request = Case_pb2.CaseRequest(id=self.id)
info = self.__case_stub.GetCaseInfo(self.__request)
self.__properties_stub = Properties_pb2_grpc.PropertiesStub(self._channel)
self.__nnc_properties_stub = NNCProperties_pb2_grpc.NNCPropertiesStub(self._channel)
# Public properties
self.type = info.type
self.chunk_size = 8160
@add_method(Case)
def __grid_count(self):
"""Get number of grids in the case"""
try:
return self.__case_stub.GetGridCount(self.__request).count
return self.__case_stub.GetGridCount(self.__request()).count
except grpc.RpcError as exception:
if exception.code() == grpc.StatusCode.NOT_FOUND:
return 0
return 0
@add_method(Case)
def __request(self):
return Case_pb2.CaseRequest(id=self.id)
@add_method(Case)
def __generate_property_input_iterator(self, values_iterator, parameters):
chunk = Properties_pb2.PropertyInputChunk()
@ -146,7 +146,7 @@ def cell_count(self, porosity_model="MATRIX_MODEL"):
reservoir_cell_count: total number of reservoir cells
"""
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Case_pb2.CellInfoRequest(case_request=self.__request,
request = Case_pb2.CellInfoRequest(case_request=self.__request(),
porosity_model=porosity_model_enum)
return self.__case_stub.GetCellCount(request)
@ -164,7 +164,7 @@ def cell_info_for_active_cells_async(self, porosity_model="MATRIX_MODEL"):
See cell_info_for_active_cells() for detalis on the **CellInfo** class.
"""
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Case_pb2.CellInfoRequest(case_request=self.__request,
request = Case_pb2.CellInfoRequest(case_request=self.__request(),
porosity_model=porosity_model_enum)
return self.__case_stub.GetCellInfoForActiveCells(request)
@ -223,14 +223,14 @@ def time_steps(self):
"""
return self.__case_stub.GetTimeSteps(self.__request).dates
return self.__case_stub.GetTimeSteps(self.__request()).dates
@add_method(Case)
def reservoir_boundingbox(self):
"""Get the reservoir bounding box
Returns: A class with six double members: min_x, max_x, min_y, max_y, min_z, max_z
"""
return self.__case_stub.GetReservoirBoundingBox(self.__request)
return self.__case_stub.GetReservoirBoundingBox(self.__request())
@add_method(Case)
def reservoir_depth_range(self):
@ -243,19 +243,12 @@ def reservoir_depth_range(self):
@add_method(Case)
def days_since_start(self):
"""Get a list of decimal values representing days since the start of the simulation"""
return self.__case_stub.GetDaysSinceStart(self.__request).day_decimals
return self.__case_stub.GetDaysSinceStart(self.__request()).day_decimals
@add_method(Case)
def views(self):
"""Get a list of views belonging to a case"""
pdm_objects = self.descendants("View")
view_list = []
for pdm_object in pdm_objects:
view_object = pdm_object.cast(View)
view_list.append(view_object)
return view_list
return self.descendants(View)
@add_method(Case)
def view(self, view_id):
@ -279,21 +272,6 @@ def create_view(self):
self._execute_command(createView=Cmd.CreateViewRequest(
caseId=self.id)).createViewResult.viewId)
@add_method(Case)
def contour_maps(self, map_type=ContourMapType.ECLIPSE):
""" Get a list of all contour maps belonging to a project
Arguments:
map_type (enum): ContourMapType.ECLIPSE or ContourMapType.GEO_MECH
"""
pdm_objects = self.descendants(ContourMapType.get_identifier(map_type))
contour_maps = []
for pdm_object in pdm_objects:
contour_maps.append(ContourMap(pdm_object, map_type))
return contour_maps
@add_method(Case)
def export_snapshots_of_all_views(self, prefix="", export_folder=""):
""" Export snapshots for all views in the case
@ -539,7 +517,7 @@ def available_properties(self,
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.AvailablePropertiesRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
porosity_model=porosity_model_enum,
)
@ -567,7 +545,7 @@ def active_cell_property_async(self,
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.PropertyRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
property_name=property_name,
time_step=time_step,
@ -624,7 +602,7 @@ def selected_cell_property_async(self,
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.PropertyRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
property_name=property_name,
time_step=time_step,
@ -684,7 +662,7 @@ def grid_property_async(
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.PropertyRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
property_name=property_name,
time_step=time_step,
@ -743,7 +721,7 @@ def set_active_cell_property_async(
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.PropertyRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
property_name=property_name,
time_step=time_step,
@ -774,7 +752,7 @@ def set_active_cell_property(
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.PropertyRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
property_name=property_name,
time_step=time_step,
@ -808,7 +786,7 @@ def set_grid_property(
property_type_enum = Properties_pb2.PropertyType.Value(property_type)
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Properties_pb2.PropertyRequest(
case_request=self.__request,
case_request=self.__request(),
property_type=property_type_enum,
property_name=property_name,
time_step=time_step,
@ -912,7 +890,7 @@ def active_cell_centers_async(
Loop through the chunks and then the values within the chunk to get all values.
"""
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Case_pb2.CellInfoRequest(case_request=self.__request,
request = Case_pb2.CellInfoRequest(case_request=self.__request(),
porosity_model=porosity_model_enum)
return self.__case_stub.GetCellCenterForActiveCells(request)
@ -949,7 +927,7 @@ def active_cell_corners_async(
Loop through the chunks and then the values within the chunk to get all values.
"""
porosity_model_enum = Case_pb2.PorosityModelType.Value(porosity_model)
request = Case_pb2.CellInfoRequest(case_request=self.__request,
request = Case_pb2.CellInfoRequest(case_request=self.__request(),
porosity_model=porosity_model_enum)
return self.__case_stub.GetCellCornersForActiveCells(request)
@ -979,7 +957,7 @@ def selected_cells_async(self):
An iterator to a chunk object containing an array of cells.
Loop through the chunks and then the cells within the chunk to get all cells.
"""
return self.__case_stub.GetSelectedCells(self.__request)
return self.__case_stub.GetSelectedCells(self.__request())
@add_method(Case)
def selected_cells(self):
@ -1003,13 +981,13 @@ def coarsening_info(self):
A list of CoarseningInfo objects with two Vec3i min and max objects
for each entry.
"""
return self.__case_stub.GetCoarseningInfoArray(self.__request).data
return self.__case_stub.GetCoarseningInfoArray(self.__request()).data
@add_method(Case)
def available_nnc_properties(self):
"""Get a list of available NNC properties
"""
return self.__nnc_properties_stub.GetAvailableNNCProperties(self.__request).properties
return self.__nnc_properties_stub.GetAvailableNNCProperties(self.__request()).properties
@add_method(Case)
def nnc_connections_async(self):
@ -1018,7 +996,7 @@ def nnc_connections_async(self):
An iterator to a chunk object containing an array NNCConnection objects.
Loop through the chunks and then the connection within the chunk to get all connections.
"""
return self.__nnc_properties_stub.GetNNCConnections(self.__request)
return self.__nnc_properties_stub.GetNNCConnections(self.__request())
@add_method(Case)
def nnc_connections(self):

View File

@ -15,7 +15,7 @@ import rips.generated.PdmObject_pb2 as PdmObject_pb2
import rips.generated.PdmObject_pb2_grpc as PdmObject_pb2_grpc
import rips.generated.Commands_pb2 as Cmd
import rips.generated.Commands_pb2_grpc as CmdRpc
from rips.generated.pdm_objects import PdmObject
from rips.generated.pdm_objects import PdmObject, class_from_keyword
def camel_to_snake(name):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
@ -57,7 +57,7 @@ def __custom_init__(self, pb2_object, channel):
self._pb2_object = pb2_object
else:
self._pb2_object = PdmObject_pb2.PdmObject(class_keyword=self.__class__.__name__)
self.class_keyword = self._pb2_object.class_keyword
self._channel = channel
if self.pb2_object() is not None and self.channel() is not None:
@ -69,7 +69,6 @@ def __custom_init__(self, pb2_object, channel):
snake_keyword = camel_to_snake(camel_keyword)
setattr(self, snake_keyword, self.__get_grpc_value(camel_keyword))
self.__keyword_translation[snake_keyword] = camel_keyword
self._superclasses = self.superclasses()
@add_method(PdmObject)
def copy_from(self, object):
@ -80,15 +79,9 @@ def copy_from(self, object):
value = getattr(object, attribute)
# This is crucial to avoid overwriting methods
if not callable(value):
setattr(self, attribute, value)
@add_method(PdmObject)
def cast(self, class_definition):
if class_definition.__name__ == self.class_keyword() or class_definition.__name__ in self._superclasses:
new_object = class_definition(self.pb2_object(), self.channel())
new_object.copy_from(self)
return new_object
return None
setattr(self, attribute, value)
if self.__custom_init__ is not None:
self.__custom_init__(self._pb2_object, self._channel)
@add_method(PdmObject)
def warnings(self):
@ -119,11 +112,6 @@ def address(self):
return self._pb2_object.address
@add_method(PdmObject)
def class_keyword(self):
"""Get the class keyword in the ResInsight Data Model for the given PdmObject"""
return self._pb2_object.class_keyword
@add_method(PdmObject)
def set_visible(self, visible):
"""Set the visibility of the object in the ResInsight project tree"""
@ -137,7 +125,7 @@ def visible(self):
@add_method(PdmObject)
def print_object_info(self):
"""Print the structure and data content of the PdmObject"""
print("=========== " + self.class_keyword() + " =================")
print("=========== " + self.class_keyword + " =================")
print("Object Attributes: ")
for snake_kw, camel_kw in self.__keyword_translation.items():
print(" " + snake_kw + " [" + type(getattr(self, snake_kw)).__name__ +
@ -215,37 +203,39 @@ def __makelist(self, list_string):
return values
@add_method(PdmObject)
def descendants(self, class_keyword_or_class):
def __from_pb2_to_pdm_objects(self, pb2_object_list, super_class_definition):
pdm_object_list = []
for pb2_object in pb2_object_list:
child_class_definition = class_from_keyword(pb2_object.class_keyword)
if child_class_definition is None:
child_class_definition = super_class_definition
pdm_object = child_class_definition(pb2_object=pb2_object, channel=self.channel())
pdm_object_list.append(pdm_object)
return pdm_object_list
@add_method(PdmObject)
def descendants(self, class_definition):
"""Get a list of all project tree descendants matching the class keyword
Arguments:
class_keyword_or_class[str/Class]: A class keyword matching the type of class wanted or a Class definition
class_definition[class]: A class definition matching the type of class wanted
Returns:
A list of PdmObjects matching the keyword provided
A list of PdmObjects matching the class_definition
"""
class_definition = PdmObject
class_keyword = ""
if isinstance(class_keyword_or_class, str):
class_keyword = class_keyword_or_class
else:
assert(inspect.isclass(class_keyword_or_class))
class_keyword = class_keyword_or_class.__name__
class_definition = class_keyword_or_class
assert(inspect.isclass(class_definition))
request = PdmObject_pb2.PdmDescendantObjectRequest(
object=self._pb2_object, child_keyword=class_keyword)
object_list = self._pdm_object_stub.GetDescendantPdmObjects(
request).objects
child_list = []
for pb2_object in object_list:
pdm_object = PdmObject(pb2_object=pb2_object, channel=self.channel())
if class_definition.__name__ == PdmObject.__name__:
child_list.append(pdm_object)
else:
casted_object = pdm_object.cast(class_definition)
if casted_object:
child_list.append(casted_object)
return child_list
class_keyword = class_definition.__name__
try:
request = PdmObject_pb2.PdmDescendantObjectRequest(
object=self._pb2_object, child_keyword=class_keyword)
object_list = self._pdm_object_stub.GetDescendantPdmObjects(
request).objects
return self.__from_pb2_to_pdm_objects(object_list, class_definition)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
return [] # Valid empty result
raise e
@add_method(PdmObject)
def children(self, child_field, class_definition=PdmObject):
@ -259,44 +249,33 @@ def children(self, child_field, class_definition=PdmObject):
child_field=child_field)
try:
object_list = self._pdm_object_stub.GetChildPdmObjects(request).objects
child_list = []
for pb2_object in object_list:
pdm_object = PdmObject(pb2_object=pb2_object, channel=self.channel())
if class_definition.__name__ == PdmObject.__name__:
child_list.append(pdm_object)
else:
child_list.append(pdm_object.cast(class_definition))
return child_list
return self.__from_pb2_to_pdm_objects(object_list, class_definition)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
return []
raise e
@add_method(PdmObject)
def ancestor(self, class_keyword_or_class):
def ancestor(self, class_definition):
"""Find the first ancestor that matches the provided class_keyword
Arguments:
class_keyword_or_class[str/Class]: A class keyword matching the type of class wanted or a Class definition
class_definition[class]: A class definition matching the type of class wanted
"""
class_definition = PdmObject
class_keyword = ""
if isinstance(class_keyword_or_class, str):
class_keyword = class_keyword_or_class
else:
assert(inspect.isclass(class_keyword_or_class))
class_keyword = class_keyword_or_class.__name__
class_definition = class_keyword_or_class
assert(inspect.isclass(class_definition))
class_keyword = class_definition.__name__
request = PdmObject_pb2.PdmParentObjectRequest(
object=self._pb2_object, parent_keyword=class_keyword)
try:
pb2_object = self._pdm_object_stub.GetAncestorPdmObject(request)
pdm_object = PdmObject(pb2_object=pb2_object,
channel=self._channel)
if class_definition.__name__ == PdmObject.__name__:
return pdm_object
else:
return pdm_object.cast(class_definition)
child_class_definition = class_from_keyword(pb2_object.class_keyword)
if child_class_definition is None:
child_class_definition = class_definition
pdm_object = child_class_definition(pb2_object=pb2_object, channel=self.channel())
return pdm_object
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.NOT_FOUND:
return None
@ -312,14 +291,3 @@ def update(self):
self._pdm_object_stub.UpdateExistingPdmObject(self._pb2_object)
else:
raise Exception("Object is not connected to GRPC service so cannot update ResInsight")
@add_method(PdmObject)
def superclasses(self):
names = []
mod = importlib.import_module("rips.generated.pdm_objects")
for name, obj in inspect.getmembers(mod):
if (inspect.isclass(obj) and name == self.class_keyword()):
class_hierarchy = inspect.getmro(obj)
for cls in class_hierarchy:
names.append(cls.__name__)
return names

View File

@ -122,7 +122,7 @@ def test_PdmObject(rips_instance, initialize_test):
case = rips_instance.project.load_case(path=case_path)
assert(case.id == 0)
assert(case.address() is not 0)
assert(case.class_keyword() == "EclipseCase")
assert(case.class_keyword == "EclipseCase")
@pytest.mark.skipif(sys.platform.startswith('linux'), reason="Brugge is currently exceptionally slow on Linux")
def test_brugge_0010(rips_instance, initialize_test):

View File

@ -23,11 +23,9 @@ def test_well_log_plots(rips_instance, initialize_test):
plots = project.plots()
well_log_plots = []
for plot in plots:
well_log_plot = plot.cast(rips.WellLogPlot)
if well_log_plot is not None:
well_log_plot.print_object_info()
assert(well_log_plot.depth_type == "MEASURED_DEPTH")
well_log_plots.append(well_log_plot)
if isinstance(plot, rips.WellLogPlot):
assert(plot.depth_type == "MEASURED_DEPTH")
well_log_plots.append(plot)
assert(len(well_log_plots) == 2)
with tempfile.TemporaryDirectory(prefix="rips") as tmpdirname:
@ -45,10 +43,9 @@ def test_well_log_plots(rips_instance, initialize_test):
assert(len(files) == 2)
plots2 = project.plots()
for plot2 in plots2:
well_log_plot2 = plot2.cast(rips.WellLogPlot)
if well_log_plot2 is not None:
assert(well_log_plot2.depth_type == "TRUE_VERTICAL_DEPTH_RKB")
for plot2 in plots2:
if isinstance(plot2, rips.WellLogPlot):
assert(plot2.depth_type == "TRUE_VERTICAL_DEPTH_RKB")
@pytest.mark.skipif(sys.platform.startswith('linux'), reason="Brugge is currently exceptionally slow on Linux")
def test_loadGridCaseGroup(rips_instance, initialize_test):