Coverage for mlos_bench/mlos_bench/tests/storage/sql/__init__.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-14 00:55 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench sql storage.""" 

6 

7from dataclasses import dataclass 

8from subprocess import run 

9 

10# The DB servers' names and other connection info. 

11# See Also: docker-compose.yml 

12 

13MYSQL_TEST_SERVER_NAME = "mysql-mlos-bench-server" 

14PGSQL_TEST_SERVER_NAME = "postgres-mlos-bench-server" 

15 

16SQL_TEST_SERVER_DATABASE = "mlos_bench" 

17SQL_TEST_SERVER_PASSWORD = "password" 

18 

19 

20@dataclass 

21class SqlTestServerInfo: 

22 """ 

23 A data class for SqlTestServerInfo. 

24 

25 See Also 

26 -------- 

27 mlos_bench.tests.services.remote.ssh.SshTestServerInfo 

28 """ 

29 

30 compose_project_name: str 

31 service_name: str 

32 hostname: str 

33 _port: int | None = None 

34 

35 @property 

36 def username(self) -> str: 

37 """Gets the username.""" 

38 usernames = { 

39 MYSQL_TEST_SERVER_NAME: "root", 

40 PGSQL_TEST_SERVER_NAME: "postgres", 

41 } 

42 return usernames[self.service_name] 

43 

44 @property 

45 def password(self) -> str: 

46 """Gets the password.""" 

47 return SQL_TEST_SERVER_PASSWORD 

48 

49 @property 

50 def database(self) -> str: 

51 """Gets the database.""" 

52 return SQL_TEST_SERVER_DATABASE 

53 

54 def get_port(self, uncached: bool = False) -> int: 

55 """ 

56 Gets the port that the DB test server is listening on. 

57 

58 Note: this value can change when the service restarts so we can't rely on 

59 the DockerServices. 

60 """ 

61 if self._port is None or uncached: 

62 default_ports = { 

63 MYSQL_TEST_SERVER_NAME: 3306, 

64 PGSQL_TEST_SERVER_NAME: 5432, 

65 } 

66 default_port = default_ports[self.service_name] 

67 port_cmd = run( 

68 ( 

69 f"docker compose -p {self.compose_project_name} " 

70 f"port {self.service_name} {default_port}" 

71 ), 

72 shell=True, 

73 check=True, 

74 capture_output=True, 

75 ) 

76 self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) 

77 return self._port