Source code for smash.io.model_ddt_io

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from smash.core.model import Model

from smash.core._constant import STRUCTURE_PARAMETERS, STRUCTURE_STATES

from smash.io._error import ReadHDF5MethodError

import os
import errno
import warnings
import h5py
import numpy as np

__all__ = ["save_model_ddt", "read_model_ddt"]


def _default_save_data(structure: str):
    return {
        "setup": ["dt", "end_time", "start_time", "structure"],
        "mesh": ["active_cell", "area", "code", "dx", "flwdir"],
        "input_data": ["mean_prcp", "mean_pet", "qobs"],
        "parameters": STRUCTURE_PARAMETERS[
            structure
        ],  # only calibrated Model param will be stored
        "states": STRUCTURE_STATES[
            structure
        ],  # only initial Model states will be stored
        "output": [
            {
                "fstates": STRUCTURE_STATES[structure]
            },  # only final Model states will be stored
            "qsim",
        ],
    }


def _parse_selected_derived_type_to_hdf5(
    derived_type, list_attr, hdf5_ins, attr_suffix=""
):
    # TODO: clean function for attr_suffix

    for attr in list_attr:
        if isinstance(attr, str):
            try:
                value = getattr(derived_type, attr)

                attr += attr_suffix

                if isinstance(value, np.ndarray):
                    if value.dtype == "object" or value.dtype.char == "U":
                        value = value.astype("S")

                    hdf5_ins.create_dataset(
                        attr,
                        shape=value.shape,
                        dtype=value.dtype,
                        data=value,
                        compression="gzip",
                        chunks=True,
                    )

                else:
                    hdf5_ins.attrs[attr] = value

            except:
                pass

        elif isinstance(attr, dict):
            for derived_type_key, list_attr_imd in attr.items():
                try:
                    derived_type_imd = getattr(derived_type, derived_type_key)

                    _parse_selected_derived_type_to_hdf5(
                        derived_type_imd, list_attr_imd, hdf5_ins
                    )

                except:
                    pass


[docs]def save_model_ddt(model: Model, path: str, sub_data=None, sub_only=False): """ Save some derived data types of the Model object to HDF5 file. This method is considerably lighter than `smash.save_model` method that saves the entire Model object. However, it is not capable of reconstructing the Model object from the saved data file. By default, the following data are stored into the `HDF5 <https://www.hdfgroup.org/solutions/hdf5/>`__ file: - ``dt``, ``end_time``, ``start_time``, ``structure`` from `Model.setup` - ``active_cell``, ``area``, ``code``, ``dx``, ``flwdir`` from `Model.mesh` - ``mean_prcp``, ``mean_pet``, ``qobs`` from `Model.input_data` - ``qsim`` from `Model.output` - The final Model states (depending upon the Model structure) from state derived type of `Model.output` - The initial Model states (depending upon the Model structure) from `Model.states` - The Model parameters (depending upon the Model structure) from `Model.parameters` Subsidiary data can be added by filling in ``sub_data``. Parameters ---------- model : Model The Model object to save derived data types as a HDF5 file. path : str The file path. If the path not end with ``.hdf5``, the extension is automatically added to the file path. sub_data : dict or None, default None Dictionary which indicates the subsidiary data to store into the HDF5 file. .. note:: If not given, no subsidiary data is saved sub_only : bool, default False Allow to only store subsidiary data. See Also -------- read_model_ddt: Read derived data types of the Model object from HDF5 file. Model: Primary data structure of the hydrological model `smash`. Examples -------- >>> setup, mesh = smash.load_dataset("cance") >>> model = smash.Model(setup, mesh) >>> model Structure: 'gr-a' Spatio-Temporal dimension: (x: 28, y: 28, time: 1440) Last update: Initialization Save spatially distributed precipitation in addition to default derived data types of Model >>> smash.save_model_ddt(model, "model_ddt.hdf5", sub_data={"prcp": model.input_data.prcp}) """ if not path.endswith(".hdf5"): path = path + ".hdf5" with h5py.File(path, "w") as f: if not sub_only: save_data = _default_save_data(model.setup.structure) for derived_type_key, list_attr in save_data.items(): derived_type = getattr(model, derived_type_key) if derived_type_key == "states": _parse_selected_derived_type_to_hdf5( derived_type, list_attr, f, attr_suffix="_0" ) else: _parse_selected_derived_type_to_hdf5(derived_type, list_attr, f) if sub_data is not None: for attr, value in sub_data.items(): if (attr in f) or (attr in f.attrs): warnings.warn(f"Ignore updating existing key ({attr})") continue if isinstance(value, np.ndarray): if value.dtype == "object" or value.dtype.char == "U": value = value.astype("S") try: f.create_dataset( attr, shape=value.shape, dtype=value.dtype, data=value, compression="gzip", chunks=True, ) except: warnings.warn(f"Can not store to HDF5: {attr}") else: try: f.attrs[attr] = value except: warnings.warn(f"Can not store to HDF5: {attr}") f.attrs["_save_func"] = "save_model_ddt"
[docs]def read_model_ddt(path: str) -> dict: """ Read derived data types of the Model object from HDF5 file. Parameters ---------- path : str The file path. Returns ------- data : dict A dictionary with derived data types loaded from HDF5 file. Raises ------ FileNotFoundError: If file not found. ReadHDF5MethodError: If file not created with `save_model_ddt`. See Also -------- save_model_ddt: Save some derived data types of the Model object to HDF5 file. Examples -------- >>> setup, mesh = smash.load_dataset("cance") >>> model = smash.Model(setup, mesh) >>> smash.save_model_ddt(model, "model_ddt.hdf5") Read the derived data types from HDF5 file >>> data = smash.read_model_ddt("model_ddt.hdf5") Then, to see the dataset keys >>> data.keys() dict_keys(['active_cell', 'area', 'cft', 'code', 'cp', 'exc', 'flwdir', 'hft', 'hft_0', 'hlr', 'hlr_0', 'hp', 'hp_0', 'lr', 'mean_pet', 'mean_prcp', 'qobs', 'qsim', 'dt', 'dx', 'end_time', 'start_time', 'structure']) And finally, to access to derived data >>> data["mean_prcp"] array([[0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.], [0., 0., 0., ..., 0., 0., 0.]], dtype=float32) """ if os.path.isfile(path): with h5py.File(path) as f: if f.attrs.get("_save_func") == "save_model_ddt": keys = list(f.keys()) values = [ f[key][:].astype("U") if f[key][:].dtype.char == "S" else f[key][:] for key in keys ] attr_keys = list(f.attrs.keys()) attr_keys.remove("_save_func") attr_values = [f.attrs[key] for key in attr_keys] return dict(zip(keys + attr_keys, values + attr_values)) else: raise ReadHDF5MethodError( f"Unable to read '{path}' with 'read_model_ddt' method. The file may not have been created with 'save_model_ddt' method." ) else: raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), path)