Source code for mira.sources.acsets.decapodes.deca_expr

from mira.metamodel.decapodes import *
from mira.sources.acsets.decapodes.util import PARTIAL_TIME_DERIVATIVE

__all__ = ["process_decaexpr"]


def get_variables_mapping_decaexpr(decaexpr_json):
    """Get the variables from a decaexpr model JSON

    Parameters
    ----------
    decaexpr_json : dict
        The JSON of a decaexpr model

    Returns
    -------
    dict[int, Variable]
    """
    # First loop through the context to get the variables
    # then loop through the equations to get the remaining variables
    if "model" in decaexpr_json:
        decaexpr_json = decaexpr_json["model"]

    yielded_variable_names = set()
    var_dict = {
        ix: Variable(id=ix, type=_type, name=name)
        for ix, (name, _type) in enumerate(
            recursively_find_variables_decaexpr_json(
                decaexpr_json, yielded_variable_names
            )
        )
    }
    return var_dict


def recursively_find_variables_decaexpr_json(decaexpr_json, yielded_variables):
    """Find all the variables in a decaexpr model JSON

    Parameters
    ----------
    decaexpr_json : dict | list
        A dictionary or list of dictionaries that represent a decaexpr
    yielded_variables : set
        The set of variable names that have already been yielded

    Yields
    ------
    : tuple[str, str]
        A tuple of the variable type and name to be used to initialize the
        Variable class
    """
    assert isinstance(yielded_variables, set)

    # Yield variable type and name
    if isinstance(decaexpr_json, dict):
        if "_type" in decaexpr_json:
            # Base level
            if decaexpr_json["_type"] == "Var":
                name = decaexpr_json["name"]
                _type = "Form0"
                if name not in yielded_variables:
                    yield name, _type
                    yielded_variables.add(name)

            # Literal, under 'equation'
            elif decaexpr_json["_type"] == "Lit":
                name = decaexpr_json["name"]
                _type = "Literal"
                if name not in yielded_variables:
                    yield name, _type
                    yielded_variables.add(name)

            # Under 'context'
            elif decaexpr_json["_type"] == "Judgement":
                # type comes from the 'dim' field here
                name = decaexpr_json["var"]["name"]
                _type = decaexpr_json["dim"]
                if name not in yielded_variables:
                    yield name, _type
                    yielded_variables.add(name)

            # Top level
            elif decaexpr_json["_type"] == "DecaExpr":
                # Skip the header
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["context"], yielded_variables
                )
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["equations"], yielded_variables
                )

            # Equation object, under 'equations' yield from lhs and rhs
            elif decaexpr_json["_type"] == "Eq":
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["lhs"], yielded_variables
                )
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["rhs"], yielded_variables
                )

            # Derivative (tangent variable), under 'equations' -> 'lhs'/rhs'
            elif decaexpr_json["_type"] == "Tan":
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["var"], yielded_variables
                )

            # Multiplication, under 'equations' -> 'rhs'/lhs'
            elif decaexpr_json["_type"] == "Mult":
                for arg in decaexpr_json["args"]:
                    yield from recursively_find_variables_decaexpr_json(
                        arg, yielded_variables
                    )

            # Plus, under 'equations' -> 'rhs'/lhs'
            elif decaexpr_json["_type"] == "Plus":
                # A 'Plus' type means args is a list of terms to be summed over
                for term in decaexpr_json["args"]:
                    yield from recursively_find_variables_decaexpr_json(
                        term, yielded_variables
                    )

            # App1, under 'equations' -> 'rhs'/lhs' -> type='App1' -> 'arg'
            elif decaexpr_json["_type"] == "App1":
                # An 'App1' type means there is one argument, arg
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["arg"], yielded_variables
                )

            # App2, under 'equations' -> 'rhs'/lhs' -> type='Plus' -> 'args'
            elif decaexpr_json["_type"] == "App2":
                # An 'App2' type means there are two arguments, arg1 and arg2
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["arg1"], yielded_variables
                )
                yield from recursively_find_variables_decaexpr_json(
                    decaexpr_json["arg2"], yielded_variables
                )
            else:
                raise NotImplementedError(
                    f"Unhandled variable type: {decaexpr_json['_type']}"
                )
        else:
            for value in decaexpr_json.values():
                yield from recursively_find_variables_decaexpr_json(
                    value, yielded_variables
                )
    elif isinstance(decaexpr_json, list):
        for value in decaexpr_json:
            yield from recursively_find_variables_decaexpr_json(
                value, yielded_variables
            )
    else:
        raise NotImplementedError(
            f"Unhandled type: {type(decaexpr_json)}: {decaexpr_json}"
        )


