from collections.abc import Callable, Hashable
from dataclasses import dataclass
from typing import Any, TypedDict, overload
import numpy as np
import pandas as pd
import xarray as xr
[docs]
class InvalidBoundsError(Exception): ...
[docs]
class CoordHandler(TypedDict):
@dataclass
[docs]
class Grid:
"""Object storing grid information."""
[docs]
def __post_init__(self) -> None:
"""Validate the initialized SpatialBounds class."""
msg = None
if self.south > self.north:
msg = (
"Value of north bound is greater than south bound."
"\nPlease check the bounds input."
)
pass
if self.west > self.east:
msg = (
"Value of west bound is greater than east bound."
"\nPlease check the bounds input."
)
if msg is not None:
raise InvalidBoundsError(msg)
[docs]
def create_regridding_dataset(
self, lat_name: str = "latitude", lon_name: str = "longitude"
) -> xr.Dataset:
"""Create a dataset to use for regridding.
Args:
grid: Grid object containing the bounds and resolution of the
cartesian grid.
lat_name: Name for the latitudinal coordinate and dimension.
Defaults to "latitude".
lon_name: Name for the longitudinal coordinate and dimension.
Defaults to "longitude".
Returns:
A dataset with the latitude and longitude coordinates corresponding to the
specified grid. Contains no data variables.
"""
return create_regridding_dataset(self, lat_name, lon_name)
[docs]
def create_lat_lon_coords(grid: Grid) -> tuple[np.ndarray, np.ndarray]:
"""Create latitude and longitude coordinates based on the provided grid parameters.
Args:
grid: Grid object.
Returns:
Latititude coordinates, longitude coordinates.
"""
if np.remainder((grid.north - grid.south), grid.resolution_lat) > 0:
lat_coords = np.arange(grid.south, grid.north, grid.resolution_lat)
else:
lat_coords = np.arange(
grid.south, grid.north + grid.resolution_lat, grid.resolution_lat
)
if np.remainder((grid.east - grid.west), grid.resolution_lat) > 0:
lon_coords = np.arange(grid.west, grid.east, grid.resolution_lon)
else:
lon_coords = np.arange(
grid.west, grid.east + grid.resolution_lon, grid.resolution_lon
)
return lat_coords, lon_coords
[docs]
def create_regridding_dataset(
grid: Grid, lat_name: str = "latitude", lon_name: str = "longitude"
) -> xr.Dataset:
"""Create a dataset to use for regridding.
Args:
grid: Grid object containing the bounds and resolution of the cartesian grid.
lat_name: Name for the latitudinal coordinate and dimension.
Defaults to "latitude".
lon_name: Name for the longitudinal coordinate and dimension.
Defaults to "longitude".
Returns:
A dataset with the latitude and longitude coordinates corresponding to the
specified grid. Contains no data variables.
"""
lat_coords, lon_coords = create_lat_lon_coords(grid)
return xr.Dataset(
{
lat_name: ([lat_name], lat_coords, {"units": "degrees_north"}),
lon_name: ([lon_name], lon_coords, {"units": "degrees_east"}),
}
)
[docs]
def to_intervalindex(coords: np.ndarray) -> pd.IntervalIndex:
"""Convert a 1-d coordinate array to a pandas IntervalIndex. Take
the midpoints between the coordinates as the interval boundaries.
Args:
coords: 1-d array containing the coordinate values.
Returns:
A pandas IntervalIndex containing the intervals corresponding to the input
coordinates.
"""
if len(coords) > 1:
midpoints = (coords[:-1] + coords[1:]) / 2
# Extrapolate outer bounds beyond the first and last coordinates
left_bound = 2 * coords[0] - midpoints[0]
right_bound = 2 * coords[-1] - midpoints[-1]
breaks = np.concatenate([[left_bound], midpoints, [right_bound]])
intervals = pd.IntervalIndex.from_breaks(breaks)
else:
# If the target grid has a single point, set search interval to span all space
intervals = pd.IntervalIndex.from_breaks([-np.inf, np.inf])
return intervals
[docs]
def overlap(a: pd.IntervalIndex, b: pd.IntervalIndex) -> np.ndarray:
"""Calculate the overlap between two sets of intervals.
Args:
a: Pandas IntervalIndex containing the first set of intervals.
b: Pandas IntervalIndex containing the second set of intervals.
Returns:
2D numpy array containing overlap (as a fraction) between the intervals of a
and b. If there is no overlap, the value will be 0.
"""
# TODO: newaxis on B and transpose is MUCH faster on benchmark.
# likely due to it being the bigger dimension.
# size(a) > size(b) leads to better perf than size(b) > size(a)
mins = np.minimum(a.right.to_numpy(), b.right.to_numpy()[:, np.newaxis])
maxs = np.maximum(a.left.to_numpy(), b.left.to_numpy()[:, np.newaxis])
overlap: np.ndarray = np.maximum(mins - maxs, 0).T
return overlap
[docs]
def normalize_overlap(overlap: np.ndarray) -> np.ndarray:
"""Normalize overlap values so they sum up to 1.0 along the first axis."""
overlap_sum: np.ndarray = overlap.sum(axis=0)
overlap_sum[overlap_sum == 0] = 1e-12 # Avoid dividing by 0.
return overlap / overlap_sum # type: ignore
[docs]
def create_dot_dataarray(
weights: np.ndarray,
coord: str,
target_coords: np.ndarray,
source_coords: np.ndarray,
) -> xr.DataArray:
"""Create a DataArray to be used at dot product compatible with xr.dot."""
return xr.DataArray(
data=weights,
dims=[coord, f"target_{coord}"],
coords={
coord: source_coords,
f"target_{coord}": target_coords,
},
)
[docs]
def common_coords(
data1: xr.DataArray | xr.Dataset,
data2: xr.DataArray | xr.Dataset,
remove_coord: str | None = None,
) -> list[Hashable]:
"""Return a set of coords which two dataset/arrays have in common."""
coords = set(data1.coords).intersection(set(data2.coords))
if remove_coord in coords:
coords.remove(remove_coord)
return list(coords)
[docs]
def call_on_dataset(
func: Callable[..., xr.Dataset],
obj: xr.DataArray | xr.Dataset,
*args: Any,
**kwargs: Any,
) -> xr.DataArray | xr.Dataset:
"""Use to call a function that expects a Dataset on either a Dataset or
DataArray, round-tripping to a temporary dataset."""
placeholder_name = "_UNNAMED_ARRAY"
if isinstance(obj, xr.DataArray):
tmp_name = obj.name if obj.name is not None else placeholder_name
ds = obj.to_dataset(name=tmp_name)
else:
ds = obj
result = func(ds, *args, **kwargs)
if isinstance(obj, xr.DataArray) and isinstance(result, xr.Dataset):
msg = "Trying to convert Dataset with more than one data variable to DataArray"
if len(result.data_vars) > 1:
raise TypeError(msg)
return next(iter(result.data_vars.values())).rename(obj.name)
return result
@overload
@overload
def format_for_regrid(
obj: xr.DataArray,
target: xr.Dataset,
stats: bool = False,
) -> xr.DataArray: ...
def format_for_regrid(
obj: xr.DataArray | xr.Dataset,
target: xr.Dataset,
stats: bool = False,
) -> xr.DataArray | xr.Dataset:
"""Apply any pre-formatting to the input dataset to prepare for regridding.
Currently handles padding of spherical geometry if lat/lon coordinates can
be inferred and the domain size requires boundary padding.
"""
# Special-cased coordinates with accepted names and formatting function
coord_handlers: dict[str, CoordHandler] = {
"lat": {"names": ["lat", "latitude"], "func": format_lat},
"lon": {"names": ["lon", "longitude"], "func": format_lon},
}
# Latitude padding adds a duplicate value which will undesirably
# alter statistical aggregations
if stats:
coord_handlers.pop("lat")
# Identify coordinates that need to be formatted
formatted_coords = {}
for coord_type, handler in coord_handlers.items():
for coord in obj.coords.keys():
if str(coord).lower() in handler["names"]:
formatted_coords[coord_type] = str(coord)
# Apply formatting
result = obj.copy()
for coord_type, coord in formatted_coords.items():
# Make sure formatted coords are sorted
result = ensure_monotonic(result, coord)
target = ensure_monotonic(target, coord)
result = coord_handlers[coord_type]["func"](result, target, formatted_coords)
# Coerce back to a single chunk if that's what was passed
if isinstance(obj, xr.DataArray) and len(obj.chunksizes.get(coord, ())) == 1:
result = result.chunk({coord: -1})
elif isinstance(obj, xr.Dataset):
for var in result.data_vars:
if len(obj[var].chunksizes.get(coord, ())) == 1:
result[var] = result[var].chunk({coord: -1})
return result
[docs]
def coord_is_covered(
obj: xr.DataArray | xr.Dataset, target: xr.Dataset, coord: Hashable
) -> bool:
"""Check if the source coord fully covers the target coord."""
pad = target[coord].diff(coord).max().values
left_covered = obj[coord].min() <= target[coord].min() - pad
right_covered = obj[coord].max() >= target[coord].max() + pad
return bool(left_covered.item() and right_covered.item())
@overload
[docs]
def ensure_monotonic(obj: xr.DataArray, coord: Hashable) -> xr.DataArray: ...
@overload
def ensure_monotonic(obj: xr.Dataset, coord: Hashable) -> xr.Dataset: ...
def ensure_monotonic(
obj: xr.DataArray | xr.Dataset, coord: Hashable
) -> xr.DataArray | xr.Dataset:
"""Ensure that an object has monotonically increasing indexes for a
given coordinate. Only sort and drop duplicates if needed because this
requires reindexing which can be expensive."""
if not obj.indexes[coord].is_monotonic_increasing:
obj = obj.sortby(coord)
if not obj.indexes[coord].is_unique:
obj = obj.drop_duplicates(coord)
return obj
@overload
[docs]
def update_coord(
obj: xr.DataArray, coord: Hashable, coord_vals: np.ndarray
) -> xr.DataArray: ...
@overload
def update_coord(
obj: xr.Dataset, coord: Hashable, coord_vals: np.ndarray
) -> xr.Dataset: ...
def update_coord(
obj: xr.DataArray | xr.Dataset, coord: Hashable, coord_vals: np.ndarray
) -> xr.DataArray | xr.Dataset:
"""Update the values of a coordinate, ensuring indexes stay in sync."""
attrs = obj.coords[coord].attrs
obj = obj.assign_coords({coord: coord_vals})
obj.coords[coord].attrs = attrs
return obj