Source code for asdf_astropy.converters.transform.core

import abc

from asdf.extension import Converter

from asdf_astropy.converters.utils import import_type


def parameter_to_value(param):
    """
    Convert a model parameter to a Quantity or number,
    depending on the presence of a unit.

    Parameters
    ----------
    param : astropy.modeling.Parameter

    Returns
    -------
    astropy.units.Quantity or float
    """
    from astropy import units as u

    if param.unit is not None:
        return u.Quantity(param)

    return param.value


# One converter, UnitsMappingConverter, does not inherit
# this class.  When adding features here consider also
# updating UnitsMappingConverter.
# This class is used by other packages, e.g., gwcs, to implement
# converters for custom models.  Keep that in mind when modifying
# this code.
[docs] class TransformConverterBase(Converter): """ ABC for transform/model converters. Handles common properties after concrete converter sets model-specific properties. """
[docs] @abc.abstractmethod def to_yaml_tree_transform(self, model, tag, ctx): """ Convert a model's parameters into a dict suitable for ASDF serialization. Common model properties such as name and inverse will be handled by this base class. Parameters ---------- model : astropy.modeling.Model The model instance to convert. tag : str The tag identifying the YAML type that `astropy.modeling.Model` should be converted into. ctx : asdf.asdf.SerializationContext The context of the current serialization request. Returns ------- dict ASDF node. """
[docs] @abc.abstractmethod def from_yaml_tree_transform(self, node, tag, ctx): """ Convert an ASDF node into an instance of the appropriate model class. The implementing class need only instantiate the model and set parameter values; common model properties such as name and inverse will be handled by this base class. Parameters ---------- node : dict The ASDF node to convert. tag : str The tag identifying the YAML type of the node. ctx : asdf.asdf.SerializationContext The context of the current serialization request. Returns ------- astropy.modeling.Model The resulting model instance. """
[docs] def to_yaml_tree(self, model, tag, ctx): from astropy.modeling.core import CompoundModel node = self.to_yaml_tree_transform(model, tag, ctx) if model.name is not None: node["name"] = model.name node["inputs"] = list(model.inputs) node["outputs"] = list(model.outputs) # Don't bother serializing analytic inverses provided # by the model: if getattr(model, "_user_inverse", None) is not None: node["inverse"] = model._user_inverse self._serialize_bounding_box(model, node) # model / parameter constraints if not isinstance(model, CompoundModel): fixed_nondefaults = {k: f for k, f in model.fixed.items() if f} if fixed_nondefaults: node["fixed"] = fixed_nondefaults bounds_nondefaults = {k: b for k, b in model.bounds.items() if any(b)} if bounds_nondefaults: node["bounds"] = bounds_nondefaults # model input_units_equivalencies if not isinstance(model, CompoundModel) and model.input_units_equivalencies: node["input_units_equivalencies"] = model.input_units_equivalencies return node
def _serialize_bounding_box(self, model, node): from astropy.modeling.bounding_box import CompoundBoundingBox, ModelBoundingBox # ignore any default bounding_box if (bbox := model._user_bounding_box) is not None: if isinstance(bbox, ModelBoundingBox): self._serialize_bbox(model, node) elif isinstance(bbox, CompoundBoundingBox): self._serialize_cbbox(model, node) def _serialize_bbox(self, model, node): from astropy.modeling.bounding_box import ModelBoundingBox bbox = model.bounding_box if len(bbox.ignored) > 0: kwargs = {"_preserve_ignore": True} else: kwargs = {} node["bounding_box"] = ModelBoundingBox.validate(model, bbox, **kwargs) def _serialize_cbbox(self, model, node): node["bounding_box"] = model.bounding_box
[docs] def from_yaml_tree(self, node, tag, ctx): from astropy.modeling.core import CompoundModel model = self.from_yaml_tree_transform(node, tag, ctx) if "name" in node: model.name = node["name"] if "inputs" in node: model.inputs = tuple(node["inputs"]) if "outputs" in node: model.outputs = tuple(node["outputs"]) self._deserialize_bounding_box(model, node) param_and_model_constraints = {} for constraint in ["fixed", "bounds"]: if constraint in node: param_and_model_constraints[constraint] = node[constraint] model._initialize_constraints(param_and_model_constraints) # this still writes eqs. for compound, but operates on each sub model if "input_units_equivalencies" in node and not isinstance(model, CompoundModel): model.input_units_equivalencies = node["input_units_equivalencies"] yield model if "inverse" in node: model.inverse = node["inverse"]
def _deserialize_bounding_box(self, model, node): if "bounding_box" in node: bounding_box = node["bounding_box"] if isinstance(bounding_box, list): model.bounding_box = bounding_box elif callable(bounding_box): model.bounding_box = bounding_box(model) else: msg = f"Cannot form bounding_box from: {bounding_box}" raise TypeError(msg)
[docs] class SimpleTransformConverter(TransformConverterBase): """ Class for converters that serialize all of a model's parameters and do not require special behavior based on tag version. Parameters ---------- tags : list of str Tag patterns. model_type_name Fully-qualified model type name. """ def __init__(self, tags, model_type_name): self._tags = tags self._model_type_name = model_type_name self._model_type = None @property def tags(self): return self._tags @property def types(self): return [self._model_type_name] @property def model_type(self): # Delay import until the model class is needed to improve speed # of loading the extension. if self._model_type is None: self._model_type = import_type(self._model_type_name) return self._model_type
[docs] def to_yaml_tree_transform(self, model, tag, ctx): return {p: parameter_to_value(getattr(model, p)) for p in model.param_names}
[docs] def from_yaml_tree_transform(self, node, tag, ctx): model_type = self.model_type model_kwargs = {} for param in model_type.param_names: if param in node: model_kwargs[param] = node[param] return model_type(**model_kwargs)