Source code for asdf_astropy.converters.transform.compound

from asdf_astropy.converters.helpers import get_tag_name

from .core import TransformConverterBase

__all__ = ["CompoundConverter"]


_OPERATOR_TO_TAG_NAME = {
    "+": "add",
    "-": "subtract",
    "*": "multiply",
    "/": "divide",
    "**": "power",
    "|": "compose",
    "&": "concatenate",
    "fix_inputs": "fix_inputs",
}


_TAG_NAME_TO_MODEL_METHOD = {
    "add": "__add__",
    "subtract": "__sub__",
    "multiply": "__mul__",
    "divide": "__truediv__",
    "power": "__pow__",
    "compose": "__or__",
    "concatenate": "__and__",
    "fix_inputs": "fix_inputs",
}


[docs] class CompoundConverter(TransformConverterBase): """ ASDF serialization support for CompoundModel. """ tags = ( "tag:stsci.edu:asdf/transform/add-*", "tag:stsci.edu:asdf/transform/subtract-*", "tag:stsci.edu:asdf/transform/multiply-*", "tag:stsci.edu:asdf/transform/divide-*", "tag:stsci.edu:asdf/transform/power-*", "tag:stsci.edu:asdf/transform/compose-*", "tag:stsci.edu:asdf/transform/concatenate-*", "tag:stsci.edu:asdf/transform/fix_inputs-*", ) types = ("astropy.modeling.core.CompoundModel",)
[docs] def select_tag(self, model, tags, ctx): tag_name = _OPERATOR_TO_TAG_NAME[model.op] # The extension will never include two tags with the # same name but different version, so we can just # return the first matching tag that we discover in # the list: return next(t for t in tags if get_tag_name(t) == tag_name)
[docs] def to_yaml_tree_transform(self, model, tag, ctx): left = model.left right = ( { "keys": list(model.right.keys()), "values": list(model.right.values()), } if isinstance(model.right, dict) else model.right ) return {"forward": [left, right]}
[docs] def from_yaml_tree_transform(self, node, tag, ctx): from astropy.modeling.core import CompoundModel, Model oper = _TAG_NAME_TO_MODEL_METHOD[get_tag_name(tag)] left = node["forward"][0] if not isinstance(left, Model): msg = f"Unknown left model type '{node['forward'][0]._tag}'" raise TypeError(msg) right = node["forward"][1] if (oper == "fix_inputs" and not isinstance(right, dict)) or ( oper != "fix_inputs" and not isinstance(right, Model) ): msg = f"Unknown right model type '{node['forward'][1]._tag}'" raise TypeError(msg) if oper == "fix_inputs": right = dict(zip(right["keys"], right["values"])) return CompoundModel("fix_inputs", left, right) return getattr(left, oper)(right)