__all__ = [
"Decapode",
"Variable",
"TangentVariable",
"Summation",
"Op1",
"Op2",
"RootVariable",
]
import copy
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List, Mapping
import sympy
def expand_variable(variable, var_produced_map):
if variable.expression:
return variable.expression
var_prod = var_produced_map.get(variable.id)
if not var_prod:
return sympy.Symbol(variable.name)
elif isinstance(var_prod, Op1):
return sympy.Function(var_prod.function_str)(
expand_variable(var_prod.src, var_produced_map)
)
elif isinstance(var_prod, Op2):
arg1 = expand_variable(var_prod.proj1, var_produced_map)
arg2 = expand_variable(var_prod.proj2, var_produced_map)
if var_prod.function_str == "/" or var_prod.function_str == "./":
return arg1 / arg2
elif var_prod.function_str == "*" or var_prod.function_str == ".*":
return arg1 * arg2
elif var_prod.function_str == "+" or var_prod.function_str == ".+":
return arg1 + arg2
elif var_prod.function_str == "-" or var_prod.function_str == ".-":
return arg1 - arg2
elif var_prod.function_str == "^" or var_prod.function_str == ".^":
return arg1**arg2
else:
return sympy.Function(var_prod.function_str)(arg1, arg2)
elif isinstance(var_prod, Summation):
args = [
expand_variable(summand, var_produced_map)
for summand in var_prod.summands
]
return sympy.Add(*args)
[docs]class Decapode:
"""
MIRA's internal representation of a decapode compute graph or decaexpr
JSON.
"""
def __init__(self, variables, op1s, op2s, summations, tangent_variables):
"""
Create a Decapode based off multiple mappings of different parts of
a Decapode.
Parameters
----------
variables : Dict[int,Variable]
Mapping of Variables.
op1s : Dict[int,Op1]
Mapping of Op1s (Operation 1s).
op2s : Dict[int,Op2]
Mapping of Op2s (Operation 2s).
summations : Dict[int,Summation]
Mapping of Summations.
tangent_variables : Dict[int,TangentVariable]
Mapping of TangentVariables.
"""
self.variables = variables
self.op1s = op1s
self.op2s = op2s
self.summations = summations
self.tangent_variables = tangent_variables
var_produced_map = {}
root_variable_map = defaultdict(list)
for ops, res_attr in (
(self.op1s, "tgt"),
(self.op2s, "res"),
(self.summations, "sum"),
):
for op_id, op in ops.items():
produced_var = getattr(op, res_attr)
if produced_var.id not in var_produced_map:
var_produced_map[produced_var.id] = op
else:
one_op = var_produced_map.pop(produced_var.id)
root_variable_map[produced_var.id] = [one_op, op]
new_vars = {}
for var_id, var in copy.deepcopy(self.variables).items():
if var_id not in root_variable_map:
var.expression = expand_variable(var, var_produced_map)
new_vars[var_id] = var
else:
var = RootVariable(var_id, var.type, var.name, var.identifiers)
temp_var_map = copy.deepcopy(var_produced_map)
temp_var_map[var_id] = root_variable_map[var_id][0]
var.expression[0] = expand_variable(
var.get_variable(), temp_var_map
)
temp_var_map = copy.deepcopy(var_produced_map)
temp_var_map[var_id] = root_variable_map[var_id][1]
var.expression[1] = expand_variable(
var.get_variable(), temp_var_map
)
new_vars[var_id] = var
self.update_vars(new_vars)
def update_vars(self, variables):
self.variables = variables
for ops, var_args in (
(self.op1s, ("src", "tgt")),
(self.op2s, ("proj1", "proj2", "res")),
(self.summations, ("summands", "sum")),
(self.tangent_variables, ("incl_var",)),
):
for op in ops.values():
for var_arg in var_args:
var_attr = getattr(op, var_arg)
if isinstance(var_attr, Variable):
setattr(op, var_arg, variables[var_attr.id])
elif isinstance(var_attr, list):
setattr(
op, var_arg, [variables[var.id] for var in var_attr]
)
# TODO: Inherit from Concept?
[docs]@dataclass
class Variable:
"""
Dataclass that represents a variable in MIRA's internal representation of
a Decapode.
Attributes
----------
id : int
The id of the tangent variable
type: str
The type of the variable.
name : str
The name of the variable.
expression : sympy.Expr
The expression of the variable.
identifiers : Mapping[str,str]
The mapping of namespaces to identifiers associated with the Variable.
"""
id: int
type: str
name: str
expression: sympy.Expr = field(default=None)
identifiers: Mapping[str, str] = field(default_factory=dict)
[docs]@dataclass
class RootVariable(Variable):
"""
Dataclass that represents a variable that is the output of a unary (
derivative) operation and the output of a series of unary and binary
operations as well.
Attributes
----------
expression : list[sympy.Expr]
A list containing both expressions associated with a RootVariable:
One expression built up from a unary operation (derivative) and one
built up from a series of unary and binary operations.
"""
expression: List[sympy.Expr] = field(default_factory=lambda: [None, None])
def get_variable(self):
return Variable(
self.id, self.type, self.name, identifiers=self.identifiers
)
[docs]@dataclass
class TangentVariable:
"""
Dataclass that represents a tangent variable in MIRA's internal
representation of a Decapode.
Attributes
----------
id : int
The id of the tangent variable.
incl_var : Variable
The variable that is the result of a derivative operation associated
with the tangent variable.
"""
id: int
incl_var: Variable
[docs]@dataclass
class Summation:
"""
Dataclass that represents a summation in MIRA's internal representation
of a decapode.
Attributes
----------
id : int
The id of the summation.
summands : list[Variable]
A list of Variables that are a part of the summation.
sum : Variable
The Variable that is the result of the summation.
"""
id: int
summands: List[Variable]
sum: Variable
[docs]@dataclass
class Op1:
"""
Dataclass that represents unary operations in MIRA's internal
representation of a decapode.
Attributes
----------
id : int
The id of the operation.
src : Variable
The Variable that is the source of the operation.
tgt : Variable
The Variable that is the target of the operation.
function_str : str
The operator of the operation.
"""
id: int
src: Variable
tgt: Variable
function_str: str
[docs]@dataclass
class Op2:
"""
Dataclass that represents binary operations in MIRA's internal
representation of a decapode.
Attributes
----------
id : int
The id of the operation.
proj1 : Variable
The Variable that is the first input to the operation.
proj2 : Variable
The Variable that is the second input to the operation.
res : Variable
The variable that is the result of the operation.
function_str : str
The operator of the operation.
"""
id: int
proj1: Variable
proj2: Variable
res: Variable
function_str: str