Source code for rook.utils.ops.consolidate
"""Consolidate file paths for each dataset in a collection."""
import re
from pathlib import Path
from clisops.exceptions import InvalidCollection
from clisops.project_utils import derive_ds_id, dset_to_filepaths, get_project_name
from clisops.utils.dataset_utils import open_xr_dataset
from clisops.utils.file_utils import FileMapper
from loguru import logger
from rook.catalog import get_catalog
from rook.io.datasets import (
DatasetFormat,
DatasetSource,
Transport,
detect_format,
detect_transport,
)
from .helpers import wrap_sequence
def _bypasses_catalog(value):
"""Return whether a direct source should skip project resolution."""
source = DatasetSource(dataset_id=None, paths=value)
return (
detect_format(source) is not DatasetFormat.NETCDF
or detect_transport(source) is Transport.S3
)
def _resolve_file_source(dset, time_param):
"""Resolve a mapped file input or a dataset without a Rook catalog."""
file_paths = dset_to_filepaths(dset, force=True)
if time_param:
file_paths = get_files_matching_time_range(time_param, file_paths)
if not file_paths:
raise Exception(f"No files found in given time range for {dset}")
dataset_id = None if isinstance(dset, FileMapper) else str(dset)
return DatasetSource(dataset_id=dataset_id, paths=file_paths)
def _resolve_catalog_sources(dset, catalog, time_param):
"""Resolve one catalog input into normalized dataset sources."""
dataset_id = derive_ds_id(dset)
result = catalog.search(collection=dataset_id, time=time_param)
if len(result) == 0:
result = catalog.search(collection=dataset_id, time=None)
if len(result) > 0:
raise Exception(f"No files found in given time range for {dset}")
raise InvalidCollection(f"{dset} is not in the list of available data.")
logger.info(f"Found {len(result)} files")
return tuple(
DatasetSource(dataset_id=matched_id, paths=paths)
for matched_id, paths in result.files().items()
)
def _catalog_for(dset, catalogs):
"""Return a cached project catalog for a dataset input."""
project = get_project_name(dset)
if project not in catalogs:
catalogs[project] = get_catalog(project)
return catalogs[project]
[docs]
def to_year(time_string):
"""Return the year in a time string as an integer."""
return int(time_string.split("-")[0])
[docs]
def get_year(value, default):
"""Get a year from a datetime string."""
if value:
return to_year(value)
return default
[docs]
def get_years_from_file(fpath):
"""Attempt to extract years from file name or file time axis."""
time_comps = Path(fpath).stem.split("_")[-1].split("-")
years = {int(tm[:4]) for tm in time_comps if re.match(r"^\d{4,}", tm)}
if len(years) > 1:
years = set(range(min(years), max(years) + 1))
if not years:
ds = open_xr_dataset(fpath)
if hasattr(ds, "time"):
years = {int(yr) for yr in ds.time.dt.year}
return years
[docs]
def get_files_matching_time_range(time_param, file_paths):
"""Filter files whose years intersect requested time range."""
if time_param.type == "none":
return file_paths
logger.info(f"Testing {len(file_paths)} files in time range: ...")
files_in_time_range = []
if time_param.type == "interval":
tp_start, tp_end = time_param.get_bounds()
req_start_year = get_year(tp_start, default=-99999999)
req_end_year = get_year(tp_end, default=999999999)
for fpath in file_paths:
years = get_years_from_file(fpath)
if min(years) <= req_end_year and max(years) >= req_start_year:
files_in_time_range.append(fpath)
elif time_param.type == "series":
req_years = {to_year(tm) for tm in time_param.asdict().get("time_values", [])}
for fpath in file_paths:
years = get_years_from_file(fpath)
if req_years.intersection(years):
files_in_time_range.append(fpath)
logger.info(f"Kept {len(files_in_time_range)} files")
return files_in_time_range
[docs]
def consolidate(collection, **kwargs):
"""Find file paths relating to each input dataset."""
collection = wrap_sequence(collection.value)
sources = []
time_param = kwargs.get("time")
catalogs = {}
for dset in collection:
if not isinstance(dset, FileMapper) and _bypasses_catalog(dset):
sources.append(DatasetSource(dataset_id=None, paths=dset))
continue
catalog = None if isinstance(dset, FileMapper) else _catalog_for(dset, catalogs)
if catalog:
sources.extend(_resolve_catalog_sources(dset, catalog, time_param))
else:
sources.append(_resolve_file_source(dset, time_param))
return tuple(sources)