Source code for mira.modeling.amr.regnet

"""This module implements generation into Petri net models which are defined
at https://github.com/DARPA-ASKEM/Model-Representations/tree/main/regnet.
"""

__all__ = ["AMRRegNetModel", "ModelSpecification",
           "template_model_to_regnet_json"]

import json
import logging
from copy import deepcopy
from collections import defaultdict
from typing import Dict, List, Optional, Union

import sympy
from pydantic import BaseModel, Field

from mira.metamodel import *

from .. import Model, is_production, is_conversion
from .utils import add_metadata_annotations

logger = logging.getLogger(__name__)

SCHEMA_VERSION = '0.2'
SCHEMA_URL = ('https://raw.githubusercontent.com/DARPA-ASKEM/'
              'Model-Representations/regnet_v%s/regnet/'
              'regnet_schema.json') % SCHEMA_VERSION


[docs]class AMRRegNetModel: """A class representing a PetriNet model.""" def __init__(self, model: Model): """Instantiate a regnet model from a generic transition model. Parameters ---------- model: The pre-compiled transition model """ self.states = [] self.transitions = [] self.parameters = [] self.model_name = model.template_model.annotations.name if \ model.template_model.annotations and \ model.template_model.annotations.name else "Model" self.model_description = model.template_model.annotations.description \ if model.template_model.annotations and \ model.template_model.annotations.description else self.model_name self.rates = [] self.observables = [] self.time = None self.metadata = {} self._states_by_id = {} vmap = {} for key, var in model.variables.items(): # Use the variable's concept name if possible but fall back # on the key otherwise vmap[key] = name = var.concept.name or str(key) state_data = { 'id': name, 'name': name, 'grounding': { 'identifiers': {k: v for k, v in var.concept.identifiers.items() if k != 'biomodels.species'}, 'context': var.concept.context, }, } initial = var.data.get('expression') if initial is not None: # Here, initial is a SympyExprStr, and if its # value is a float, we export it as a float, # otherwise we export it as a string try: initial_float = float(initial.args[0]) state_data['initial'] = initial_float except TypeError: state_data['initial'] = str(initial) self.states.append(state_data) self._states_by_id[name] = state_data edge_id = 1 # It's possible that something is naturally degraded/replicated # by multiple transitions so we need to collect and aggregate rates intrinsic_by_var = defaultdict(list) for transition in model.transitions.values(): # Regnets cannot represent conversions (only # production/degradation) so we skip these if is_conversion(transition.template): continue # Now let's look for intrinsic growth/decay which is represented # at the vertex level in regnets natdeg = isinstance(transition.template, NaturalDegradation) natrep = isinstance(transition.template, NaturalReplication) contprod = isinstance(transition.template, ControlledProduction) # Natural degradation corresponds to an inherent negative # sign on the state so we have special handling for it if natdeg or natrep or (contprod and (transition.control[0].key == transition.produced[0].key)): if natdeg: var = vmap[transition.consumed[0].key] sign = False elif natrep: var = vmap[transition.produced[0].key] sign = True elif contprod: var = vmap[transition.produced[0].key] sign = True state_for_var = self._states_by_id.get(var) if transition.template.rate_law: pnames = transition.template.get_parameter_names() # We just choose an arbitrary one deterministically rate_const = sorted(pnames)[0] if pnames else None if state_for_var: state_for_var['rate_constant'] = rate_const if state_for_var: state_for_var['sign'] = sign if transition.template.rate_law: rate_law = transition.template.rate_law.args[0] intrinsic_by_var[var].append(rate_law) # Beyond these, we can assume that the transition is a # form of replication or degradation corresponding to # a regular transition in the regnet framework # Possibilities are: # - ControlledReplication / GroupedControlledProduction # - ControlledDegradation / GroupedControlledDegradation # - ControlledProduction if the control and produced concept are not the same else: # If we have multiple controls then the thing that replicates # is both a control and a produced variable. if len(transition.control) > 1: # GroupedControlledProduction if is_production(transition.template) or is_replication( transition.template ): indep_ctrl = {c.key for c in transition.control} - { transition.produced[0].key } # There is one corner case where both controllers are also # the same as the produced variable, in which case. if not indep_ctrl: indep_ctrl = {transition.produced[0].key} for index, controller in enumerate(indep_ctrl): self.create_edge( transition, vmap[controller], vmap[transition.produced[0].key], edge_id, False if index == 0 else True ) edge_id += 1 else: # GroupedControlledDegradation indep_ctrl = {c.key for c in transition.control} - { transition.consumed[0].key } # There is one corner case where both controllers are also # the same as the consumed variable, in which case. if not indep_ctrl: indep_ctrl = {transition.consumed[0].key} for index, controller in enumerate(indep_ctrl): self.create_edge( transition, vmap[controller], vmap[transition.consumed[0].key], edge_id, False if index == 0 else True ) edge_id += 1 else: # ControlledProduction if produced and controller are not the same # ControlledDegradation target = vmap[ transition.consumed[0].key if transition.consumed else transition.produced[0].key ] self.create_edge( transition, vmap[transition.control[0].key], target, edge_id, False ) edge_id += 1 for var, rates in intrinsic_by_var.items(): rate_law = sum(rates) self.rates.append({ 'target': var, 'expression': str(rate_law), 'expression_mathml': expression_to_mathml(rate_law) }) for key, param in model.parameters.items(): if param.placeholder: continue param_dict = {'id': str(key)} if param.value is not None: param_dict['value'] = param.value if not param.distribution: pass elif param.distribution.type is None: logger.warning("can not add distribution without type: %s", param.distribution) else: param_dict['distribution'] = { 'type': param.distribution.type, 'parameters': param.distribution.parameters, } self.parameters.append(param_dict) for key, observable in model.observables.items(): display_name = observable.observable.display_name \ if observable.observable.display_name \ else observable.observable.name obs_data = { 'id': observable.observable.name, 'name': display_name, 'expression': str(observable.observable.expression), 'expression_mathml': expression_to_mathml( observable.observable.expression.args[0]), } self.observables.append(obs_data) if model.template_model.time: self.time = {'id': model.template_model.time.name} if model.template_model.time.units: self.time['units'] = { 'expression': str(model.template_model.time.units.expression), 'expression_mathml': expression_to_mathml( model.template_model.time.units.expression.args[0]), } else: self.time = None add_metadata_annotations(self.metadata, model)
[docs] def create_edge(self, transition, source, target, edge_id, duplicate): """Create and append a transition dictionary to the list of transitions Parameters ---------- transition : Transition The Transition object source : str The name of the source of the transition target : str The name of the target of the transition edge_id : int The id to assign to the transition duplicate : bool A boolean that tells us whether the transition we are processing has already been processed at least once. This is for the purpose of not adding duplicate rate laws to the output amr. """ tid = f"t{edge_id}" transition_dict = {"id": tid} transition_dict["source"] = source transition_dict["target"] = target transition_dict["sign"] = is_production( transition.template ) or is_replication(transition.template) # if transition.template.rate_law: # If we are processing a duplicate rate, set the rate to 0 rate_law = transition.template.rate_law.args[0] if not duplicate \ else safe_parse_expr('0') pnames = transition.template.get_parameter_names() # We just choose an arbitrary one deterministically rate_const = sorted(pnames)[0] if pnames else None transition_dict["properties"] = { "name": tid, "rate_constant": rate_const, } self.rates.append( { "target": tid, "expression": str(rate_law), "expression_mathml": expression_to_mathml(rate_law), } ) self.transitions.append(transition_dict)
[docs] def to_json( self, name: str = None, description: str = None, model_version: str = None ): """Return a JSON dict structure of the Petri net model. Parameters ---------- name : The name of the model. Defaults to the model name of the original template model of the input Model instance, or "Model" if no name is available. description : The description of the model. Defaults to the description of the original template model of the input Model instance, or the model name if no description is available. model_version : The version of the model. Defaults to 0.1 Returns ------- : JSON A JSON representation of the Petri net model. """ return { 'header': { 'name': name or self.model_name, 'schema': SCHEMA_URL, 'schema_name': 'regnet', 'description': description or self.model_description, 'model_version': model_version or '0.1', }, 'model': { 'vertices': self.states, 'edges': self.transitions, 'parameters': self.parameters, }, 'semantics': {'ode': { 'rates': self.rates, 'observables': self.observables, 'time': self.time if self.time else {'id': 't'} }}, 'metadata': self.metadata, }
[docs] def to_pydantic( self, name: str = None, description: str = None, model_version: str = None ) -> "ModelSpecification": """Return a Pydantic model specification of the Petri net model. Parameters ---------- name : The name of the model. Defaults to the model name of the original template model of the input Model instance, or "Model" if no name is available. description : The description of the model. Defaults to the description of the original template model of the input Model instance, or the model name if no description is available. model_version : The version of the model. Defaults to 0.1 Returns ------- : A Pydantic model specification of the Petri net model. """ return ModelSpecification( header=Header( name=name or self.model_name, schema=SCHEMA_URL, schema_name='regnet', description=description or self.model_description, model_version=model_version or '0.1', ), model=RegNetModel( vertices=[State.parse_obj(s) for s in self.states], edges=[Transition.parse_obj(t) for t in self.transitions], parameters=[Parameter.from_dict(p) for p in self.parameters], ), semantics=Ode( ode=OdeSemantics( rates=[Rate.parse_obj(r) for r in self.rates], observables=[Observable.parse_obj(o) for o in self.observables], time=Time.parse_obj(self.time) if self.time else Time(id='t') ) ), metadata=self.metadata, )
[docs] def to_json_str(self, **kwargs): """Return a JSON string representation of the Petri net model. Parameters ---------- **kwargs : Keyword arguments to be passed to json.dumps Returns ------- : A JSON string representation of the Petri net model. """ return json.dumps(self.to_json(), **kwargs)
[docs] def to_json_file( self, fname: str, name: str = None, description: str = None, model_version: str = None, **kwargs ): """Write the Petri net model to a JSON file. Parameters ---------- fname : The file name to write to. name : The name of the model. Defaults to the model name of the original template model of the input Model instance, or "Model" if no name is available. description : The description of the model. Defaults to the description of the original template model of the input Model instance, or the model name if no description is available. model_version : The version of the model. Defaults to 0.1 **kwargs : Keyword arguments to be passed to json.dump """ js = self.to_json(name=name, description=description, model_version=model_version) with open(fname, 'w') as fh: json.dump(js, fh, **kwargs)
[docs]def template_model_to_regnet_json(tm: TemplateModel): """Convert a template model to a RegNet JSON dict. Parameters ---------- tm : The template model to convert. Returns ------- A JSON dict representing the RegNet model. """ return AMRRegNetModel(Model(tm)).to_json()
class Initial(BaseModel): expression: Union[str, float] expression_mathml: str class TransitionProperties(BaseModel): name: Optional[str] grounding: Optional[Dict] rate: Optional[Dict] class Rate(BaseModel): expression: str expression_mathml: str class Distribution(BaseModel): type: str parameters: Dict class State(BaseModel): id: str name: str initial: Optional[Initial] = None class Transition(BaseModel): id: str input: List[str] output: List[str] properties: Optional[TransitionProperties] class Parameter(BaseModel): id: str value: Optional[float] = None description: Optional[str] = None distribution: Optional[Distribution] = None @classmethod def from_dict(cls, d): d = deepcopy(d) d['id'] = str(d['id']) return cls.parse_obj(d) class RegNetModel(BaseModel): vertices: List[State] edges: List[Transition] parameters: List[Parameter] class Header(BaseModel): name: str schema_name: str schema_url: str = Field(..., alias='schema') description: str model_version: str class OdeSemantics(BaseModel): rates: List[Rate] time: Optional[Time] observables: List[Observable] class Ode(BaseModel): ode: Optional[OdeSemantics]
[docs]class ModelSpecification(BaseModel): """A Pydantic model specification of the model.""" header: Header properties: Optional[Dict] model: RegNetModel semantics: Optional[Ode] metadata: Optional[Dict]