Coverage for mlos_bench/mlos_bench/tests/storage/test_storage_pickling.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Test pickling and unpickling of Storage, and restoring Experiment and Trial by id.""" 

6import pickle 

7import sys 

8from datetime import datetime 

9from typing import Literal 

10 

11import pytest 

12from pytz import UTC 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.storage.base_storage import Storage 

16from mlos_bench.tests.storage.sql.fixtures import PERSISTENT_SQL_STORAGE_FIXTURES 

17from mlos_bench.tunables.tunable_groups import TunableGroups 

18 

19 

20# TODO: When we introduce ParallelTrialScheduler warn at config startup time 

21# that it is incompatible with sqlite storage on Windows. 

22@pytest.mark.skipif( 

23 sys.platform == "win32", 

24 reason="Windows doesn't support multiple processes accessing the same file.", 

25) 

26@pytest.mark.parametrize( 

27 "persistent_storage", 

28 [ 

29 # TODO: Improve this test to support non-sql backends eventually as well. 

30 *PERSISTENT_SQL_STORAGE_FIXTURES, 

31 ], 

32) 

33def test_storage_pickle_restore_experiment_and_trial( 

34 persistent_storage: Storage, 

35 tunable_groups: TunableGroups, 

36) -> None: 

37 """Check that we can pickle and unpickle the Storage object, and restore Experiment 

38 and Trial by id. 

39 """ 

40 storage = persistent_storage 

41 storage_class = storage.__class__ 

42 assert issubclass(storage_class, Storage) 

43 assert storage_class != Storage 

44 # Create an Experiment and a Trial 

45 opt_targets: dict[str, Literal["min", "max"]] = {"metric": "min"} 

46 experiment = storage.experiment( 

47 experiment_id="experiment_id", 

48 trial_id=0, 

49 root_env_config="dummy_env.json", 

50 description="Pickle test experiment", 

51 tunables=tunable_groups, 

52 opt_targets=opt_targets, 

53 ) 

54 with experiment: 

55 trial = experiment.new_trial(tunable_groups) 

56 trial_id_created = trial.trial_id 

57 trial.set_trial_runner(1) 

58 trial.update(Status.RUNNING, datetime.now(UTC)) 

59 

60 # Pickle and unpickle the Storage object 

61 pickled = pickle.dumps(storage) 

62 restored_storage = pickle.loads(pickled) 

63 assert isinstance(restored_storage, storage_class) 

64 

65 # Restore the Experiment from storage by id and check that it matches the original 

66 restored_experiment = restored_storage.get_experiment_by_id( 

67 experiment_id=experiment.experiment_id, 

68 tunables=tunable_groups, 

69 opt_targets=opt_targets, 

70 ) 

71 assert restored_experiment is not None 

72 assert restored_experiment is not experiment 

73 assert restored_experiment.experiment_id == experiment.experiment_id 

74 assert restored_experiment.description == experiment.description 

75 assert restored_experiment.root_env_config == experiment.root_env_config 

76 assert restored_experiment.tunables == experiment.tunables 

77 assert restored_experiment.opt_targets == experiment.opt_targets 

78 with restored_experiment: 

79 # trial_id should have been restored during __enter__ 

80 assert restored_experiment.trial_id == experiment.trial_id 

81 

82 # Restore the Trial from storage by id and check that it matches the original 

83 restored_trial = restored_experiment.get_trial_by_id(trial_id_created) 

84 assert restored_trial is not None 

85 assert restored_trial is not trial 

86 assert restored_trial.trial_id == trial.trial_id 

87 assert restored_trial.experiment_id == trial.experiment_id 

88 assert restored_trial.tunables == trial.tunables 

89 assert restored_trial.status == trial.status 

90 assert restored_trial.config() == trial.config() 

91 assert restored_trial.trial_runner_id == trial.trial_runner_id