def expand_equations(
    decaexpr_equations_json,
    variable_lookup,
    op2s_lookup,
    op1s_lookup,
    tangent_variables_lookup,
    summations_lookup,
    var_name_to_index,
) -> Variable:
    """Expand the equations in a decaexpr JSON to its components"""
    _type = decaexpr_equations_json["_type"]
    if _type in {"Var", "Lit"}:
        var_name = decaexpr_equations_json["name"]
        if var_name not in var_name_to_index:
            # Create new variable
            var_type = "Constant" if _type == "Lit" else "Form0"
            new_var_ix = len(variable_lookup)
            variable_lookup[new_var_ix] = Variable(
                id=new_var_ix, type=var_type, name=var_name
            )
            var_name_to_index[var_name] = new_var_ix
        return variable_lookup[var_name_to_index[var_name]]

    elif _type == "App2":
        # Binary operation
        arg1 = expand_equations(
            decaexpr_equations_json["arg1"],
            variable_lookup,
            op2s_lookup,
            op1s_lookup,
            tangent_variables_lookup,
            summations_lookup,
            var_name_to_index,
        )
        arg2 = expand_equations(
            decaexpr_equations_json["arg2"],
            variable_lookup,
            op2s_lookup,
            op1s_lookup,
            tangent_variables_lookup,
            summations_lookup,
            var_name_to_index,
        )
        op2 = decaexpr_equations_json["f"]
        # Create new variable that is the result of the binary operation
        var_type = "infer"
        if op2 == "*":
            name_prefix = "mult"
        elif op2 == "+":
            name_prefix = "add"
        elif op2 == "-":
            name_prefix = "sub"
        elif op2 == "/":
            name_prefix = "div"
        else:
            raise NotImplementedError(
                f"Unhandled binary operation: {op2}"
            )

        new_var_name_ix = len(
            [var.name for var in variable_lookup.values()
             if var.name.startswith(name_prefix)]
        ) + 1
        new_var_name = f"{name_prefix}_{new_var_name_ix}"

        new_var_ix = len(variable_lookup)
        variable_lookup[new_var_ix] = Variable(
            id=new_var_ix, type=var_type, name=new_var_name
        )

        # Add binary operation
        new_op2_ix = len(op2s_lookup)
        op2s_lookup[new_op2_ix] = Op2(
            id=new_op2_ix,
            proj1=variable_lookup[arg1.id],
            proj2=variable_lookup[arg2.id],
            res=variable_lookup[new_var_ix],
            function_str=op2,
        )

        var_name_to_index[new_var_name] = new_var_ix

        return variable_lookup[new_var_ix]

    elif _type == "App1":
        # Unary operation; apply a function to an argument
        arg = expand_equations(
            decaexpr_equations_json["arg"],
            variable_lookup,
            op2s_lookup,
            op1s_lookup,
            tangent_variables_lookup,
            summations_lookup,
            var_name_to_index,
        )
        op1 = decaexpr_equations_json["f"]

        # Create new variable that is the result of the unary operation
        var_type = "infer"
        var_name = f"{op1}({arg.name})"

        new_var_ix = len(variable_lookup)
        variable_lookup[new_var_ix] = Variable(
            id=new_var_ix, type=var_type, name=var_name
        )

        var_name_to_index[var_name] = new_var_ix

        # Add unary operation
        new_op1_ix = len(op1s_lookup)
        op1s_lookup[new_op1_ix] = Op1(
            id=new_op1_ix,
            src=variable_lookup[arg.id],
            tgt=variable_lookup[new_var_ix],
            function_str=op1,
        )

        return variable_lookup[new_var_ix]

    elif _type == "Tan":
        # Time derivative
        arg = expand_equations(
            decaexpr_equations_json["var"],
            variable_lookup,
            op2s_lookup,
            op1s_lookup,
            tangent_variables_lookup,
            summations_lookup,
            var_name_to_index,
        )

        # Create new variable that is the result of the unary operation
        var_type = "infer"
        var_name = f"{PARTIAL_TIME_DERIVATIVE}({arg.name})"

        new_var_ix = len(variable_lookup)
        variable_lookup[new_var_ix] = Variable(
            id=new_var_ix, type=var_type, name=var_name
        )

        var_name_to_index[var_name] = new_var_ix

        # Add unary operation
        new_op1_ix = len(op1s_lookup)
        op1s_lookup[new_op1_ix] = Op1(
            id=new_op1_ix,
            src=variable_lookup[arg.id],
            tgt=variable_lookup[new_var_ix],
            function_str=PARTIAL_TIME_DERIVATIVE,
        )

        # Add tangent variable - the result of the derivative
        new_tangent_var_ix = len(tangent_variables_lookup)
        tangent_variables_lookup[new_tangent_var_ix] = TangentVariable(
            id=new_tangent_var_ix, incl_var=variable_lookup[new_var_ix]
        )

        return variable_lookup[new_var_ix]

    elif _type == "Mult":
        # Loop through the arguments and multiply them together to get the
        # result, start from the left
        new_mult_result = None
        new_var_ix = None
        for iter_ix in range(len(decaexpr_equations_json["args"]) - 1):
            if iter_ix == 0:
                # First iteration, create a new variable with the first two
                # arguments
                arg0 = expand_equations(
                    decaexpr_equations_json["args"][iter_ix],
                    variable_lookup,
                    op2s_lookup,
                    op1s_lookup,
                    tangent_variables_lookup,
                    summations_lookup,
                    var_name_to_index,
                )
                arg1 = expand_equations(
                    decaexpr_equations_json["args"][iter_ix + 1],
                    variable_lookup,
                    op2s_lookup,
                    op1s_lookup,
                    tangent_variables_lookup,
                    summations_lookup,
                    var_name_to_index,
                )
            else:
                # Subsequent iterations, use the result of the previous
                # iteration
                assert new_mult_result is not None, "Should not be None"
                arg0 = new_mult_result
                arg1 = expand_equations(
                    decaexpr_equations_json["args"][iter_ix + 1],
                    variable_lookup,
                    op2s_lookup,
                    op1s_lookup,
                    tangent_variables_lookup,
                    summations_lookup,
                    var_name_to_index,
                )

            # Create new variable that is the result of the multiplication
            var_type = "infer"
            new_mult_ix = len([var.name for var in variable_lookup.values() if
                               var.name.startswith("mult")]) + 1
            new_var_name = f"mult_{new_mult_ix}"

            new_var_ix = len(variable_lookup)
            variable_lookup[new_var_ix] = Variable(
                id=new_var_ix, type=var_type, name=new_var_name
            )

            var_name_to_index[new_var_name] = new_var_ix

            # Add binary operation
            new_op2_ix = len(op2s_lookup)
            op2s_lookup[new_op2_ix] = Op2(
                id=new_op2_ix,
                proj1=variable_lookup[arg0.id],
                proj2=variable_lookup[arg1.id],
                res=variable_lookup[new_var_ix],
                function_str="*",
            )

            new_mult_result = variable_lookup[new_var_ix]

        assert new_var_ix is not None
        return variable_lookup[new_var_ix]

    elif _type == "Plus":
        # In decapode:
        #  - the Σ table specifies the result of the sums in the equation
        #  - the summand table specifies the terms in the sum(s), which sum
        #    they belong to is specified by the summation value which
        #    references one of the sums in the Σ table
        summand_list = []
        for summand_json in decaexpr_equations_json["args"]:
            summand_var = expand_equations(
                summand_json,
                variable_lookup,
                op2s_lookup,
                op1s_lookup,
                tangent_variables_lookup,
                summations_lookup,
                var_name_to_index,
            )
            summand_list.append(summand_var)

        # Create new variable that is the result of the addition
        var_type = "infer"
        new_add_ix = len([var.name for var in variable_lookup.values()
                          if var.name.startswith("add")]) + 1
        new_var_name = f"sum_{new_add_ix}"

        new_var_ix = len(variable_lookup)
        variable_lookup[new_var_ix] = Variable(
            id=new_var_ix, type=var_type, name=new_var_name
        )

        new_sum_ix = len(summations_lookup)
        summations_lookup[new_sum_ix] = Summation(
            id=new_sum_ix,
            summands=summand_list,
            sum=variable_lookup[new_var_ix],
        )

        var_name_to_index[new_var_name] = new_var_ix
        return variable_lookup[new_var_ix]

    else:
        raise NotImplementedError(f"Unhandled equation type: {_type}")


