Coverage for mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py: 100%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench.services.remote.azure.azure_vm_services.""" 

6 

7from copy import deepcopy 

8from datetime import datetime, timezone 

9from unittest.mock import MagicMock, patch 

10 

11import pytest 

12import requests.exceptions as requests_ex 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.services.remote.azure.azure_auth import AzureAuthService 

16from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService 

17from mlos_bench.tests.services.remote.azure import make_httplib_json_response 

18 

19 

20@pytest.mark.parametrize( 

21 ("total_retries", "operation_status"), 

22 [ 

23 (2, Status.SUCCEEDED), 

24 (1, Status.FAILED), 

25 (0, Status.FAILED), 

26 ], 

27) 

28@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") 

29def test_wait_host_deployment_retry( 

30 mock_getconn: MagicMock, 

31 total_retries: int, 

32 operation_status: Status, 

33 azure_vm_service: AzureVMService, 

34) -> None: 

35 """Test retries of the host deployment operation.""" 

36 # Simulate intermittent connection issues with multiple connection errors 

37 # Sufficient retry attempts should result in success, otherwise a graceful failure state 

38 mock_getconn.return_value.getresponse.side_effect = [ 

39 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), 

40 requests_ex.ConnectionError( 

41 "Connection aborted", 

42 OSError(107, "Transport endpoint is not connected"), 

43 ), 

44 requests_ex.ConnectionError( 

45 "Connection aborted", 

46 OSError(107, "Transport endpoint is not connected"), 

47 ), 

48 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), 

49 make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), 

50 ] 

51 

52 (status, _) = azure_vm_service.wait_host_deployment( 

53 params={ 

54 "pollInterval": 0.1, 

55 "requestTotalRetries": total_retries, 

56 "deploymentName": "TEST_DEPLOYMENT1", 

57 "subscription": "TEST_SUB1", 

58 "resourceGroup": "TEST_RG1", 

59 }, 

60 is_setup=True, 

61 ) 

62 assert status == operation_status 

63 

64 

65def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: 

66 """Test expanding template params recursively.""" 

67 config = { 

68 "deploymentTemplatePath": ( 

69 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" 

70 ), 

71 "subscription": "TEST_SUB1", 

72 "resourceGroup": "TEST_RG1", 

73 "deploymentTemplateParameters": { 

74 "location": "$location", 

75 "vmMeta": "$vmName-$location", 

76 "vmNsg": "$vmMeta-nsg", 

77 }, 

78 } 

79 global_config = { 

80 "deploymentName": "TEST_DEPLOYMENT1", 

81 "vmName": "test-vm", 

82 "location": "eastus", 

83 } 

84 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) 

85 assert azure_vm_service.deploy_params["location"] == global_config["location"] 

86 assert ( 

87 azure_vm_service.deploy_params["vmMeta"] 

88 == f'{global_config["vmName"]}-{global_config["location"]}' 

89 ) 

90 assert ( 

91 azure_vm_service.deploy_params["vmNsg"] 

92 == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' 

93 ) 

94 

95 

96def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: 

97 """Test loading custom data from a file.""" 

98 config = { 

99 "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml", 

100 "deploymentTemplatePath": ( 

101 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" 

102 ), 

103 "subscription": "TEST_SUB1", 

104 "resourceGroup": "TEST_RG1", 

105 "deploymentTemplateParameters": { 

106 "location": "eastus2", 

107 }, 

108 } 

109 global_config = { 

110 "deploymentName": "TEST_DEPLOYMENT1", 

111 "vmName": "test-vm", 

112 } 

113 with pytest.raises(ValueError): 

114 config_with_custom_data = deepcopy(config) 

115 config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] # pylint: disable=line-too-long # noqa 

116 AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) 

117 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) 

118 assert azure_vm_service.deploy_params["customData"] 

119 

120 

121@pytest.mark.parametrize( 

122 ("operation_name", "accepts_params"), 

123 [ 

124 ("start_host", True), 

125 ("stop_host", True), 

126 ("shutdown", True), 

127 ("deprovision_host", True), 

128 ("deallocate_host", True), 

129 ("restart_host", True), 

130 ("reboot", True), 

131 ], 

132) 

133@pytest.mark.parametrize( 

134 ("http_status_code", "operation_status"), 

135 [ 

136 (200, Status.SUCCEEDED), 

137 (202, Status.PENDING), 

138 (401, Status.FAILED), 

139 (404, Status.FAILED), 

140 ], 

141) 

142@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") 

143def test_vm_operation_status( 

144 mock_requests: MagicMock, 

145 azure_vm_service: AzureVMService, 

146 operation_name: str, 

147 accepts_params: bool, 

148 http_status_code: int, 

149 operation_status: Status, 

150) -> None: 

151 """Test VM operation status.""" 

152 # pylint: disable=too-many-arguments,too-many-positional-arguments 

153 mock_response = MagicMock() 

154 mock_response.status_code = http_status_code 

155 mock_requests.post.return_value = mock_response 

156 

157 operation = getattr(azure_vm_service, operation_name) 

158 with pytest.raises(ValueError): 

159 # Missing vmName should raise ValueError 

160 (status, _) = operation({}) if accepts_params else operation() 

161 (status, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() 

162 assert status == operation_status 

163 

164 

165@pytest.mark.parametrize( 

166 ("operation_name", "accepts_params"), 

167 [ 

168 ("provision_host", True), 

169 ], 

170) 

171def test_vm_operation_invalid( 

172 azure_vm_service_remote_exec_only: AzureVMService, 

173 operation_name: str, 

174 accepts_params: bool, 

175) -> None: 

176 """Test VM operation status for an incomplete service config.""" 

177 operation = getattr(azure_vm_service_remote_exec_only, operation_name) 

178 with pytest.raises(ValueError): 

179 (_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() 

180 

181 

182@patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") 

183@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") 

184def test_wait_vm_operation_ready( 

185 mock_session: MagicMock, 

186 mock_sleep: MagicMock, 

187 azure_vm_service: AzureVMService, 

188) -> None: 

189 """Test waiting for the completion of the remote VM operation.""" 

190 # Mock response header 

191 async_url = "DUMMY_ASYNC_URL" 

192 retry_after = 12345 

193 params = { 

194 "asyncResultsUrl": async_url, 

195 "vmName": "test-vm", 

196 "pollInterval": retry_after, 

197 } 

198 

199 mock_status_response = MagicMock(status_code=200) 

200 mock_status_response.json.return_value = { 

201 "status": "Succeeded", 

202 } 

203 mock_session.return_value.get.return_value = mock_status_response 

204 

205 status, _ = azure_vm_service.wait_host_operation(params) 

206 

207 assert (async_url,) == mock_session.return_value.get.call_args[0] 

208 assert (retry_after,) == mock_sleep.call_args[0] 

209 assert status.is_succeeded() 

210 

211 

212@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") 

213def test_wait_vm_operation_timeout( 

214 mock_session: MagicMock, 

215 azure_vm_service: AzureVMService, 

216) -> None: 

217 """Test the time out of the remote VM operation.""" 

218 # Mock response header 

219 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1} 

220 

221 mock_status_response = MagicMock(status_code=200) 

222 mock_status_response.json.return_value = { 

223 "status": "InProgress", 

224 } 

225 mock_session.return_value.get.return_value = mock_status_response 

226 

227 (status, _) = azure_vm_service.wait_host_operation(params) 

228 assert status == Status.TIMED_OUT 

229 

230 

231@pytest.mark.parametrize( 

232 ("total_retries", "operation_status"), 

233 [ 

234 (2, Status.SUCCEEDED), 

235 (1, Status.FAILED), 

236 (0, Status.FAILED), 

237 ], 

238) 

239@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") 

240def test_wait_vm_operation_retry( 

241 mock_getconn: MagicMock, 

242 total_retries: int, 

243 operation_status: Status, 

244 azure_vm_service: AzureVMService, 

245) -> None: 

246 """Test the retries of the remote VM operation.""" 

247 # Simulate intermittent connection issues with multiple connection errors 

248 # Sufficient retry attempts should result in success, otherwise a graceful failure state 

249 mock_getconn.return_value.getresponse.side_effect = [ 

250 make_httplib_json_response(200, {"status": "InProgress"}), 

251 requests_ex.ConnectionError( 

252 "Connection aborted", 

253 OSError(107, "Transport endpoint is not connected"), 

254 ), 

255 requests_ex.ConnectionError( 

256 "Connection aborted", 

257 OSError(107, "Transport endpoint is not connected"), 

258 ), 

259 make_httplib_json_response(200, {"status": "InProgress"}), 

260 make_httplib_json_response(200, {"status": "Succeeded"}), 

261 ] 

262 

263 (status, _) = azure_vm_service.wait_host_operation( 

264 params={ 

265 "pollInterval": 0.1, 

266 "requestTotalRetries": total_retries, 

267 "asyncResultsUrl": "https://DUMMY_ASYNC_URL", 

268 "vmName": "test-vm", 

269 } 

270 ) 

271 assert status == operation_status 

272 

273 

274@pytest.mark.parametrize( 

275 ("http_status_code", "operation_status"), 

276 [ 

277 (200, Status.PENDING), 

278 (201, Status.PENDING), 

279 (401, Status.FAILED), 

280 (404, Status.FAILED), 

281 ], 

282) 

283@patch("mlos_bench.services.remote.azure.azure_vm_services.requests") 

284def test_remote_exec_status( 

285 mock_requests: MagicMock, 

286 azure_vm_service_remote_exec_only: AzureVMService, 

287 http_status_code: int, 

288 operation_status: Status, 

289) -> None: 

290 """Test waiting for completion of the remote execution on Azure.""" 

291 script = ["command_1", "command_2"] 

292 

293 mock_response = MagicMock() 

294 mock_response.status_code = http_status_code 

295 mock_response.json.return_value = { 

296 "fake response": "body as json to dict", 

297 } 

298 mock_requests.put.return_value = mock_response 

299 

300 status, _ = azure_vm_service_remote_exec_only.remote_exec( 

301 script, 

302 config={ 

303 "vmName": "test-vm", 

304 "commandName": "TEST_COMMAND", 

305 "location": "TEST_LOCATION", 

306 }, 

307 env_params={}, 

308 ) 

309 

310 assert status == operation_status 

311 

312 

313@patch("mlos_bench.services.remote.azure.azure_vm_services.requests") 

314def test_remote_exec_output( 

315 mock_requests: MagicMock, 

316 azure_vm_service_remote_exec_only: AzureVMService, 

317) -> None: 

318 """Check if HTTP headers from the remote execution on Azure are correct.""" 

319 async_url_key = "asyncResultsUrl" 

320 async_url_value = "DUMMY_ASYNC_URL" 

321 script = ["command_1", "command_2"] 

322 

323 mock_response = MagicMock() 

324 mock_response.status_code = 201 

325 mock_response.headers = {"Azure-AsyncOperation": async_url_value} 

326 mock_response.json = MagicMock( 

327 return_value={ 

328 "fake response": "body as json to dict", 

329 } 

330 ) 

331 mock_requests.put.return_value = mock_response 

332 

333 _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( 

334 script, 

335 config={ 

336 "vmName": "test-vm", 

337 "commandName": "TEST_COMMAND", 

338 "location": "TEST_LOCATION", 

339 }, 

340 env_params={ 

341 "param_1": 123, 

342 "param_2": "abc", 

343 }, 

344 ) 

345 

346 assert async_url_key in cmd_output 

347 

348 assert mock_requests.put.call_args[1]["json"] == { 

349 "location": "TEST_LOCATION", 

350 "properties": { 

351 "source": {"script": "; ".join(script)}, 

352 "protectedParameters": [ 

353 {"name": "param_1", "value": 123}, 

354 {"name": "param_2", "value": "abc"}, 

355 ], 

356 "timeoutInSeconds": 2, 

357 "asyncExecution": True, 

358 }, 

359 } 

360 

361 

362@pytest.mark.parametrize( 

363 ("operation_status", "wait_output", "results_output"), 

364 [ 

365 ( 

366 Status.SUCCEEDED, 

367 { 

368 "properties": { 

369 "instanceView": { 

370 "output": "DUMMY_STDOUT\n", 

371 "error": "DUMMY_STDERR\n", 

372 "executionState": "Succeeded", 

373 "exitCode": 0, 

374 "startTime": "2024-01-01T00:00:00+00:00", 

375 "endTime": "2024-01-01T00:01:00+00:00", 

376 } 

377 } 

378 }, 

379 { 

380 "stdout": ["DUMMY_STDOUT"], 

381 "stderr": ["DUMMY_STDERR"], 

382 "exitCode": 0, 

383 "startTimestamp": datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), 

384 "endTimestamp": datetime(2024, 1, 1, 0, 1, 0, tzinfo=timezone.utc), 

385 }, 

386 ), 

387 (Status.PENDING, {}, {}), 

388 (Status.FAILED, {}, {}), 

389 ], 

390) 

391def test_get_remote_exec_results( 

392 azure_vm_service_remote_exec_only: AzureVMService, 

393 operation_status: Status, 

394 wait_output: dict, 

395 results_output: dict, 

396) -> None: 

397 """Test getting the results of the remote execution on Azure.""" 

398 params = { 

399 "asyncResultsUrl": "DUMMY_ASYNC_URL", 

400 } 

401 

402 mock_wait_remote_exec_operation = MagicMock() 

403 mock_wait_remote_exec_operation.return_value = (operation_status, wait_output) 

404 # azure_vm_service.wait_remote_exec_operation = mock_wait_remote_exec_operation 

405 setattr( 

406 azure_vm_service_remote_exec_only, 

407 "wait_remote_exec_operation", 

408 mock_wait_remote_exec_operation, 

409 ) 

410 

411 status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params) 

412 

413 assert status == operation_status 

414 assert mock_wait_remote_exec_operation.call_args[0][0] == params 

415 assert cmd_output == results_output