Coverage for mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for RemoveEnv benchmark environment via local SSH test services.""" 

6 

7import os 

8import sys 

9from typing import Dict 

10 

11import numpy as np 

12import pytest 

13 

14from mlos_bench.services.config_persistence import ConfigPersistenceService 

15from mlos_bench.tests import requires_docker 

16from mlos_bench.tests.environments import check_env_success 

17from mlos_bench.tests.services.remote.ssh import SshTestServerInfo 

18from mlos_bench.tunables.tunable import TunableValue 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21if sys.version_info < (3, 10): 

22 from importlib_resources import files 

23else: 

24 from importlib.resources import files 

25 

26 

27@requires_docker 

28def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: 

29 """Produce benchmark and telemetry data in a local script and read it.""" 

30 global_config: Dict[str, TunableValue] = { 

31 "ssh_hostname": ssh_test_server.hostname, 

32 "ssh_port": ssh_test_server.get_port(), 

33 "ssh_username": ssh_test_server.username, 

34 "ssh_priv_key_path": ssh_test_server.id_rsa_path, 

35 } 

36 

37 service = ConfigPersistenceService( 

38 config={"config_path": [str(files("mlos_bench.tests.config"))]} 

39 ) 

40 config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") 

41 env = service.load_environment( 

42 config_path, 

43 TunableGroups(), 

44 global_config=global_config, 

45 service=service, 

46 ) 

47 

48 check_env_success( 

49 env, 

50 env.tunable_params, 

51 expected_results={ 

52 "hostname": ssh_test_server.service_name, 

53 "username": ssh_test_server.username, 

54 "score": 0.9, 

55 "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" 

56 "test_param": "unset", 

57 "FOO": "unset", 

58 "ssh_username": "unset", 

59 }, 

60 expected_telemetry=[], 

61 ) 

62 assert not os.path.exists( 

63 os.path.join(os.getcwd(), "output-downloaded.csv") 

64 ), "output-downloaded.csv should have been cleaned up by temp_dir context" 

65 

66 

67if __name__ == "__main__": 

68 pytest.main(["-n1", __file__])