__all__ = ['get_parseable_expression', 'revert_parseable_expression',
'safe_parse_expr', 'SympyExprStr', 'sanity_check_tm']
import sympy
import re
import unicodedata
[docs]def get_parseable_expression(s: str) -> str:
"""Return an expression that can be parsed using sympy."""
# Handle lambda which cannot be parsed by sympy
s = s.replace('lambda', 'XXlambdaXX')
# Handle dots which also cannot be parsed
s = re.sub(r'\.(?=\D)', 'XX_XX', s)
# Handle superscripts which are not allowed in sympy
s = re.sub(r"\^{(.*?)}", r"XXCXX{_\1}", s)
# Handle curly braces which are not allowed in sympy
s = s.replace('{', 'XXCBO').replace('}', 'XXCBC')
s = unicodedata.normalize('NFKC', s)
return s
[docs]def revert_parseable_expression(s: str) -> str:
"""Return an expression to its original form."""
s = s.replace('XXCXX', '^')
s = s.replace('XXCBO', '{').replace('XXCBC', '}')
s = s.replace('XX_XX', '.')
s = s.replace('XXlambdaXX', 'lambda')
return s
[docs]def safe_parse_expr(s: str, local_dict=None) -> sympy.Expr:
"""Parse an expression that may contain lambda functions."""
return sympy.parse_expr(get_parseable_expression(s),
local_dict={get_parseable_expression(k): v
for k, v in local_dict.items()}
if local_dict else None)
[docs]class SympyExprStr(sympy.Expr):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if isinstance(v, cls):
return v
elif isinstance(v, float):
return cls(sympy.Float(v))
elif isinstance(v, int):
return cls(sympy.Integer(v))
return cls(v)
@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string", example="2*x")
def __str__(self):
return super().__str__()[len(self.__class__.__name__)+1:-1]
def __repr__(self):
return str(self)
[docs]def sanity_check_tm(tm):
"""Apply a short sanity check to a template model."""
assert tm.templates
all_concept_names = set(tm.get_concepts_name_map())
all_parameter_names = set(tm.parameters)
all_symbols = all_concept_names | all_parameter_names | ({tm.time.name} if tm.time else set())
for template in tm.templates:
assert template.rate_law
symbols = template.rate_law.args[0].free_symbols
for symbol in symbols:
assert symbol.name in all_symbols, f"missing symbol: {symbol.name}"
all_initial_names = {init.concept.name for init in tm.initials.values()}
for concept in all_concept_names:
assert concept in all_initial_names