import collections
import numpy as np
import xarray as xr
from clisops.core.average import average_over_dims as average
from clisops.ops import subset
from clisops.parameter import collection_parameter
from clisops.parameter import dimension_parameter
from clisops.parameter import time_components_parameter
from clisops.parameter import time_parameter
from clisops.project_utils import derive_ds_id
from clisops.utils.dataset_utils import open_xr_dataset
from rook.utils.decadal_fixes import apply_decadal_fixes, decadal_fix_calendar
from . import normalise
from .base import Operation
coord_by_standard_name = {
"realization": "realization",
}
[docs]
def drop_time_bnds(ds: xr.Dataset) -> xr.Dataset:
if "time_bnds" in ds.variables:
ds = ds.drop_vars("time_bnds")
return ds
[docs]
def patched_normalise(collection):
norm_collection = collections.OrderedDict()
for dset, file_paths in collection.items():
fixed_datasets = [
decadal_fix_calendar(None, open_xr_dataset(file)) for file in file_paths
]
ds = xr.concat(fixed_datasets, dim="time")
norm_collection[dset] = ds
return norm_collection
[docs]
class Concat(Operation):
def _resolve_params(self, collection, **params):
time = time_parameter.TimeParameter(params.get("time"))
time_components = time_components_parameter.TimeComponentsParameter(
params.get("time_components")
)
dims = dimension_parameter.DimensionParameter(params.get("dims"))
collection = collection_parameter.CollectionParameter(collection)
self.collection = collection
self.params = {
"time": time,
"time_components": time_components,
"dims": dims,
"apply_average": params.get("apply_average", False),
"ignore_undetected_dims": params.get("ignore_undetected_dims"),
}
def _calculate(self):
config = {
"output_type": self._output_type,
"output_dir": self._output_dir,
"split_method": self._split_method,
"file_namer": self._file_namer,
}
self.params.update(config)
new_collection = collections.OrderedDict()
for dset in self.collection:
ds_id = derive_ds_id(dset)
new_collection[ds_id] = dset.file_paths
norm_collection = patched_normalise(new_collection)
rs = normalise.ResultSet(vars())
datasets = []
for ds_id in norm_collection.keys():
ds = norm_collection[ds_id]
ds_mod = apply_decadal_fixes(
ds_id, ds, output_dir=self.params.get("output_dir", ".")
)
datasets.append(ds_mod)
dims = dimension_parameter.DimensionParameter(
self.params.get("dims", None)
).value
standard_name = dims[0]
dim = coord_by_standard_name.get(standard_name, None)
processed_ds = xr.concat(
datasets,
dim,
)
processed_ds = processed_ds.assign_coords(
{dim: (dim, np.array(processed_ds[dim].values, dtype="int32"))}
)
processed_ds.coords[dim].attrs = {"standard_name": standard_name}
processed_ds = drop_time_bnds(processed_ds)
if self.params.get("apply_average", False):
processed_ds = average(processed_ds, dims=[dim])
outputs = subset(
processed_ds,
time=self.params.get("time", None),
time_components=self.params.get("time_components", None),
output_type="nc",
)
rs.add("output", outputs)
return rs
def _concat(
collection,
time=None,
time_components=None,
dims=None,
ignore_undetected_dims=False,
output_dir=None,
output_type="netcdf",
split_method="time:auto",
file_namer="standard",
apply_fixes=True,
apply_average=False,
):
return Concat(**locals())._calculate()
[docs]
def concat(
collection,
time=None,
time_components=None,
dims=None,
ignore_undetected_dims=False,
output_dir=None,
output_type="netcdf",
split_method="time:auto",
file_namer="standard",
apply_fixes=True,
apply_average=False,
):
args = dict(
collection=collection,
time=time,
time_components=time_components,
dims=dims,
ignore_undetected_dims=ignore_undetected_dims,
output_dir=output_dir,
output_type=output_type,
split_method=split_method,
file_namer=file_namer,
apply_fixes=apply_fixes,
apply_average=apply_average,
)
return _concat(**args)