def replace_variable(replacement: Variable,
                     to_replace: Variable,
                     variable_lookup,
                     name_to_variable_index,
                     op2s_lookup,
                     op1s_lookup,
                     tangent_variables_lookup,
                     summations_lookup):
    """Replace a variable from the data structures

    Parameters
    ----------
    replacement : Variable
        The variable to replace the other variable with
    to_replace : Variable
        The variable to be replaced
    variable_lookup : dict[int, Variable]
        The lookup table for the variables
    name_to_variable_index : dict[str, int]
        The lookup table for the variable names
    op2s_lookup : dict[int, Op2]
        The lookup table for the binary operations
    op1s_lookup : dict[int, Op1]
        The lookup table for the unary operations
    tangent_variables_lookup : dict[int, TangentVariable]
        The lookup table for the tangent variables
    summations_lookup : dict[int, Summation]
        The lookup table for the summations
    """
    # Remove the variable to be replaced from the lookup tables
    del variable_lookup[to_replace.id]
    del name_to_variable_index[to_replace.name]

    # For each of the operations lookup tables, replace the variable
    # with the replacement variable
    for op2 in op2s_lookup.values():
        if op2.proj1.id == to_replace.id:
            op2.proj1 = replacement
        if op2.proj2.id == to_replace.id:
            op2.proj2 = replacement
        if op2.res.id == to_replace.id:
            op2.res = replacement

    for op1 in op1s_lookup.values():
        if op1.src.id == to_replace.id:
            op1.src = replacement
        if op1.tgt.id == to_replace.id:
            op1.tgt = replacement

    for tangent_var in tangent_variables_lookup.values():
        if tangent_var.incl_var.id == to_replace.id:
            tangent_var.incl_var = replacement

    for summation in summations_lookup.values():
        replace_ix = []
        for ix, summand in enumerate(summation.summands):
            if summand.id == to_replace.id:
                replace_ix.append(ix)

        for ix in replace_ix:
            summation.summands[ix] = replacement

        if summation.sum.id == to_replace.id:
            summation.sum = replacement


