ray.train.trainer.BaseTrainer.restore
ray.train.trainer.BaseTrainer.restore#
- classmethod BaseTrainer.restore(path: Union[str, os.PathLike], storage_filesystem: Optional[pyarrow._fs.FileSystem] = None, datasets: Optional[Dict[str, Union[Dataset, Callable[[], Dataset]]]] = None, preprocessor: Optional[Preprocessor] = None, scaling_config: Optional[ray.train.ScalingConfig] = None, **kwargs) BaseTrainer[source]#
Restores a Train experiment from a previously interrupted/failed run.
Restore should be used for experiment-level fault tolerance in the event that the head node crashes (e.g., OOM or some other runtime error) or the entire cluster goes down (e.g., network error affecting all nodes).
The following example can be paired with implementing job retry using Ray Jobs to produce a Train experiment that will attempt to resume on both experiment-level and trial-level failures:
import os import ray from ray import train from ray.train.trainer import BaseTrainer experiment_name = "unique_experiment_name" storage_path = os.path.expanduser("~/ray_results") experiment_dir = os.path.join(storage_path, experiment_name) # Define some dummy inputs for demonstration purposes datasets = {"train": ray.data.from_items([{"a": i} for i in range(10)])} class CustomTrainer(BaseTrainer): def training_loop(self): pass if CustomTrainer.can_restore(experiment_dir): trainer = CustomTrainer.restore( experiment_dir, datasets=datasets ) else: trainer = CustomTrainer( datasets=datasets, run_config=train.RunConfig( name=experiment_name, storage_path=storage_path, # Tip: You can also enable retries on failure for # worker-level fault tolerance failure_config=train.FailureConfig(max_failures=3), ), ) result = trainer.fit()
- Parameters
path – The path to the experiment directory of the training run to restore. This can be a local path or a remote URI if the experiment was uploaded to the cloud.
datasets – Re-specified datasets used in the original training run. This must include all the datasets that were passed in the original trainer constructor.
scaling_config – Optionally re-specified scaling config. This can be modified to be different from the original spec.
**kwargs – Other optionally re-specified arguments, passed in by subclasses.
- Raises
ValueError – If all datasets were not re-supplied on restore.
- Returns
A restored instance of the class that is calling this method.
- Return type