Coverage for mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py: 100%
108 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-20 00:44 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tests for mlos_bench.services.remote.azure.azure_vm_services."""
7from copy import deepcopy
8from datetime import datetime, timezone
9from unittest.mock import MagicMock, patch
11import pytest
12import requests.exceptions as requests_ex
14from mlos_bench.environments.status import Status
15from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
16from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService
17from mlos_bench.tests.services.remote.azure import make_httplib_json_response
20@pytest.mark.parametrize(
21 ("total_retries", "operation_status"),
22 [
23 (2, Status.SUCCEEDED),
24 (1, Status.FAILED),
25 (0, Status.FAILED),
26 ],
27)
28@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
29def test_wait_host_deployment_retry(
30 mock_getconn: MagicMock,
31 total_retries: int,
32 operation_status: Status,
33 azure_vm_service: AzureVMService,
34) -> None:
35 """Test retries of the host deployment operation."""
36 # Simulate intermittent connection issues with multiple connection errors
37 # Sufficient retry attempts should result in success, otherwise a graceful failure state
38 mock_getconn.return_value.getresponse.side_effect = [
39 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
40 requests_ex.ConnectionError(
41 "Connection aborted",
42 OSError(107, "Transport endpoint is not connected"),
43 ),
44 requests_ex.ConnectionError(
45 "Connection aborted",
46 OSError(107, "Transport endpoint is not connected"),
47 ),
48 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
49 make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}),
50 ]
52 (status, _) = azure_vm_service.wait_host_deployment(
53 params={
54 "pollInterval": 0.1,
55 "requestTotalRetries": total_retries,
56 "deploymentName": "TEST_DEPLOYMENT1",
57 "subscription": "TEST_SUB1",
58 "resourceGroup": "TEST_RG1",
59 },
60 is_setup=True,
61 )
62 assert status == operation_status
65def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None:
66 """Test expanding template params recursively."""
67 config = {
68 "deploymentTemplatePath": (
69 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
70 ),
71 "subscription": "TEST_SUB1",
72 "resourceGroup": "TEST_RG1",
73 "deploymentTemplateParameters": {
74 "location": "$location",
75 "vmMeta": "$vmName-$location",
76 "vmNsg": "$vmMeta-nsg",
77 },
78 }
79 global_config = {
80 "deploymentName": "TEST_DEPLOYMENT1",
81 "vmName": "test-vm",
82 "location": "eastus",
83 }
84 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service)
85 assert azure_vm_service.deploy_params["location"] == global_config["location"]
86 assert (
87 azure_vm_service.deploy_params["vmMeta"]
88 == f'{global_config["vmName"]}-{global_config["location"]}'
89 )
90 assert (
91 azure_vm_service.deploy_params["vmNsg"]
92 == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg'
93 )
96def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None:
97 """Test loading custom data from a file."""
98 config = {
99 "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml",
100 "deploymentTemplatePath": (
101 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
102 ),
103 "subscription": "TEST_SUB1",
104 "resourceGroup": "TEST_RG1",
105 "deploymentTemplateParameters": {
106 "location": "eastus2",
107 },
108 }
109 global_config = {
110 "deploymentName": "TEST_DEPLOYMENT1",
111 "vmName": "test-vm",
112 }
113 with pytest.raises(ValueError):
114 config_with_custom_data = deepcopy(config)
115 config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] # pylint: disable=line-too-long # noqa
116 AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service)
117 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service)
118 assert azure_vm_service.deploy_params["customData"]
121@pytest.mark.parametrize(
122 ("operation_name", "accepts_params"),
123 [
124 ("start_host", True),
125 ("stop_host", True),
126 ("shutdown", True),
127 ("deprovision_host", True),
128 ("deallocate_host", True),
129 ("restart_host", True),
130 ("reboot", True),
131 ],
132)
133@pytest.mark.parametrize(
134 ("http_status_code", "operation_status"),
135 [
136 (200, Status.SUCCEEDED),
137 (202, Status.PENDING),
138 (401, Status.FAILED),
139 (404, Status.FAILED),
140 ],
141)
142@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests")
143def test_vm_operation_status(
144 mock_requests: MagicMock,
145 azure_vm_service: AzureVMService,
146 operation_name: str,
147 accepts_params: bool,
148 http_status_code: int,
149 operation_status: Status,
150) -> None:
151 """Test VM operation status."""
152 # pylint: disable=too-many-arguments,too-many-positional-arguments
153 mock_response = MagicMock()
154 mock_response.status_code = http_status_code
155 mock_requests.post.return_value = mock_response
157 operation = getattr(azure_vm_service, operation_name)
158 with pytest.raises(ValueError):
159 # Missing vmName should raise ValueError
160 (status, _) = operation({}) if accepts_params else operation()
161 (status, _) = operation({"vmName": "test-vm"}) if accepts_params else operation()
162 assert status == operation_status
165@pytest.mark.parametrize(
166 ("operation_name", "accepts_params"),
167 [
168 ("provision_host", True),
169 ],
170)
171def test_vm_operation_invalid(
172 azure_vm_service_remote_exec_only: AzureVMService,
173 operation_name: str,
174 accepts_params: bool,
175) -> None:
176 """Test VM operation status for an incomplete service config."""
177 operation = getattr(azure_vm_service_remote_exec_only, operation_name)
178 with pytest.raises(ValueError):
179 (_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation()
182@patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep")
183@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session")
184def test_wait_vm_operation_ready(
185 mock_session: MagicMock,
186 mock_sleep: MagicMock,
187 azure_vm_service: AzureVMService,
188) -> None:
189 """Test waiting for the completion of the remote VM operation."""
190 # Mock response header
191 async_url = "DUMMY_ASYNC_URL"
192 retry_after = 12345
193 params = {
194 "asyncResultsUrl": async_url,
195 "vmName": "test-vm",
196 "pollInterval": retry_after,
197 }
199 mock_status_response = MagicMock(status_code=200)
200 mock_status_response.json.return_value = {
201 "status": "Succeeded",
202 }
203 mock_session.return_value.get.return_value = mock_status_response
205 status, _ = azure_vm_service.wait_host_operation(params)
207 assert (async_url,) == mock_session.return_value.get.call_args[0]
208 assert (retry_after,) == mock_sleep.call_args[0]
209 assert status.is_succeeded()
212@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session")
213def test_wait_vm_operation_timeout(
214 mock_session: MagicMock,
215 azure_vm_service: AzureVMService,
216) -> None:
217 """Test the time out of the remote VM operation."""
218 # Mock response header
219 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1}
221 mock_status_response = MagicMock(status_code=200)
222 mock_status_response.json.return_value = {
223 "status": "InProgress",
224 }
225 mock_session.return_value.get.return_value = mock_status_response
227 (status, _) = azure_vm_service.wait_host_operation(params)
228 assert status == Status.TIMED_OUT
231@pytest.mark.parametrize(
232 ("total_retries", "operation_status"),
233 [
234 (2, Status.SUCCEEDED),
235 (1, Status.FAILED),
236 (0, Status.FAILED),
237 ],
238)
239@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
240def test_wait_vm_operation_retry(
241 mock_getconn: MagicMock,
242 total_retries: int,
243 operation_status: Status,
244 azure_vm_service: AzureVMService,
245) -> None:
246 """Test the retries of the remote VM operation."""
247 # Simulate intermittent connection issues with multiple connection errors
248 # Sufficient retry attempts should result in success, otherwise a graceful failure state
249 mock_getconn.return_value.getresponse.side_effect = [
250 make_httplib_json_response(200, {"status": "InProgress"}),
251 requests_ex.ConnectionError(
252 "Connection aborted",
253 OSError(107, "Transport endpoint is not connected"),
254 ),
255 requests_ex.ConnectionError(
256 "Connection aborted",
257 OSError(107, "Transport endpoint is not connected"),
258 ),
259 make_httplib_json_response(200, {"status": "InProgress"}),
260 make_httplib_json_response(200, {"status": "Succeeded"}),
261 ]
263 (status, _) = azure_vm_service.wait_host_operation(
264 params={
265 "pollInterval": 0.1,
266 "requestTotalRetries": total_retries,
267 "asyncResultsUrl": "https://DUMMY_ASYNC_URL",
268 "vmName": "test-vm",
269 }
270 )
271 assert status == operation_status
274@pytest.mark.parametrize(
275 ("http_status_code", "operation_status"),
276 [
277 (200, Status.PENDING),
278 (201, Status.PENDING),
279 (401, Status.FAILED),
280 (404, Status.FAILED),
281 ],
282)
283@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
284def test_remote_exec_status(
285 mock_requests: MagicMock,
286 azure_vm_service_remote_exec_only: AzureVMService,
287 http_status_code: int,
288 operation_status: Status,
289) -> None:
290 """Test waiting for completion of the remote execution on Azure."""
291 script = ["command_1", "command_2"]
293 mock_response = MagicMock()
294 mock_response.status_code = http_status_code
295 mock_response.json.return_value = {
296 "fake response": "body as json to dict",
297 }
298 mock_requests.put.return_value = mock_response
300 status, _ = azure_vm_service_remote_exec_only.remote_exec(
301 script,
302 config={
303 "vmName": "test-vm",
304 "commandName": "TEST_COMMAND",
305 "location": "TEST_LOCATION",
306 },
307 env_params={},
308 )
310 assert status == operation_status
313@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
314def test_remote_exec_output(
315 mock_requests: MagicMock,
316 azure_vm_service_remote_exec_only: AzureVMService,
317) -> None:
318 """Check if HTTP headers from the remote execution on Azure are correct."""
319 async_url_key = "asyncResultsUrl"
320 async_url_value = "DUMMY_ASYNC_URL"
321 script = ["command_1", "command_2"]
323 mock_response = MagicMock()
324 mock_response.status_code = 201
325 mock_response.headers = {"Azure-AsyncOperation": async_url_value}
326 mock_response.json = MagicMock(
327 return_value={
328 "fake response": "body as json to dict",
329 }
330 )
331 mock_requests.put.return_value = mock_response
333 _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(
334 script,
335 config={
336 "vmName": "test-vm",
337 "commandName": "TEST_COMMAND",
338 "location": "TEST_LOCATION",
339 },
340 env_params={
341 "param_1": 123,
342 "param_2": "abc",
343 },
344 )
346 assert async_url_key in cmd_output
348 assert mock_requests.put.call_args[1]["json"] == {
349 "location": "TEST_LOCATION",
350 "properties": {
351 "source": {"script": "; ".join(script)},
352 "protectedParameters": [
353 {"name": "param_1", "value": 123},
354 {"name": "param_2", "value": "abc"},
355 ],
356 "timeoutInSeconds": 2,
357 "asyncExecution": True,
358 },
359 }
362@pytest.mark.parametrize(
363 ("operation_status", "wait_output", "results_output"),
364 [
365 (
366 Status.SUCCEEDED,
367 {
368 "properties": {
369 "instanceView": {
370 "output": "DUMMY_STDOUT\n",
371 "error": "DUMMY_STDERR\n",
372 "executionState": "Succeeded",
373 "exitCode": 0,
374 "startTime": "2024-01-01T00:00:00+00:00",
375 "endTime": "2024-01-01T00:01:00+00:00",
376 }
377 }
378 },
379 {
380 "stdout": ["DUMMY_STDOUT"],
381 "stderr": ["DUMMY_STDERR"],
382 "exitCode": 0,
383 "startTimestamp": datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
384 "endTimestamp": datetime(2024, 1, 1, 0, 1, 0, tzinfo=timezone.utc),
385 },
386 ),
387 (Status.PENDING, {}, {}),
388 (Status.FAILED, {}, {}),
389 ],
390)
391def test_get_remote_exec_results(
392 azure_vm_service_remote_exec_only: AzureVMService,
393 operation_status: Status,
394 wait_output: dict,
395 results_output: dict,
396) -> None:
397 """Test getting the results of the remote execution on Azure."""
398 params = {
399 "asyncResultsUrl": "DUMMY_ASYNC_URL",
400 }
402 mock_wait_remote_exec_operation = MagicMock()
403 mock_wait_remote_exec_operation.return_value = (operation_status, wait_output)
404 # azure_vm_service.wait_remote_exec_operation = mock_wait_remote_exec_operation
405 setattr(
406 azure_vm_service_remote_exec_only,
407 "wait_remote_exec_operation",
408 mock_wait_remote_exec_operation,
409 )
411 status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params)
413 assert status == operation_status
414 assert mock_wait_remote_exec_operation.call_args[0][0] == params
415 assert cmd_output == results_output