# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""Plugins to perform clustering on realizations within a cube."""
import json
import warnings
from collections import defaultdict
from typing import Any, Optional
import iris
import numpy as np
import pandas as pd
from iris.cube import Cube, CubeList
from iris.util import new_axis, promote_aux_coord_to_dim_coord
from improver import BasePlugin
from improver.clustering.clustering import FitClustering
from improver.regrid.landsea import RegridLandSea
from improver.utilities.cube_manipulation import (
MergeCubes,
enforce_coordinate_ordering,
get_dim_coord_names,
)
from improver.utilities.temporal import (
reset_forecast_reference_time_and_period,
)
try:
import kmedoids
except ModuleNotFoundError:
# Define empty class to avoid type hint errors.
class KMedoids:
pass
[docs]
class RealizationClustering(BasePlugin):
"""Class to perform clustering on realizations of a cube. For example, this can be
used to cluster a large number of ensemble members based on their spatial patterns
into a smaller set of distinct clusters. If the input is precipitation forecasts,
the resultant clusters could represent different types of precipitation events.
"""
[docs]
def __init__(self, clustering_method: str, **kwargs: Any) -> None:
"""Initialise the RealizationClustering class.
Args:
clustering_method: The clustering method to use. The clustering method
to use (e.g. "KMedoids"). The method must be supported by the
improver.clustering.FitClustering class.
**kwargs: Additional arguments for the clustering method.
"""
self.clustering_method = clustering_method
self.kwargs = kwargs
[docs]
@staticmethod
def _convert_to_2d(array: np.ndarray) -> np.ndarray:
"""Convert an array with arbitrary dimensions to a 2D array by maintaining
the zeroth dimension and flattening all other dimensions.
This prepares the data for clustering algorithms that expect 2D input where
rows are samples (realizations) and columns are features
(e.g. spatial points x forecast periods) so that an array of shape
(18, 4, 100, 100) is converted to shape (18, 40000).
Args:
array: The input array to convert. Can have any number of dimensions.
Returns:
array_2d: The converted 2D array with shape (array.shape[0], -1).
"""
if array.ndim < 2:
msg = "Input array must have at least 2 dimensions."
raise ValueError(msg)
elif array.ndim == 2:
return array.copy()
else:
target_shape = (array.shape[0], -1)
return array.reshape(target_shape)
[docs]
def process(self, cube: Cube) -> Any:
"""Apply the clustering method to the cube.
Cubes with more than 2 dimensions are converted to 2D arrays before
clustering by flattening all dimensions except the leading dimension.
The leading dimension is assumed to be the realization dimension.
Args:
cube: The input cube to cluster with the realization dimension
as the leading dimension.
Returns:
The result of the clustering algorithm applied to the input data.
Raises:
ValueError: If the leading dimension of the input cube is not
the realization dimension.
"""
if cube.dim_coords[0].name() != "realization":
msg = (
"The leading dimension of the input cube must be "
"the 'realization' dimension."
)
raise ValueError(msg)
array_2d = self._convert_to_2d(cube.data)
# The rows of the DataFrame correspond to realizations. The columns correspond
# to the flattened non-realization dimensions. These column values are the
# features that the clustering algorithm will use to cluster the realizations.
df = pd.DataFrame(
array_2d,
index=[f"realization_{p}" for p in cube.coord("realization").points],
)
return FitClustering(self.clustering_method, **self.kwargs)(df)
[docs]
class RealizationToClusterMatcher(BasePlugin):
"""Match candidate realizations to clusters based on mean squared error (MSE).
In this context, 'candidate realizations' refers to the set of realizations being
considered for assignment to clusters from a secondary input. These are matched
to clusters derived from a primary input by minimizing mean squared error (MSE).
Assigns realizations from a secondary input (e.g. a high-resolution regional
ensemble model) to clusters defined by a primary input (e.g. a coarse global
ensemble model). When multiple candidates compete for the same cluster, only the
candidate with the lowest MSE is assigned; other candidates are not assigned to
any cluster.
Supports both 3D cubes and 4D cubes of dimensions (realization, y, x) and
(realization, forecast_period, y, x) respectively only.
"""
[docs]
def __init__(self) -> None:
"""Initialise the plugin."""
pass
[docs]
def _mean_squared_error_per_realization(
self,
clustered_array: np.ndarray,
candidate_array: np.ndarray,
n_realizations: int,
) -> np.ndarray:
"""Calculate MSE between clustered and candidate realization arrays. Lower MSE
indicates a candidate realization better matches a cluster's representative
member.
For 3D cubes, the MSE is calculated by averaging over spatial dimensions (y, x).
For 4D cubes, the mean is calculated over spatial dimensions first, then
the MSE is averaged over forecast_period.
Args:
clustered_array: The clustered array with shape (n_clusters, y, x) or
(n_clusters, forecast_period, y, x).
candidate_array: The candidate array with shape (n_realizations, y, x)
or (n_realizations, forecast_period, y, x).
n_realizations: The number of realizations in the candidate array.
Returns:
Array of MSE values with shape (n_realizations, n_clusters) with
element [i, j] containing the MSE between candidate realization i
and cluster j.
"""
mse_list = []
for index in range(n_realizations):
# Calculate squared differences between each candidate realization and
# all cluster medoids
squared_diff = np.square(clustered_array - candidate_array[index])
if clustered_array.ndim == 3:
# For 3D: average over spatial dimensions (y, x)
mse = np.nanmean(squared_diff, axis=(1, 2))
else:
# For 4D: mean over spatial (y, x), then mean over forecast_period
mse = np.nanmean(np.nanmean(squared_diff, axis=(2, 3)), axis=1)
mse_list.append(mse)
return np.array(mse_list)
[docs]
def _validate_cube_dimensions(
self, clusters_cube: Cube, candidate_cube: Cube
) -> None:
"""Validate that both the clustered and candidate cubes have matching
dimensions.
Args:
clusters_cube: The clustered cube.
candidate_cube: The candidate cube.
Raises:
ValueError: If cube dimensions don't match.
ValueError: If dimension coordinate names don't match.
"""
if clusters_cube.ndim != candidate_cube.ndim:
msg = (
f"Clustered cube has {clusters_cube.ndim} dimensions but candidate "
f"cube has {candidate_cube.ndim} dimensions. Both cubes must have "
"the same number of dimensions (either 3D or 4D)."
)
raise ValueError(msg)
if get_dim_coord_names(clusters_cube) != get_dim_coord_names(candidate_cube):
msg = (
"Clustered and candidate cubes must have the same dimension "
"coordinates in the same order. "
f"Clustered cube dimensions: {get_dim_coord_names(clusters_cube)}, "
f"Candidate cube dimensions: {get_dim_coord_names(candidate_cube)}"
)
raise ValueError(msg)
[docs]
def _validate_forecast_period_coords(
self, clusters_cube: Cube, candidate_cube: Cube
) -> None:
"""Validate matching forecast_period coordinates for 4D cubes.
Args:
clusters_cube: The clustered cube.
candidate_cube: The candidate cube.
Raises:
ValueError: If forecast period coords do not match between clustered and
candidate cubes.
"""
if clusters_cube.ndim == 4:
cube_fp = clusters_cube.coord("forecast_period")
candidate_fp = candidate_cube.coord("forecast_period")
if not np.array_equal(cube_fp.points, candidate_fp.points):
msg = (
"Forecast period coords must match between clustered and "
f"candidate cubes. Clustered: {cube_fp.points}, "
f"Candidate: {candidate_fp.points}"
)
raise ValueError(msg)
[docs]
def assign_clusters(self, realization_cluster_mse: np.ndarray) -> list[int]:
"""Assign clusters to candidate realizations using greedy MSE minimization.
This method assigns candidate realizations to clusters by minimizing mean
squared error. The algorithm iterates through realizations in descending order
of their "MSE cost" (the sum of differences between each cluster's MSE and
the minimum MSE for that realization). Realizations with higher costs
(those with more uniform MSE across clusters, i.e. without a cluster that they
are "well matched" to) are processed first; low cost-realizations (those with
a stronger "preference" for one cluster) are processed last. During each
iteration, if the realization's MSE is better than the current holder's MSE
(or the cluster is unassigned), it replaces assignment to that cluster;
otherwise the cluster remains assigned to its current realization. This
iterative process continues through all realizations, with early assignments
by flexible realizations often being replaced by later-processed realizations
that have stronger (lower MSE) matches to clusters
Note: This greedy algorithm is chosen for its relative simplicity and
computational efficiency. While optimal assignment algorithms (such as
the Hungarian algorithm) could guarantee globally optimal solutions,
this approach provides good results with O(n²) complexity and
deterministic behavior.
Args:
realization_cluster_mse: The MSE array with shape
(n_realizations, n_clusters).
Returns:
Tuple of (cluster_indices, realization_indices):
- cluster_indices: List of cluster indices that were assigned
(may be < n_clusters), sorted in ascending order.
- realization_indices: List of realization indices assigned to each
cluster (one per assigned cluster).
"""
# Calculate cost for each realization (sum of differences from minimum MSE)
min_mse_array = np.min(realization_cluster_mse, axis=1, keepdims=True)
mse_array_cost = np.sum(realization_cluster_mse - min_mse_array, axis=1)
# Process realizations in descending order of cost (highest cost first)
realization_order = np.argsort(mse_array_cost)[::-1]
n_realizations = realization_cluster_mse.shape[0]
n_clusters = realization_cluster_mse.shape[1]
cluster_to_realization = {}
cluster_to_mse = {}
# Iterate through realizations in order of descending cost. For example,
# realization_order might be [3, 1, 0, 2].
for loop_idx, realization_idx in enumerate(realization_order):
# assigned_clusters is a list of cluster indices that have already
# been assigned to a realization. In the first iteration, this will be
# empty. In later iterations, this will contain the clusters that have
# already been assigned to realizations in previous iterations.
# clusters_remaining is the number of clusters that have not yet been
# assigned to any realization.
assigned_clusters = list(cluster_to_realization.keys())
clusters_remaining = n_clusters - len(assigned_clusters)
# If there are at least as many unassigned clusters as remaining
# realizations, prevent competition for already-assigned clusters by
# setting their MSE to inf. This forces each remaining realization to
# select from unassigned clusters.
mse_values = realization_cluster_mse[realization_idx].copy()
n_realizations_remaining = n_realizations - loop_idx
if n_realizations_remaining <= clusters_remaining:
mse_values[assigned_clusters] = np.inf
# Skip this realization if all MSE values are NaN
if np.all(np.isnan(mse_values)):
continue
cluster_idx = np.nanargmin(mse_values)
if mse_values[cluster_idx] < cluster_to_mse.get(cluster_idx, np.inf):
# cluster_to_realization maps cluster indices to the currently assigned
# realization index e.g. {1: 3}. cluster_to_mse maps cluster indices
# to the MSE of the currently assigned realization e.g. {1: 10000}.
cluster_to_mse[cluster_idx] = mse_values[cluster_idx]
cluster_to_realization[cluster_idx] = realization_idx
# Sort by cluster index and return both cluster indices and realization indices
sorted_items = sorted(cluster_to_realization.items())
cluster_indices, realization_indices = zip(*sorted_items)
return list(cluster_indices), list(realization_indices)
[docs]
def process(
self,
clusters_cube: Cube,
candidate_cube: Cube,
) -> tuple[list[int], list[int]]:
"""Assign candidate realizations to clusters by mean squared error (MSE).
This method takes a cube of clustered realizations (e.g. from a global model)
and candidate realizations (e.g. from a higher-resolution model), then assigns
each cluster to the candidate realization with the lowest MSE for that cluster.
When multiple candidates compete for the same cluster, only the one with the
lowest MSE is assigned; other candidates are not assigned to any cluster.
Supports both 3D cubes (realization, y, x) and 4D cubes
(realization, forecast_period, y, x). When using 4D cubes, both input
cubes must have matching forecast_period coordinates.
Args:
clusters_cube: The cube containing clustered realizations (e.g., from
KMedoids clustering). Shape: (n_clusters, y, x) or
(n_clusters, forecast_period, y, x).
candidate_cube: The input cube with realizations to assign to
clusters. Shape: (n_realizations, y, x) or
(n_realizations, forecast_period, y, x).
Returns:
Tuple of (cluster_indices, realization_indices):
cluster_indices: List of cluster indices that were assigned.
May have length < n_clusters if there are fewer candidates
than clusters.
realization_indices: List of realization indices assigned to each
cluster (one per assigned cluster).
"""
# Strictly enforce dimension order for both cubes
enforce_coordinate_ordering(
clusters_cube, ["realization", "forecast_period", "y", "x"]
)
enforce_coordinate_ordering(
candidate_cube, ["realization", "forecast_period", "y", "x"]
)
n_candidates = len(candidate_cube.coord("realization").points)
# Validate inputs
self._validate_cube_dimensions(clusters_cube, candidate_cube)
self._validate_forecast_period_coords(clusters_cube, candidate_cube)
realization_cluster_mse = self._mean_squared_error_per_realization(
clusters_cube.data,
candidate_cube.data,
n_candidates,
)
cluster_indices, realization_indices = self.assign_clusters(
realization_cluster_mse
)
return cluster_indices, realization_indices
[docs]
class RealizationClusterAndMatch(BasePlugin):
"""Cluster primary input realizations and match secondary inputs to clusters.
This plugin performs KMedoids clustering on a primary input, then matches
secondary input realizations to the resulting clusters based on mean squared
error. When multiple secondary inputs are provided, their order in the hierarchy
determines their precedence: inputs listed earlier (leftmost) in the
secondary_inputs dictionary have higher priority and can overwrite matches from
later (lower-priority) ones for overlapping forecast periods. In other words, the
first (leftmost) secondary input in the dictionary has the highest precedence, and
later ones have lower precedence. See the Args section of the __init__ docstring
for details on how the hierarchy is specified and used.
See Also:
For a practical usage example, see:
doc/source/examples/realization_cluster_and_match_example_data.py.
"""
[docs]
def __init__(
self,
hierarchy: dict[str, str | dict[str, list[int]]],
model_id_attr: str,
clustering_method: str,
target_grid_name: str | None = None,
regrid_mode: str = "esmf-area-weighted",
regrid_for_clustering: bool = True,
regrid_kwargs: dict[str, Any] | None = None,
cycletime: str | None = None,
**kwargs: Any,
) -> None:
"""Initialise the clustering and matching class.
Args:
hierarchy: The hierarchy of inputs defining the primary input, which is
clustered, and secondary inputs, which are matched to each cluster.
The order of the secondary_inputs is used as the priority for matching.
The list values for secondary inputs specify forecast periods in hours.
A two-element list [start, end] will be expanded to include all hours
in that inclusive range. Lists with other lengths are treated as
explicit lists of forecast period hours. All values will be
automatically converted to seconds to match the forecast_period
coordinate units in the input cubes::
{
"primary_input": "input1",
"secondary_inputs": {"input2": [0, 24], "input3": [0, 6]},
}
In this example, input2 will use forecast periods in the range
0 to 24 hours inclusive (i.e., any forecast periods between 0 and
86400 seconds), and input3 will use the range 0 to 6 hours
(0 to 86400 seconds). For lead times, where secondary inputs are
not provided the primary input will be used. Only forecast periods
that actually exist in the input cubes within these ranges will be
processed.
model_id_attr: The model ID attribute used to identify different models
within the input cubes.
target_grid_name: The name of the target grid cube for regridding. Only
required if regrid_for_clustering is True.
clustering_method: The clustering method to use.
regrid_mode: The regridding mode to use. Default is
"esmf-area-weighted". See RegridLandSea for available modes.
regrid_for_clustering: If True, regrid all cubes (primary and secondary)
to the target grid before clustering and matching. This regridding
step speeds up the computation by reducing the data size and,
importantly, emphasises larger-scale spatial features in the data,
rather than small-scale detail. This helps the clustering focus on the
most relevant broad patterns rather than being dominated by
fine-scale noise. If False, clustering and matching are performed
on the original grids without regridding. Default is True.
regrid_kwargs: Additional keyword arguments to pass to RegridLandSea.
Common options include:
- mdtol (float): Tolerance of missing data (default 1).
- extrapolation_mode (str): Mode to fill regions outside domain.
- landmask (Cube): Land-sea mask for mask-aware regridding.
- landmask_vicinity (float): Radius for coastline search.
cycletime:
The forecast_reference_time on the input cubes will be reset to
this value. The forecast periods will be adjusted accordingly with
the validity times kept fixed. cycletime should be provided in the
format YYYYMMDDTHHMMZ (e.g., 20240101T0000Z). If not provided, the
forecast_reference_time on the input cubes will be left unchanged.
**kwargs: Additional arguments for the clustering method.
Raises:
ValueError: If regrid_for_clustering is True but target_grid_name is None.
NotImplementedError: If the clustering method is not supported
(currently only KMedoids is supported).
"""
self.hierarchy = hierarchy
self.model_id_attr = model_id_attr
self.target_grid_name = target_grid_name
self.clustering_method = clustering_method
self.regrid_mode = regrid_mode
self.regrid_for_clustering = regrid_for_clustering
self.regrid_kwargs = regrid_kwargs if regrid_kwargs is not None else {}
self.cycletime = cycletime
self.kwargs = kwargs
if regrid_for_clustering and target_grid_name is None:
msg = (
"target_grid_name must be provided when regrid_for_clustering is True."
)
raise ValueError(msg)
if clustering_method != "KMedoids":
msg = (
"Currently only KMedoids clustering is supported for the clustering "
"and matching of realizations."
)
raise NotImplementedError(msg)
[docs]
@staticmethod
def _expand_forecast_period_range(fp_range: list[int]) -> list[int]:
"""Expand a forecast period range [start, end] to a list of integers.
Args:
fp_range: A list containing either [start, end] values defining a range
in hours, or a list of specific forecast period hours.
Returns:
If fp_range has 2 elements, returns integers from start to end inclusive
in steps of 1 hour. Otherwise, returns the list as-is.
Raises:
ValueError: If start > end (when 2 elements provided).
"""
if len(fp_range) == 2:
start, end = fp_range
if start > end:
msg = f"Forecast period range start ({start}) must be <= end ({end})"
raise ValueError(msg)
return list(range(start, end + 1, 1))
else:
# Return as-is for lists with != 2 elements
return fp_range
[docs]
@staticmethod
def _convert_hours_to_seconds(hours: list[int]) -> list[int]:
"""Convert a list of hours to seconds.
Args:
hours: List of forecast period values in hours.
Returns:
List of forecast period values in seconds.
"""
return [h * 3600 for h in hours]
[docs]
def _select_realizations_for_kmedoid_clusters(
self, primary_cube: Cube, clustering_result: "kmedoids.KMedoids"
) -> Cube:
"""Select the realizations corresponding to the medoid indices from
the clustering result.
Args:
primary_cube: The input cube to select realizations from.
clustering_result: The result of the clustering.
Returns:
cube_clustered: The clustered cube.
Raises:
ValueError: If the number of clusters is greater than the number of
realizations in the input cube.
"""
n_realizations = len(primary_cube.coord("realization").points)
if len(clustering_result.medoid_indices_) > n_realizations:
n_clusters = len(clustering_result.medoid_indices_)
msg = (
f"The number of clusters {n_clusters} is expected to be less than "
f"the number of realizations {n_realizations}. "
"Please reduce the number of clusters."
)
raise ValueError(msg)
# Select the realizations corresponding to the medoid indices.
cube_clustered = primary_cube[clustering_result.medoid_indices_]
cube_clustered.coord("realization").points = range(
len(clustering_result.medoid_indices_)
)
promote_aux_coord_to_dim_coord(cube_clustered, "realization")
cluster_to_realizations = defaultdict(list)
for idx, cluster_num in enumerate(clustering_result.labels_):
cluster_to_realizations[int(cluster_num)].append(
int(primary_cube.coord("realization").points[idx])
)
# Convert defaultdict to regular dict for serialization
cluster_to_realizations = {
k: cluster_to_realizations[k] for k in sorted(cluster_to_realizations)
}
cube_clustered.attributes["primary_input_realizations_to_clusters"] = (
cluster_to_realizations
)
# Store which realization is the medoid for each cluster
cluster_primary_map = {
int(cluster_idx): int(primary_cube.coord("realization").points[medoid_idx])
for cluster_idx, medoid_idx in enumerate(clustering_result.medoid_indices_)
}
cube_clustered.attributes["primary_input_realization_to_cluster_medoid"] = (
cluster_primary_map
)
return cube_clustered
[docs]
def _ensure_forecast_period_is_dimension(self, cube: Cube) -> Cube:
"""Ensure forecast_period is a dimension coordinate and realization is first.
If forecast_period exists but is not a dimension coordinate (i.e., it's scalar
or auxiliary), promote it to a dimension coordinate using new_axis. Then ensure
realization is the leading dimension. Also ensures that the time coordinate is
associated with the forecast_period dimension to avoid it being scalar.
Args:
cube: The cube to check and potentially modify.
Returns:
The cube with forecast_period as a dimension coordinate (if it exists),
time associated with the forecast_period dimension, and realization as
the first dimension.
"""
if cube.coords("forecast_period") and not cube.coord_dims("forecast_period"):
cube = new_axis(cube, "forecast_period")
enforce_coordinate_ordering(cube, ["realization"])
# Ensure time coordinate is associated with forecast_period dimension
if cube.coords("time") and cube.coords("forecast_period"):
fp_dim = cube.coord_dims("forecast_period")
time_dims = cube.coord_dims("time")
# If time is scalar or not associated with forecast_period dimension
if not time_dims or time_dims != fp_dim:
time_coord = cube.coord("time")
fp_coord = cube.coord("forecast_period")
# Only reassociate if time coord shape matches forecast_period shape
if time_coord.shape == fp_coord.shape:
# Remove time as a coordinate and re-add it associated with
# forecast_period
cube.remove_coord("time")
cube.add_aux_coord(time_coord, fp_dim)
return cube
[docs]
def _initialise_matched_cubes_with_primary(
self, clustered_primary_cube: Cube
) -> CubeList:
"""Initialise matched_cubes with clustered primary cube for all periods.
This ensures we always have a full set of realizations to work with
as a base, which can then be selectively replaced by secondary inputs.
Args:
clustered_primary_cube: The clustered primary cube containing all
forecast periods.
Returns:
A CubeList containing one cube per forecast period from the clustered
primary cube, each with forecast_period as a dimension coordinate.
"""
matched_cubes = CubeList()
for fp_cube in clustered_primary_cube.slices_over("forecast_period"):
fp_cube = self._ensure_forecast_period_is_dimension(fp_cube)
matched_cubes.append(fp_cube)
return matched_cubes
[docs]
def _update_cluster_sources(
self,
cluster_sources: dict[int, dict[str, list[int]]],
cluster_indices: list[int],
candidate_name: str,
fp: int,
) -> None:
"""Update cluster sources tracking when replacing data from one model
with another.
This method removes the forecast period from the primary input's tracking
and adds it to the secondary input for the specified clusters, maintaining a
record of which model provided data for each cluster at each forecast_period.
Args:
cluster_sources: Dictionary tracking which input was used for each
cluster at each forecast period. Modified in-place.
Format: {cluster_idx: {model_name: [fp1, fp2, ...]}}
cluster_indices: List of cluster indices being updated.
candidate_name: Name of the secondary input being added
e.g. 'secondary_input1'.
fp: Forecast period value in seconds.
"""
primary_name = self.hierarchy["primary_input"]
for cluster_idx in cluster_indices:
cluster_sources.setdefault(cluster_idx, {})
# Remove this forecast period from primary input
if primary_name in cluster_sources[cluster_idx]:
if fp in cluster_sources[cluster_idx][primary_name]:
cluster_sources[cluster_idx][primary_name].remove(fp)
# Clean up empty lists
if not cluster_sources[cluster_idx][primary_name]:
del cluster_sources[cluster_idx][primary_name]
# Add to secondary input
if candidate_name not in cluster_sources[cluster_idx]:
cluster_sources[cluster_idx][candidate_name] = []
if fp not in cluster_sources[cluster_idx][candidate_name]:
cluster_sources[cluster_idx][candidate_name].append(fp)
[docs]
def _maybe_regrid_candidate_cube(
self, candidate_cube: Cube, target_grid_cube: Cube
) -> Cube:
"""Regrid the candidate cube if regrid_for_clustering is True, otherwise
return as is.
Args:
candidate_cube: The input candidate Cube to potentially regrid.
target_grid_cube: The target grid Cube to regrid onto if regridding
is enabled.
Returns:
The regridded candidate Cube if regrid_for_clustering is True, otherwise
the original candidate Cube.
"""
if self.regrid_for_clustering:
return RegridLandSea(
regrid_mode=self.regrid_mode,
**self.regrid_kwargs,
)(candidate_cube, target_grid_cube)
else:
return candidate_cube
[docs]
def compact_secondary_mapping(
self,
secondary_input_realizations_to_clusters: dict[
str, dict[int, dict[int, list[int]]]
],
) -> dict[str, dict[int, list[dict[str, list[list[int]] | list[int]]]]]:
"""
Compact the mapping of secondary input realizations to clusters by grouping
forecast periods for each unique realization assignment per cluster.
Args:
secondary_input_realizations_to_clusters: A nested dictionary mapping
secondary input names to forecast periods, then to cluster indices,
then to lists of realization indices:
{
secondary_input_name: {
forecast_period: {
cluster_idx: [realization_index]
}
}
}
Returns:
A compacted dictionary mapping each secondary input name to a dictionary
of cluster indices, each containing a list of dicts with:
- "realization": the realization index assigned.
- "forecast_periods": a sorted list of forecast periods.
Example:
{
secondary_input_name: {
cluster_idx: [
{
"realization": 3,
"forecast_periods": [3600, 7200, 10800]
}
]
}
}
Note:
Only one realization is assigned to each cluster for each forecast
period.
"""
compact = {}
for sec_name, fp_dict in secondary_input_realizations_to_clusters.items():
cluster_map = {}
# (cluster_idx, realization) -> list of forecast_periods
temp = {}
for fp, cluster_dict in fp_dict.items():
for cluster_idx, realizations in cluster_dict.items():
realization = tuple(realizations)
key = (cluster_idx, realization)
temp.setdefault(key, []).append(fp)
for (cluster_idx, realization), fps in temp.items():
cluster_map.setdefault(cluster_idx, []).append(
{
"realization": realization[0],
"forecast_periods": sorted(fps),
}
)
compact[sec_name] = cluster_map
return compact
[docs]
def track_secondary_realizations_to_clusters(
self,
secondary_input_realizations_to_clusters: dict[
str, dict[int, dict[int, list[int]]]
],
cluster_indices: list[int],
realization_indices: list[int],
candidate_name: str,
fp: int,
candidate_cube: Cube,
) -> None:
"""
Track which secondary realizations contributed to each cluster for a given
secondary input, forecast period, and candidate cube.
This updates the provided dictionary in-place, ensuring all keys and values
are native Python ints for serialization compatibility.
Args:
secondary_input_realizations_to_clusters: Nested dictionary to update,
mapping secondary input names to forecast periods, then to cluster
indices, then to lists of realization indices:
{
secondary_input_name: {
forecast_period: {
cluster_idx: [realization_indices]
}
}
}
cluster_indices: List of cluster indices assigned for this forecast period.
realization_indices: List of realization indices from the candidate cube
that were assigned to each cluster.
candidate_name: Name of the secondary input/model.
fp: Forecast period (in seconds).
candidate_cube: The candidate Cube from which realization indices are drawn.
Returns:
None. The dictionary is updated in-place.
"""
fp = int(fp)
if candidate_name not in secondary_input_realizations_to_clusters:
secondary_input_realizations_to_clusters[candidate_name] = {}
if fp not in secondary_input_realizations_to_clusters[candidate_name]:
secondary_input_realizations_to_clusters[candidate_name][fp] = {}
for cluster_idx, realization_idx in zip(cluster_indices, realization_indices):
# Ensure cluster_idx is a native int
cluster_idx = int(cluster_idx)
if (
cluster_idx
not in secondary_input_realizations_to_clusters[candidate_name][fp]
):
secondary_input_realizations_to_clusters[candidate_name][fp][
cluster_idx
] = []
secondary_input_realizations_to_clusters[candidate_name][fp][
cluster_idx
].append(int(candidate_cube.coord("realization").points[realization_idx]))
[docs]
def process(self, cubes: CubeList) -> Cube:
"""Cluster and match the data.
This method clusters the primary input realizations and matches secondary input
realizations to the resulting clusters, according to the specified hierarchy
and precedence.
Args:
cubes: The input CubeList containing all primary and secondary input
cubes required for clustering and matching. Each cube must have the
model_id_attr attribute set to identify its source/model. For each
model (primary and secondary), include all forecast periods and
realizations that should be considered for matching or replacement.
Expected input shapes::
2D: (y, x)
for single realization, single forecast period fields.
3D: (realization, y, x)
for multiple realizations at a single forecast period.
4D: (realization, forecast_period, y, x)
for multiple realizations and multiple forecast periods.
The leading dimension must always be realization if present.
For 4D cubes, the second dimension must be forecast_period.
Returns:
The matched cube containing all secondary inputs matched to clusters.
The output cube will have realization and forecast_period as leading
dimensions (if present in the input), followed by spatial dimensions (y, x).
The returned cube includes the following JSON string attributes:
- 'primary_input_realizations_to_clusters': tracks which primary input
realizations were assigned to each cluster.
- 'secondary_input_realizations_to_clusters': tracks which secondary input
realization was assigned to each cluster per forecast period.
- 'cluster_sources': tracks which input model provided the final data for
each cluster-forecast_period pairing.
Raises:
ValueError: If no primary cube is found with the specified
model_id_attr.
"""
if self.cycletime is not None:
for cube in cubes:
if not cube.coords("forecast_reference_time"):
continue
reset_forecast_reference_time_and_period(cube, self.cycletime)
constr = iris.AttributeConstraint(
**{self.model_id_attr: self.hierarchy["primary_input"]}
)
primary_cubes = cubes.extract(constr)
if primary_cubes:
primary_cube = MergeCubes()(primary_cubes)
enforce_coordinate_ordering(primary_cube, ["realization"])
else:
raise ValueError(
f"No primary cube found with {self.model_id_attr}="
f"{self.hierarchy['primary_input']}"
)
target_grid_cube = None
if self.regrid_for_clustering:
try:
target_grid_cube = cubes.extract_cube(self.target_grid_name)
except iris.exceptions.ConstraintMismatchError:
msg = (
f"Target grid cube '{self.target_grid_name}' not found in input "
"cubes for regridding."
)
raise ValueError(msg)
clustered_primary_cube, regridded_clustered_primary_cube = (
self.cluster_primary_input(primary_cube, target_grid_cube)
)
# Store mapping for re-application to result.
primary_input_realizations_to_clusters = clustered_primary_cube.attributes[
"primary_input_realizations_to_clusters"
]
primary_input_realization_to_cluster_medoid = clustered_primary_cube.attributes[
"primary_input_realization_to_cluster_medoid"
]
n_clusters = len(clustered_primary_cube.coord("realization").points)
# Categorise secondary inputs by whether they have full or partial realizations
full_realization_inputs, partial_realization_inputs = (
self._categorise_secondary_inputs(cubes, n_clusters, primary_cube)
)
# Check if we have any secondary inputs to process
if not full_realization_inputs and not partial_realization_inputs:
warnings.warn(
"No secondary inputs have forecast periods that overlap with the "
f"primary input '{self.hierarchy['primary_input']}'. "
"Only the clustered primary input will be returned.",
UserWarning,
)
# Track which (forecast_period, realization) pairs have been replaced
# Key: forecast_period, Value: set of realization indices that have
# been replaced
replaced_realizations = {}
# Track cluster sources: which input was used for each cluster at each
# forecast period
# Format: {cluster_idx: {model_name: [fp1, fp2, ...]}}
cluster_sources = {}
# Start with the clustered primary cube as the base for all forecast periods
# This ensures we always have a full set of realizations to work with
matched_cubes = self._initialise_matched_cubes_with_primary(
clustered_primary_cube
)
# Initialise cluster_sources with primary input for all clusters and
# forecast periods
primary_name = self.hierarchy["primary_input"]
for cluster_idx in range(n_clusters):
cluster_sources[cluster_idx] = {}
cluster_sources[cluster_idx][primary_name] = list(
clustered_primary_cube.coord("forecast_period").points
)
# Create a mapping to track which realizations from secondary inputs correspond
# to which clusters.
secondary_input_realizations_to_clusters = {}
# First pass: Process full realization inputs
# These will replace entire forecast period cubes
self._process_full_realization_inputs(
full_realization_inputs,
cubes,
target_grid_cube,
regridded_clustered_primary_cube,
replaced_realizations,
matched_cubes,
cluster_sources,
secondary_input_realizations_to_clusters,
)
# Second pass: Process partial realization inputs
# These will selectively replace specific realizations within existing cubes
self._process_partial_realization_inputs(
partial_realization_inputs,
cubes,
target_grid_cube,
regridded_clustered_primary_cube,
replaced_realizations,
matched_cubes,
cluster_sources,
secondary_input_realizations_to_clusters,
)
result_cube = MergeCubes()(
CubeList([iris.util.squeeze(c) for c in matched_cubes])
)
# Use json.dumps to store dictionary as attribute.
result_cube.attributes["primary_input_realizations_to_clusters"] = json.dumps(
primary_input_realizations_to_clusters
)
result_cube.attributes["primary_input_realization_to_cluster_medoid"] = (
json.dumps(primary_input_realization_to_cluster_medoid)
)
result_cube.attributes["secondary_input_realizations_to_clusters"] = json.dumps(
self.compact_secondary_mapping(secondary_input_realizations_to_clusters)
)
# Store cluster_sources as a cube attribute (as JSON string)
# Format: {cluster_idx: {model_name: [fp1, fp2, ...]}}
# Convert numpy int32 to native Python int for JSON serialization
cluster_sources_serialisable = {
int(k): {name: [int(fp) for fp in fps] for name, fps in v.items()}
for k, v in cluster_sources.items()
}
result_cube.attributes["cluster_sources"] = json.dumps(
cluster_sources_serialisable
)
return result_cube
[docs]
class RealizationSelection(BasePlugin):
"""Plugin to select realizations based on clustering results.
This plugin is intended to be used with the output from the
RealizationClusterAndMatch plugin. A typical use case is where
RealizationClusterAndMatch has performed clustering and matching using a
subset of forecast periods (for computational efficiency or other reasons),
but you wish to apply the resulting cluster assignments to any forecast
period. The RealizationSelection plugin enables this by selecting and
relabelling realizations from the original forecast cubes according to the
cluster mapping attributes stored in the cluster cube output by
RealizationClusterAndMatch.
To use this plugin, provide as input the same forecast cubes as were
supplied to RealizationClusterAndMatch (but strictly only at a single
forecast period), together with the cluster cube output from
RealizationClusterAndMatch.
"""
[docs]
def __init__(
self,
forecast_period: int,
model_id_attr: str = "mosg__model_configuration",
):
"""
Initialise the RealizationSelection plugin.
Args:
forecast_period: The forecast period (in seconds) to use for interrogating
the cluster mapping attributes in order to select the appropriate
realizations.
model_id_attr: The name of the cube attribute used to identify the model
source.
"""
self.forecast_period = forecast_period
self.model_id_attr = model_id_attr
[docs]
def split_cubes_forecast_and_cluster(
self, cubes: CubeList
) -> tuple[CubeList, Cube]:
"""
Split a CubeList into forecast cubes and the cluster cube.
The cluster cube is identified by the presence of the
"primary_input_realization_to_cluster_medoid" attribute.
The forecast cubes are assumed to be the cubes without such an attribute and
that share a common validity time.
Args:
cubes: CubeList of input cubes expected to contain forecast
cubes and a cluster cube.
Returns:
Tuple of (forecast_cubes, cluster_cube):
- forecast_cubes: CubeList of forecast cubes.
- cluster_cube: The cluster cube.
Raises:
ValueError: If no cluster cube is found.
"""
cluster_cube = None
forecast_cubes = CubeList()
for cube in cubes:
if "primary_input_realization_to_cluster_medoid" in cube.attributes:
cluster_cube = cube
else:
forecast_cubes.append(cube)
if cluster_cube is None:
raise ValueError(
"No cluster cube found in input cubes "
"(missing 'primary_input_realization_to_cluster_medoid' attribute)."
)
return forecast_cubes, cluster_cube
[docs]
def parse_mapping_attributes(
self, cluster_cube: Cube
) -> tuple[dict[str, int], dict[str, dict[int, list[dict[str, list[int]]]]]]:
"""
Parse and decode the mapping attributes from the cluster cube.
Args:
cluster_cube: The cube output from RealizationClusterAndMatch, containing
the mapping attributes as JSON-encoded strings.
Returns:
A tuple containing:
- primary_map: Dictionary mapping cluster index (as string)
to medoid realization index (int).
- secondary_map: Dictionary mapping secondary input names to
cluster mappings, where each cluster index maps to a list of
dicts with "realizations" and "forecast_periods".
Raises:
TypeError: If the mapping attributes are not in the expected format
(str or dict).
"""
primary_map = cluster_cube.attributes.get(
"primary_input_realization_to_cluster_medoid"
)
secondary_map = cluster_cube.attributes.get(
"secondary_input_realizations_to_clusters"
)
if isinstance(primary_map, str):
primary_map = json.loads(primary_map)
elif not isinstance(primary_map, dict):
raise TypeError(
f"Expected primary_map to be str or dict, got {type(primary_map)}"
)
if isinstance(secondary_map, str):
secondary_map = json.loads(secondary_map)
elif secondary_map is not None and not isinstance(secondary_map, dict):
raise TypeError(
"Expected secondary_map to be str, dict, or None, "
f"got {type(secondary_map)}"
)
return primary_map, secondary_map
[docs]
def validate_common_validity_time(self, forecast_cubes: CubeList) -> None:
"""
Validate that all forecast cubes share a common validity time.
Args:
forecast_cubes: CubeList of forecast cubes.
Raises:
ValueError: If forecast cubes do not share a common validity time.
"""
unique_validity_times = {
cube.coord("time").cell(0).point for cube in forecast_cubes
}
if len(unique_validity_times) > 1:
raise ValueError(
"Forecast cubes must share a common validity time (time coordinate)."
)
[docs]
def find_nearest_secondary_mapping_fp(
self, mapping_fps: Optional[set[int]], fp: int
) -> tuple[int, bool]:
"""
Find the nearest forecast period in the secondary mapping to the requested
forecast period.
Args:
mapping_fps: Set of forecast periods (in seconds) available in the
secondary mapping.
fp: The forecast period (in seconds) for which to find the nearest mapping.
Returns:
A tuple containing:
- nearest_fp: The forecast period from mapping_fps closest to fp
(or fp if mapping_fps is empty).
- use_secondary: Boolean indicating whether the secondary mapping
should be used (True if fp is less than or equal to the maximum
in mapping_fps, else False).
"""
if mapping_fps:
nearest_fp = min(mapping_fps, key=lambda x: abs(x - fp))
use_secondary = fp <= max(mapping_fps)
else:
nearest_fp = fp
use_secondary = False
return nearest_fp, use_secondary
[docs]
def build_cluster_to_selection(
self,
nearest_fp: int,
use_secondary: bool,
secondary_map: dict[str, dict[int, list[dict[str, list[int]]]]],
primary_map: dict[str, int],
cluster_cube: Cube,
) -> dict[int, tuple[str, int]]:
"""
Build a mapping from cluster index to (model name, realization index)
for selection.
Args:
nearest_fp: The forecast period (in seconds) from the secondary mapping
closest to the requested forecast period.
use_secondary: Whether to use the secondary mapping (True) or fall back
to the primary mapping (False). Determined by
find_nearest_secondary_mapping_fp method.
secondary_map: Dictionary mapping secondary input names to cluster mappings,
where each cluster index maps to a list of dicts with "realization"
and "forecast_periods".
primary_map: Dictionary mapping cluster index (as string) to medoid
realization index (int).
cluster_cube: The cluster cube output from RealizationClusterAndMatch,
containing the cluster mapping attributes. Used to determine the model
name for the primary input when assigning realizations to clusters.
Returns:
Dictionary mapping cluster index (int) to a tuple of
(model name, realization index).
"""
cluster_to_selection = {}
# Use secondary mapping if appropriate
if use_secondary and secondary_map:
for model_name, cluster_dict in secondary_map.items():
for cluster_idx_str, cluster_list in cluster_dict.items():
cluster_idx = int(cluster_idx_str)
for entry in cluster_list:
if nearest_fp in entry["forecast_periods"]:
cluster_to_selection[cluster_idx] = (
model_name,
entry["realization"],
)
# Fill in any clusters not covered by secondary inputs using the medoid mapping
for cluster_idx_str, realization in primary_map.items():
cluster_idx = int(cluster_idx_str)
if cluster_idx not in cluster_to_selection:
cluster_to_selection[cluster_idx] = (
cluster_cube.attributes.get(
"mosg__model_configuration", "primary_input"
),
realization,
)
return cluster_to_selection
[docs]
def select_realizations_for_clusters(
self,
cluster_to_selection: dict[int, tuple[str, int]],
forecast_cubes: CubeList,
) -> list[Cube]:
"""
Select and relabel realizations from the input cubes according to the
cluster-to-selection mapping.
Args:
cluster_to_selection: Dictionary mapping cluster index (int) to a
tuple of (model name, realization index).
forecast_cubes: CubeList of input forecast cubes, each for a single
forecast period.
Returns:
A list of Cube objects, each containing a single realization relabelled
to the cluster index.
Raises:
ValueError: If no forecast cube is found for a specified model name.
"""
selected_cubes = []
for cluster_idx in sorted(cluster_to_selection):
model_name, realization_index = cluster_to_selection[cluster_idx]
model_cubes = forecast_cubes.extract(
iris.AttributeConstraint(**{self.model_id_attr: model_name})
)
if not model_cubes:
raise ValueError(f"No forecast cube found for model '{model_name}'")
model_cube = model_cubes[0]
selected = model_cube.extract(
iris.Constraint(realization=realization_index)
)
selected.coord("realization").points = [cluster_idx]
selected_cubes.append(selected)
return selected_cubes
[docs]
def process(self, cubes: CubeList) -> Cube:
"""
Select realizations from input forecast cubes according to cluster assignments
defined by the cluster_cube's attributes.
Args:
cubes (list of Cube): List of input cubes, including forecast
cubes and a cluster cube. The forecast cubes are from all source models
for a common validity time and with each containing a "realization"
coordinate that contributed to the clustering. Each cube must have the
model_id_attr attribute set to identify its source model. The cluster
cube is output from RealizationClusterAndMatch, containing
the cluster mapping attributes. The cluster cube is identified by the
presence of the "primary_input_realization_to_cluster_medoid" attribute.
Returns:
A merged Cube containing the selected realizations, with realization
indices matching the cluster indices in cluster_cube.
"""
forecast_cubes, cluster_cube = self.split_cubes_forecast_and_cluster(cubes)
self.validate_common_validity_time(forecast_cubes)
primary_map, secondary_map = self.parse_mapping_attributes(cluster_cube)
mapping_fps = set()
if secondary_map:
for cluster_dict in secondary_map.values():
for cluster_list in cluster_dict.values():
for entry in cluster_list:
mapping_fps.update(entry["forecast_periods"])
nearest_fp, use_secondary = self.find_nearest_secondary_mapping_fp(
mapping_fps, self.forecast_period
)
cluster_to_selection = self.build_cluster_to_selection(
nearest_fp, use_secondary, secondary_map, primary_map, cluster_cube
)
selected_cubes = self.select_realizations_for_clusters(
cluster_to_selection, forecast_cubes
)
result_cube = MergeCubes()(CubeList(selected_cubes))
return result_cube