Coverage for mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py: 100%

109 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tests for mlos_bench.services.remote.azure.azure_vm_services 

7""" 

8 

9from copy import deepcopy 

10from unittest.mock import MagicMock, patch 

11 

12import pytest 

13import requests.exceptions as requests_ex 

14 

15from mlos_bench.environments.status import Status 

16 

17from mlos_bench.services.remote.azure.azure_auth import AzureAuthService 

18from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService 

19 

20from mlos_bench.tests.services.remote.azure import make_httplib_json_response 

21 

22 

23@pytest.mark.parametrize( 

24 ("total_retries", "operation_status"), [ 

25 (2, Status.SUCCEEDED), 

26 (1, Status.FAILED), 

27 (0, Status.FAILED), 

28 ]) 

29@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") 

30def test_wait_host_deployment_retry(mock_getconn: MagicMock, 

31 total_retries: int, 

32 operation_status: Status, 

33 azure_vm_service: AzureVMService) -> None: 

34 """ 

35 Test retries of the host deployment operation. 

36 """ 

37 # Simulate intermittent connection issues with multiple connection errors 

38 # Sufficient retry attempts should result in success, otherwise a graceful failure state 

39 mock_getconn.return_value.getresponse.side_effect = [ 

40 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), 

41 requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), 

42 requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), 

43 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), 

44 make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), 

45 ] 

46 

47 (status, _) = azure_vm_service.wait_host_deployment( 

48 params={ 

49 "pollInterval": 0.1, 

50 "requestTotalRetries": total_retries, 

51 "deploymentName": "TEST_DEPLOYMENT1", 

52 "subscription": "TEST_SUB1", 

53 "resourceGroup": "TEST_RG1", 

54 }, 

55 is_setup=True) 

56 assert status == operation_status 

57 

58 

59def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: 

60 """ 

61 Test expanding template params recursively. 

62 """ 

63 config = { 

64 "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", 

65 "subscription": "TEST_SUB1", 

66 "resourceGroup": "TEST_RG1", 

67 "deploymentTemplateParameters": { 

68 "location": "$location", 

69 "vmMeta": "$vmName-$location", 

70 "vmNsg": "$vmMeta-nsg", 

71 }, 

72 } 

73 global_config = { 

74 "deploymentName": "TEST_DEPLOYMENT1", 

75 "vmName": "test-vm", 

76 "location": "eastus", 

77 } 

78 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) 

79 assert azure_vm_service.deploy_params["location"] == global_config["location"] 

80 assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}' 

81 assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' 

82 

83 

84def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: 

85 """ 

86 Test loading custom data from a file. 

87 """ 

88 config = { 

89 "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml", 

90 "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", 

91 "subscription": "TEST_SUB1", 

92 "resourceGroup": "TEST_RG1", 

93 "deploymentTemplateParameters": { 

94 "location": "eastus2", 

95 }, 

96 } 

97 global_config = { 

98 "deploymentName": "TEST_DEPLOYMENT1", 

99 "vmName": "test-vm", 

100 } 

101 with pytest.raises(ValueError): 

102 config_with_custom_data = deepcopy(config) 

103 config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index] 

104 AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) 

105 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) 

106 assert azure_vm_service.deploy_params['customData'] 

107 

108 

109@pytest.mark.parametrize( 

110 ("operation_name", "accepts_params"), [ 

111 ("start_host", True), 

112 ("stop_host", True), 

113 ("shutdown", True), 

114 ("deprovision_host", True), 

115 ("deallocate_host", True), 

116 ("restart_host", True), 

117 ("reboot", True), 

118 ]) 

119@pytest.mark.parametrize( 

120 ("http_status_code", "operation_status"), [ 

121 (200, Status.SUCCEEDED), 

122 (202, Status.PENDING), 

123 (401, Status.FAILED), 

124 (404, Status.FAILED), 

125 ]) 

126@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") 

127# pylint: disable=too-many-arguments 

128def test_vm_operation_status(mock_requests: MagicMock, 

129 azure_vm_service: AzureVMService, 

130 operation_name: str, 

131 accepts_params: bool, 

132 http_status_code: int, 

133 operation_status: Status) -> None: 

134 """ 