[docs]def process_decaexpr(decaexpr_json) -> Decapode: """Process a DecaExpr JSON into a Decapode object. Parameters ---------- decaexpr_json : JSON The DecaExpr JSON of a model. Returns ------- : The corresponding MIRA Decapode object. """ decaexpr_json_model = decaexpr_json["model"] variables = get_variables_mapping_decaexpr(decaexpr_json_model) name_to_variable_index = {v.name: k for k, v in variables.items()} equation_type_priority = ["Lit", "Var", "Tan", "Mult", "Plus"] op1s_lookup = {} op2_lookup = {} tangent_variables_lookup = {} summations_lookup = {} # Expand each side of the equation(s) into its components for equation_json in decaexpr_json_model["equations"]: lhs_var = expand_equations( equation_json["lhs"], variable_lookup=variables, op1s_lookup=op1s_lookup, op2s_lookup=op2_lookup, tangent_variables_lookup=tangent_variables_lookup, summations_lookup=summations_lookup, var_name_to_index=name_to_variable_index, ) rhs_var = expand_equations( equation_json["rhs"], variable_lookup=variables, op1s_lookup=op1s_lookup, op2s_lookup=op2_lookup, tangent_variables_lookup=tangent_variables_lookup, summations_lookup=summations_lookup, var_name_to_index=name_to_variable_index, ) lhs_type = equation_json["lhs"]["_type"] lhs_priority = equation_type_priority.index(lhs_type) rhs_type = equation_json["rhs"]["_type"] rhs_priority = equation_type_priority.index(rhs_type) # Same priority, choose the left hand side if lhs_priority == rhs_priority: prio_var = lhs_var del_var = rhs_var else: prio_var = lhs_var if lhs_priority < rhs_priority else rhs_var del_var = lhs_var if lhs_priority > rhs_priority else rhs_var # Replace the variable that is not the priority variable replace_variable( replacement=prio_var, to_replace=del_var, variable_lookup=variables, name_to_variable_index=name_to_variable_index, op2s_lookup=op2_lookup, op1s_lookup=op1s_lookup, tangent_variables_lookup=tangent_variables_lookup, summations_lookup=summations_lookup, ) return Decapode( variables=variables, op1s=op1s_lookup, op2s=op2_lookup, summations=summations_lookup, tangent_variables=tangent_variables_lookup, )