Source code for improver.calibration.rainforest_compilation

# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.

"""RainForests model compilation plugin."""

import pathlib
from pathlib import Path

from improver import BasePlugin
from improver.calibration import treelite_packages_available

LIGHTGBM_EXTENSION = ".txt"
TREELITE_EXTENSION = ".so"


[docs] class CompileRainForestsModel(BasePlugin): """Class to compile RainForests tree models"""
[docs] def __init__( self, model_config_dict: dict[int, dict[str, dict[str, str]]], toolchain: str = "gcc", verbose: bool = False, parallel_comp: int = 0, ) -> None: """Initialise the options used when compiling models. Args: model_config_dict: Dictionary describing the high-level RainForests model structure; - top level key describes the lead-hour, - next level key describes the threshold, - corresponding values locate the associated model file. toolchain: Toolchain to use for Treelite model compilation. 'gcc' (default), 'msvc', 'clang' or a specific variation of clang or gcc (e.g. 'gcc-7'). verbose: Print verbose output during compilation parallel_comp: Enables parallel compilation to reduce time and memory consumption. Value is the number of processes to use. Defaults to 0 (no parallel compilation) Dictionary is of format:: { "24": { "0.000010": { "lightgbm_model": "<path_to_lightgbm_model_object>", "treelite_model": "<path_to_treelite_model_object>" }, "0.000050": { "lightgbm_model": "<path_to_lightgbm_model_object>", "treelite_model": "<path_to_treelite_model_object>" }, "0.000100": { "lightgbm_model": "<path_to_lightgbm_model_object>", "treelite_model": "<path_to_treelite_model_object>" }, } The keys specify the lead times and model threshold values, while the associated values are the path to the corresponding tree-model objects for that lead time and threshold. """ treelite_available = treelite_packages_available() if not treelite_available: raise ModuleNotFoundError("Could not find TreeLite module") self.config = model_config_dict self.toolchain = toolchain self.verbose = verbose self.treelight_params = {"parallel_comp": parallel_comp, "quantize": 1}
[docs] def process(self, allow_missing: bool = False) -> None: """Compile all configured LightGBM models with Treelite. Args: allow_missing: If False (default), throws an error if any LightGBM models are missing. If True, any missing LightGBM files will be ignored. Iterates through all lead times and thresholds in the model config dictionary and compiles the corresponding LightGBM models to Treelite predictors. """ if not allow_missing: # Validate models have been trained before compiling any. missing_model_paths = [ model_path for lead_time in self.config.values() for threshold in lead_time.values() if not (Path(model_path := threshold["lightgbm_model"])).is_file() ] if missing_model_paths: raise ValueError(f"Model file(s) not found: {missing_model_paths}") for lead_time in self.config.values(): for threshold in lead_time.values(): self._compile_model( Path(threshold["lightgbm_model"]), Path(threshold["treelite_model"]), )
[docs] def _compile_model( self, lightgbm_path: pathlib.Path, output_path: pathlib.Path ) -> None: """Compile a lightgbm model with Treelite. Args: lightgbm_path: Path to LightGBM Booster file. output_path: Path where the compiled Treelite predictor file will be created. """ import tl2cgen import treelite # Validate both paths if not lightgbm_path.is_file(): return if lightgbm_path.suffix.lower() != LIGHTGBM_EXTENSION: raise ValueError(f"Input path must have extension {LIGHTGBM_EXTENSION}") if output_path.suffix.lower() != TREELITE_EXTENSION: raise ValueError(f"Output path must have extension {TREELITE_EXTENSION}") Path.mkdir(output_path.parent, parents=True, exist_ok=True) model = treelite.frontend.load_lightgbm_model(lightgbm_path) tl2cgen.export_lib( model, libpath=output_path, toolchain=self.toolchain, verbose=self.verbose, params=self.treelight_params, )