135 Test VM operation status. 

136 """ 

137 mock_response = MagicMock() 

138 mock_response.status_code = http_status_code 

139 mock_requests.post.return_value = mock_response 

140 

141 operation = getattr(azure_vm_service, operation_name) 

142 with pytest.raises(ValueError): 

143 # Missing vmName should raise ValueError 

144 (status, _) = operation({}) if accepts_params else operation() 

145 (status, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() 

146 assert status == operation_status 

147 

148 

149@pytest.mark.parametrize( 

150 ("operation_name", "accepts_params"), [ 

151 ("provision_host", True), 

152 ]) 

153def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, 

154 operation_name: str, 

155 accepts_params: bool) -> None: 

156 """ 

157 Test VM operation status for an incomplete service config. 

158 """ 

159 operation = getattr(azure_vm_service_remote_exec_only, operation_name) 

160 with pytest.raises(ValueError): 

161 (_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() 

162 

163 

164@patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") 

165@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") 

166def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, 

167 azure_vm_service: AzureVMService) -> None: 

168 """ 

169 Test waiting for the completion of the remote VM operation. 

170 """ 

171 # Mock response header 

172 async_url = "DUMMY_ASYNC_URL" 

173 retry_after = 12345 

174 params = { 

175 "asyncResultsUrl": async_url, 

176 "vmName": "test-vm", 

177 "pollInterval": retry_after, 

178 } 

179 

180 mock_status_response = MagicMock(status_code=200) 

181 mock_status_response.json.return_value = { 

182 "status": "Succeeded", 

183 } 

184 mock_session.return_value.get.return_value = mock_status_response 

185 

186 status, _ = azure_vm_service.wait_host_operation(params) 

187 

188 assert (async_url, ) == mock_session.return_value.get.call_args[0] 

189 assert (retry_after, ) == mock_sleep.call_args[0] 

190 assert status.is_succeeded() 

191 

192 

193@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") 

194def test_wait_vm_operation_timeout(mock_session: MagicMock, 

195 azure_vm_service: AzureVMService) -> None: 

196 """ 

197 Test the time out of the remote VM operation. 

198 """ 

199 # Mock response header 

200 params = { 

201 "asyncResultsUrl": "DUMMY_ASYNC_URL", 

202 "vmName": "test-vm", 

203 "pollInterval": 1 

204 } 

205 

206 mock_status_response = MagicMock(status_code=200) 

207 mock_status_response.json.return_value = { 

208 "status": "InProgress", 

209 } 

210 mock_session.return_value.get.return_value = mock_status_response 

211 

212 (status, _) = azure_vm_service.wait_host_operation(params) 

213 assert status == Status.TIMED_OUT 

214 

215 

216@pytest.mark.parametrize( 

217 ("total_retries", "operation_status"), [ 

218 (2, Status.SUCCEEDED), 

219 (1, Status.FAILED), 

220 (0, Status.FAILED), 

221 ]) 

222@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") 

223def test_wait_vm_operation_retry(mock_getconn: MagicMock, 

224 total_retries: int, 

225 operation_status: Status, 

226 azure_vm_service: AzureVMService) -> None: 

227 """ 

228 Test the retries of the remote VM operation. 

