Source code for xarray_regrid.regrid

from collections.abc import Hashable
from typing import Any, overload

import numpy as np
import xarray as xr

from xarray_regrid.methods import conservative, flox_reduce, interp
from xarray_regrid.utils import format_for_regrid


@xr.register_dataarray_accessor("regrid")
@xr.register_dataset_accessor("regrid")
[docs] class Regridder: """Regridding xarray datasets and dataarrays. Available methods: linear: linear, bilinear, or higher dimensional linear interpolation nearest: nearest-neighbor regridding cubic: cubic spline regridding conservative: conservative regridding most_common: most common value regridder stat: area statistics regridder """ def __init__(self, xarray_obj: xr.DataArray | xr.Dataset):
[docs] self._obj = xarray_obj
[docs] def linear( self, ds_target_grid: xr.Dataset, time_dim: str | None = "time", ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target dataset with linear interpolation. Args: ds_target_grid: Dataset containing the target coordinates. time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. Returns: Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) return interp.interp_regrid(ds_formatted, ds_target_grid, "linear")
[docs] def nearest( self, ds_target_grid: xr.Dataset, time_dim: str | None = "time", ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target with nearest-neighbor interpolation. Args: ds_target_grid: Dataset containing the target coordinates. time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. Returns: Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) return interp.interp_regrid(ds_formatted, ds_target_grid, "nearest")
[docs] def cubic( self, ds_target_grid: xr.Dataset, time_dim: str | None = "time", ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target dataset with cubic interpolation. Args: ds_target_grid: Dataset containing the target coordinates. time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. Returns: Data regridded to the target dataset coordinates. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) return interp.interp_regrid(ds_formatted, ds_target_grid, "cubic")
[docs] def conservative( self, ds_target_grid: xr.Dataset, latitude_coord: str | None = None, time_dim: str | None = "time", skipna: bool = True, nan_threshold: float = 1.0, output_chunks: dict[Hashable, int] | None = None, ) -> xr.DataArray | xr.Dataset: """Regrid to the coords of the target dataset with a conservative scheme. Args: ds_target_grid: Dataset containing the target coordinates. latitude_coord: Name of the latitude coord, to be used for applying the spherical correction. By default, attempt to infer a latitude coordinate as either "latitude" or "lat". time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. skipna: If True, enable handling for NaN values. This adds only a small amount of overhead, but can be disabled for optimal performance on data without any NaNs. nan_threshold: Threshold value that will retain any output points containing at least this many non-null input points. The default value is 1.0, which will keep output points containing any non-null inputs, while a value of 0.0 will only keep output points where all inputs are non-null. output_chunks: Optional dictionary of explicit chunk sizes for the output data. If not provided, the output will be chunked the same as the input data. Returns: Data regridded to the target dataset coordinates. """ if not 0.0 <= nan_threshold <= 1.0: msg = "nan_threshold must be between [0, 1]]" raise ValueError(msg) ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid) return conservative.conservative_regrid( ds_formatted, ds_target_grid, latitude_coord, skipna, nan_threshold, output_chunks, )
[docs] def most_common( self, ds_target_grid: xr.Dataset, values: np.ndarray, time_dim: str | None = "time", fill_value: None | Any = None, ) -> xr.DataArray: """Regrid by taking the most common value within the new grid cells. To be used for regridding data to a much coarser resolution, not for regridding when the source and target grids are of a similar resolution. Note that in the case of two unqiue values with the same count, the behaviour is not deterministic, and the resulting "most common" one will randomly be either of the two. Args: ds_target_grid: Target grid dataset values: Numpy array containing all labels expected to be in the input data. For example, `np.array([0, 2, 4])`, if the data only contains the values 0, 2 and 4. time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. fill_value: What value to fill uncovered parts of the target grid. By default this will be NaN, and integer type data will be cast to float to accomodate this. Returns: Regridded data. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) if isinstance(self._obj, xr.Dataset): msg = ( "The 'most common value' regridder is not implemented for\n", "xarray.Dataset, as it requires specifying the expected labels.\n" "Please select only a single variable (as DataArray),\n" " and regrid it separately.", ) raise ValueError(msg) ds_formatted = format_for_regrid(self._obj, ds_target_grid, stats=True) return flox_reduce.compute_mode( ds_formatted, ds_target_grid, values, time_dim, fill_value, anti_mode=False, )
[docs] def least_common( self, ds_target_grid: xr.Dataset, values: np.ndarray, time_dim: str | None = "time", fill_value: None | Any = None, ) -> xr.DataArray: """Regrid by taking the least common value within the new grid cells. To be used for regridding data to a much coarser resolution, not for regridding when the source and target grids are of a similar resolution. Note that in the case of two unqiue values with the same count, the behaviour is not deterministic, and the resulting "least common" one will randomly be either of the two. Args: ds_target_grid: Target grid dataset values: Numpy array containing all labels expected to be in the input data. For example, `np.array([0, 2, 4])`, if the data only contains the values 0, 2 and 4. time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. fill_value: What value to fill uncovered parts of the target grid. By default this will be NaN, and integer type data will be cast to float to accomodate this. Returns: Regridded data. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) if isinstance(self._obj, xr.Dataset): msg = ( "The 'least common value' regridder is not implemented for\n", "xarray.Dataset, as it requires specifying the expected labels.\n" "Please select only a single variable (as DataArray),\n" " and regrid it separately.", ) raise ValueError(msg) ds_formatted = format_for_regrid(self._obj, ds_target_grid, stats=True) return flox_reduce.compute_mode( ds_formatted, ds_target_grid, values, time_dim, fill_value, anti_mode=True, )
[docs] def stat( self, ds_target_grid: xr.Dataset, method: str, time_dim: str | None = "time", skipna: bool = False, fill_value: None | Any = None, ) -> xr.DataArray | xr.Dataset: """Upsampling of data using statistical methods (e.g. the mean or variance). We use flox Aggregations to perform a "groupby" over multiple dimensions, which we reduce using the specified method. https://flox.readthedocs.io/en/latest/aggregations.html Args: ds_target_grid: Target grid dataset method: One of the following reduction methods: "sum", "mean", "var", "std", "median", "min", or "max". time_dim: Name of the time dimension. Defaults to "time". Use `None` to force regridding over the time dimension. skipna: If NaN values should be ignored. fill_value: What value to fill uncovered parts of the target grid. By default this will be NaN, and integer type data will be cast to float to accomodate this. Returns: xarray.dataset with regridded land cover categorical data. """ ds_target_grid = validate_input(self._obj, ds_target_grid, time_dim) ds_formatted = format_for_regrid(self._obj, ds_target_grid, stats=True) return flox_reduce.statistic_reduce( ds_formatted, ds_target_grid, time_dim, method, skipna, fill_value )
@overload
[docs] def validate_input( data: xr.Dataset, ds_target_grid: xr.Dataset, time_dim: str | None, ) -> xr.Dataset: ...
@overload def validate_input( data: xr.DataArray, ds_target_grid: xr.Dataset, time_dim: str | None, ) -> xr.Dataset: ... def validate_input( data: xr.DataArray | xr.Dataset, ds_target_grid: xr.Dataset, time_dim: str | None, ) -> xr.Dataset: if time_dim is not None and time_dim in ds_target_grid.coords: ds_target_grid = ds_target_grid.isel({time_dim: 0}).reset_coords() if len(set(data.dims).intersection(set(ds_target_grid.dims))) == 0: msg = ( "None of the target dims are in the data:\n" " regridding is not possible.\n" f"Target dims: {list(ds_target_grid.dims)}\n" f"Source dims: {list(data.dims)}" ) raise ValueError(msg) if len(set(data.coords).intersection(set(ds_target_grid.coords))) == 0: msg = ( "None of the target coords are in the data:\n" " regridding is not possible.\n" f"Target coords: {ds_target_grid.coords}\n" f"Dataset coords: {data.coords}" ) raise ValueError(msg) return ds_target_grid