Coverage for mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for RemoveEnv benchmark environment via local SSH test services. 

7""" 

8 

9from typing import Dict 

10 

11import os 

12import sys 

13 

14import numpy as np 

15 

16import pytest 

17 

18from mlos_bench.services.config_persistence import ConfigPersistenceService 

19from mlos_bench.tunables.tunable import TunableValue 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21 

22from mlos_bench.tests import requires_docker 

23from mlos_bench.tests.environments import check_env_success 

24from mlos_bench.tests.services.remote.ssh import SshTestServerInfo 

25 

26if sys.version_info < (3, 10): 

27 from importlib_resources import files 

28else: 

29 from importlib.resources import files 

30 

31 

32@requires_docker 

33def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: 

34 """ 

35 Produce benchmark and telemetry data in a local script and read it. 

36 """ 

37 global_config: Dict[str, TunableValue] = { 

38 "ssh_hostname": ssh_test_server.hostname, 

39 "ssh_port": ssh_test_server.get_port(), 

40 "ssh_username": ssh_test_server.username, 

41 "ssh_priv_key_path": ssh_test_server.id_rsa_path, 

42 } 

43 

44 service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]}) 

45 config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") 

46 env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service) 

47 

48 check_env_success( 

49 env, env.tunable_params, 

50 expected_results={ 

51 "hostname": ssh_test_server.service_name, 

52 "username": ssh_test_server.username, 

53 "score": 0.9, 

54 "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" 

55 "test_param": "unset", 

56 "FOO": "unset", 

57 "ssh_username": "unset", 

58 }, 

59 expected_telemetry=[], 

60 ) 

61 assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \ 

62 "output-downloaded.csv should have been cleaned up by temp_dir context" 

63 

64 

65if __name__ == "__main__": 

66 pytest.main(["-n1", __file__])