229 """ 

230 # Simulate intermittent connection issues with multiple connection errors 

231 # Sufficient retry attempts should result in success, otherwise a graceful failure state 

232 mock_getconn.return_value.getresponse.side_effect = [ 

233 make_httplib_json_response(200, {"status": "InProgress"}), 

234 requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), 

235 requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), 

236 make_httplib_json_response(200, {"status": "InProgress"}), 

237 make_httplib_json_response(200, {"status": "Succeeded"}), 

238 ] 

239 

240 (status, _) = azure_vm_service.wait_host_operation( 

241 params={ 

242 "pollInterval": 0.1, 

243 "requestTotalRetries": total_retries, 

244 "asyncResultsUrl": "https://DUMMY_ASYNC_URL", 

245 "vmName": "test-vm", 

246 }) 

247 assert status == operation_status 

248 

249 

250@pytest.mark.parametrize( 

251 ("http_status_code", "operation_status"), [ 

252 (200, Status.SUCCEEDED), 

253 (202, Status.PENDING), 

254 (401, Status.FAILED), 

255 (404, Status.FAILED), 

256 ]) 

257@patch("mlos_bench.services.remote.azure.azure_vm_services.requests") 

258def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, 

259 http_status_code: int, operation_status: Status) -> None: 

260 """ 

261 Test waiting for completion of the remote execution on Azure. 

262 """ 

263 script = ["command_1", "command_2"] 

264 

265 mock_response = MagicMock() 

266 mock_response.status_code = http_status_code 

267 mock_response.json = MagicMock(return_value={ 

268 "fake response": "body as json to dict", 

269 }) 

270 mock_requests.post.return_value = mock_response 

271 

272 status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={}) 

273 

274 assert status == operation_status 

275 

276 

277@patch("mlos_bench.services.remote.azure.azure_vm_services.requests") 

278def test_remote_exec_headers_output(mock_requests: MagicMock, 

279 azure_vm_service_remote_exec_only: AzureVMService) -> None: 

280 """ 

281 Check if HTTP headers from the remote execution on Azure are correct. 

282 """ 

283 async_url_key = "asyncResultsUrl" 

284 async_url_value = "DUMMY_ASYNC_URL" 

285 script = ["command_1", "command_2"] 

286 

287 mock_response = MagicMock() 

288 mock_response.status_code = 202 

289 mock_response.headers = { 

290 "Azure-AsyncOperation": async_url_value 

291 } 

292 mock_response.json = MagicMock(return_value={ 

293 "fake response": "body as json to dict", 

294 }) 

295 mock_requests.post.return_value = mock_response 

296 

297 _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={ 

298 "param_1": 123, 

299 "param_2": "abc", 

300 }) 

301 

302 assert async_url_key in cmd_output 

303 assert cmd_output[async_url_key] == async_url_value 

304 

305 assert mock_requests.post.call_args[1]["json"] == { 

306 "commandId": "RunShellScript", 

307 "script": script, 

308 "parameters": [ 

309 {"name": "param_1", "value": 123}, 

310 {"name": "param_2", "value": "abc"} 

311 ] 

312 } 

313 

314 

315@pytest.mark.parametrize( 

316 ("operation_status", "wait_output", "results_output"), [ 

317 ( 

318 Status.SUCCEEDED, 

319 { 

320 "properties": { 

321 "output": { 

322 "value": [ 

323 {"message": "DUMMY_STDOUT_STDERR"}, 

324 ] 

325 } 

326 } 

327 }, 

328 {"stdout": "DUMMY_STDOUT_STDERR"} 

329 ), 

330 (Status.PENDING, {}, {}), 

331 (Status.FAILED, {}, {}), 

332 ]) 

333def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, 

334 wait_output: dict, results_output: dict) -> None: 

335 """ 

336 Test getting the results of the remote execution on Azure. 

337 """ 

338 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"} 

339 

340 mock_wait_host_operation = MagicMock() 

341 mock_wait_host_operation.return_value = (operation_status, wait_output) 

342 # azure_vm_service.wait_host_operation = mock_wait_host_operation 

343 setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation) 

344 

345 status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params) 

346 

347 assert status == operation_status 

348 assert mock_wait_host_operation.call_args[0][0] == params 

349 assert cmd_output == results_output