Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py: 79%

136 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for managing VMs on Azure.""" 

6 

7import json 

8import logging 

9from datetime import datetime 

10from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 

11 

12import requests 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.services.base_service import Service 

16from mlos_bench.services.remote.azure.azure_deployment_services import ( 

17 AzureDeploymentService, 

18) 

19from mlos_bench.services.types.host_ops_type import SupportsHostOps 

20from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning 

21from mlos_bench.services.types.os_ops_type import SupportsOSOps 

22from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec 

23from mlos_bench.util import merge_parameters 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class AzureVMService( 

29 AzureDeploymentService, 

30 SupportsHostProvisioning, 

31 SupportsHostOps, 

32 SupportsOSOps, 

33 SupportsRemoteExec, 

34): 

35 """Helper methods to manage VMs on Azure.""" 

36 

37 # pylint: disable=too-many-ancestors 

38 

39 # Azure Compute REST API calls as described in 

40 # https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines 

41 

42 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start 

43 _URL_START = ( 

44 "https://management.azure.com" 

45 "/subscriptions/{subscription}" 

46 "/resourceGroups/{resource_group}" 

47 "/providers/Microsoft.Compute" 

48 "/virtualMachines/{vm_name}" 

49 "/start" 

50 "?api-version=2022-03-01" 

51 ) 

52 

53 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off 

54 _URL_STOP = ( 

55 "https://management.azure.com" 

56 "/subscriptions/{subscription}" 

57 "/resourceGroups/{resource_group}" 

58 "/providers/Microsoft.Compute" 

59 "/virtualMachines/{vm_name}" 

60 "/powerOff" 

61 "?api-version=2022-03-01" 

62 ) 

63 

64 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate 

65 _URL_DEALLOCATE = ( 

66 "https://management.azure.com" 

67 "/subscriptions/{subscription}" 

68 "/resourceGroups/{resource_group}" 

69 "/providers/Microsoft.Compute" 

70 "/virtualMachines/{vm_name}" 

71 "/deallocate" 

72 "?api-version=2022-03-01" 

73 ) 

74 

75 # TODO: This is probably the more correct URL to use for the deprovision operation. 

76 # However, previous code used the deallocate URL above, so for now, we keep 

77 # that and handle that change later. 

78 # See Also: #498 

79 _URL_DEPROVISION = _URL_DEALLOCATE 

80 

81 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/delete 

82 # _URL_DEPROVISION = ( 

83 # "https://management.azure.com" 

84 # "/subscriptions/{subscription}" 

85 # "/resourceGroups/{resource_group}" 

86 # "/providers/Microsoft.Compute" 

87 # "/virtualMachines/{vm_name}" 

88 # "/delete" 

89 # "?api-version=2022-03-01" 

90 # ) 

91 

92 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart 

93 _URL_REBOOT = ( 

94 "https://management.azure.com" 

95 "/subscriptions/{subscription}" 

96 "/resourceGroups/{resource_group}" 

97 "/providers/Microsoft.Compute" 

98 "/virtualMachines/{vm_name}" 

99 "/restart" 

100 "?api-version=2022-03-01" 

101 ) 

102 

103 # From: 

104 # https://learn.microsoft.com/en-us/rest/api/compute/virtual-machine-run-commands/create-or-update 

105 _URL_REXEC_RUN = ( 

106 "https://management.azure.com" 

107 "/subscriptions/{subscription}" 

108 "/resourceGroups/{resource_group}" 

109 "/providers/Microsoft.Compute" 

110 "/virtualMachines/{vm_name}" 

111 "/runcommands/{command_name}" 

112 "?api-version=2024-07-01" 

113 ) 

114 _URL_REXEC_RESULT = ( 

115 "https://management.azure.com" 

116 "/subscriptions/{subscription}" 

117 "/resourceGroups/{resource_group}" 

118 "/providers/Microsoft.Compute" 

119 "/virtualMachines/{vm_name}" 

120 "/runcommands/{command_name}" 

121 "?$expand=instanceView&api-version=2024-07-01" 

122 ) 

123 

124 def __init__( 

125 self, 

126 config: Optional[Dict[str, Any]] = None, 

127 global_config: Optional[Dict[str, Any]] = None, 

128 parent: Optional[Service] = None, 

129 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

130 ): 

131 """ 

132 Create a new instance of Azure VM services proxy. 

133 

134 Parameters 

135 ---------- 

136 config : dict 

137 Free-format dictionary that contains the benchmark environment 

138 configuration. 

139 global_config : dict 

140 Free-format dictionary of global parameters. 

141 parent : Service 

142 Parent service that can provide mixin functions. 

143 methods : Union[Dict[str, Callable], List[Callable], None] 

144 New methods to register with the service. 

145 """ 

146 super().__init__( 

147 config, 

148 global_config, 

149 parent, 

150 self.merge_methods( 

151 methods, 

152 [ 

153 # SupportsHostProvisioning 

154 self.provision_host, 

155 self.deprovision_host, 

156 self.deallocate_host, 

157 self.wait_host_deployment, 

158 # SupportsHostOps 

159 self.start_host, 

160 self.stop_host, 

161 self.restart_host, 

162 self.wait_host_operation, 

163 # SupportsOSOps 

164 self.shutdown, 

165 self.reboot, 

166 self.wait_os_operation, 

167 # SupportsRemoteExec 

168 self.remote_exec, 

169 self.get_remote_exec_results, 

170 ], 

171 ), 

172 ) 

173 

174 # As a convenience, allow reading customData out of a file, rather than 

175 # embedding it in a json config file. 

176 # Note: ARM templates expect this data to be base64 encoded, but that 

177 # can be done using the `base64()` string function inside the ARM template. 

178 self._custom_data_file = self.config.get("customDataFile", None) 

179 if self._custom_data_file: 

180 if self._deploy_params.get("customData", None): 

181 raise ValueError("Both customDataFile and customData are specified.") 

182 self._custom_data_file = self.config_loader_service.resolve_path( 

183 self._custom_data_file 

184 ) 

185 with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh: 

186 self._deploy_params["customData"] = custom_data_fh.read() 

187 

188 def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use 

189 # Try and provide a semi sane default for the deploymentName if not provided 

190 # since this is a common way to set the deploymentName and can same some 

191 # config work for the caller. 

192 if "vmName" in params and "deploymentName" not in params: 

193 params["deploymentName"] = f"""{params["vmName"]}-deployment""" 

194 

195 _LOG.info( 

196 "deploymentName missing from params. Defaulting to '%s'.", 

197 params["deploymentName"], 

198 ) 

199 return params 

200 

201 def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: 

202 """ 

203 Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. 

204 Return TIMED_OUT when timing out. 

205 

206 Parameters 

207 ---------- 

208 params : dict 

209 Flat dictionary of (key, value) pairs of tunable parameters. 

210 is_setup : bool 

211 If True, wait for VM being deployed; otherwise, wait for successful deprovisioning. 

212 

213 Returns 

214 ------- 

215 result : (Status, dict) 

216 A pair of Status and result. 

217 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

218 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

219 """ 

220 return self._wait_deployment(params, is_setup=is_setup) 

221 

222 def wait_host_operation(self, params: dict) -> Tuple[Status, dict]: 

223 """ 

224 Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. 

225 Return TIMED_OUT when timing out. 

226 

227 Parameters 

228 ---------- 

229 params: dict 

230 Flat dictionary of (key, value) pairs of tunable parameters. 

231 Must have the "asyncResultsUrl" key to get the results. 

232 If the key is not present, return Status.PENDING. 

233 

234 Returns 

235 ------- 

236 result : (Status, dict) 

237 A pair of Status and result. 

238 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

239 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

240 """ 

241 _LOG.info("Wait for operation on VM %s", params["vmName"]) 

242 # Try and provide a semi sane default for the deploymentName 

243 params.setdefault(f"""{params["vmName"]}-deployment""") 

244 return self._wait_while(self._check_operation_status, Status.RUNNING, params) 

245 

246 def wait_remote_exec_operation(self, params: dict) -> Tuple["Status", dict]: 

247 """ 

248 Waits for a pending remote execution on an Azure VM to resolve to SUCCEEDED or 

249 FAILED. Return TIMED_OUT when timing out. 

250 

251 Parameters 

252 ---------- 

253 params: dict 

254 Flat dictionary of (key, value) pairs of tunable parameters. 

255 Must have the "asyncResultsUrl" key to get the results. 

256 If the key is not present, return Status.PENDING. 

257 

258 Returns 

259 ------- 

260 result : (Status, dict) 

261 A pair of Status and result. 

262 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

263 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

264 """ 

265 _LOG.info("Wait for run command %s on VM %s", params["commandName"], params["vmName"]) 

266 return self._wait_while(self._check_remote_exec_status, Status.RUNNING, params) 

267 

268 def wait_os_operation(self, params: dict) -> Tuple["Status", dict]: 

269 return self.wait_host_operation(params) 

270 

271 def provision_host(self, params: dict) -> Tuple[Status, dict]: 

272 """ 

273 Check if Azure VM is ready. Deploy a new VM, if necessary. 

274 

275 Parameters 

276 ---------- 

277 params : dict 

278 Flat dictionary of (key, value) pairs of tunable parameters. 

279 HostEnv tunables are variable parameters that, together with the 

280 HostEnv configuration, are sufficient to provision a VM. 

281 

282 Returns 

283 ------- 

284 result : (Status, dict) 

285 A pair of Status and result. The result is the input `params` plus the 

286 parameters extracted from the response JSON, or {} if the status is FAILED. 

287 Status is one of {PENDING, SUCCEEDED, FAILED} 

288 """ 

289 return self._provision_resource(params) 

290 

291 def deprovision_host(self, params: dict) -> Tuple[Status, dict]: 

292 """ 

293 Deprovisions the VM on Azure by deleting it. 

294 

295 Parameters 

296 ---------- 

297 params : dict 

298 Flat dictionary of (key, value) pairs of tunable parameters. 

299 

300 Returns 

301 ------- 

302 result : (Status, dict) 

303 A pair of Status and result. The result is always {}. 

304 Status is one of {PENDING, SUCCEEDED, FAILED} 

305 """ 

306 params = self._set_default_params(params) 

307 config = merge_parameters( 

308 dest=self.config.copy(), 

309 source=params, 

310 required_keys=[ 

311 "subscription", 

312 "resourceGroup", 

313 "deploymentName", 

314 "vmName", 

315 ], 

316 ) 

317 _LOG.info("Deprovision VM: %s", config["vmName"]) 

318 _LOG.info("Deprovision deployment: %s", config["deploymentName"]) 

319 # TODO: Properly deprovision *all* resources specified in the ARM template. 

320 return self._azure_rest_api_post_helper( 

321 config, 

322 self._URL_DEPROVISION.format( 

323 subscription=config["subscription"], 

324 resource_group=config["resourceGroup"], 

325 vm_name=config["vmName"], 

326 ), 

327 ) 

328 

329 def deallocate_host(self, params: dict) -> Tuple[Status, dict]: 

330 """ 

331 Deallocates the VM on Azure by shutting it down then releasing the compute 

332 resources. 

333 

334 Note: This can cause the VM to arrive on a new host node when its 

335 restarted, which may have different performance characteristics. 

336 

337 Parameters 

338 ---------- 

339 params : dict 

340 Flat dictionary of (key, value) pairs of tunable parameters. 

341 

342 Returns 

343 ------- 

344 result : (Status, dict) 

345 A pair of Status and result. The result is always {}. 

346 Status is one of {PENDING, SUCCEEDED, FAILED} 

347 """ 

348 params = self._set_default_params(params) 

349 config = merge_parameters( 

350 dest=self.config.copy(), 

351 source=params, 

352 required_keys=[ 

353 "subscription", 

354 "resourceGroup", 

355 "vmName", 

356 ], 

357 ) 

358 _LOG.info("Deallocate VM: %s", config["vmName"]) 

359 return self._azure_rest_api_post_helper( 

360 config, 

361 self._URL_DEALLOCATE.format( 

362 subscription=config["subscription"], 

363 resource_group=config["resourceGroup"], 

364 vm_name=config["vmName"], 

365 ), 

366 ) 

367 

368 def start_host(self, params: dict) -> Tuple[Status, dict]: 

369 """ 

370 Start the VM on Azure. 

371 

372 Parameters 

373 ---------- 

374 params : dict 

375 Flat dictionary of (key, value) pairs of tunable parameters. 

376 

377 Returns 

378 ------- 

379 result : (Status, dict) 

380 A pair of Status and result. The result is always {}. 

381 Status is one of {PENDING, SUCCEEDED, FAILED} 

382 """ 

383 params = self._set_default_params(params) 

384 config = merge_parameters( 

385 dest=self.config.copy(), 

386 source=params, 

387 required_keys=[ 

388 "subscription", 

389 "resourceGroup", 

390 "vmName", 

391 ], 

392 ) 

393 _LOG.info("Start VM: %s :: %s", config["vmName"], params) 

394 return self._azure_rest_api_post_helper( 

395 config, 

396 self._URL_START.format( 

397 subscription=config["subscription"], 

398 resource_group=config["resourceGroup"], 

399 vm_name=config["vmName"], 

400 ), 

401 ) 

402 

403 def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: 

404 """ 

405 Stops the VM on Azure by initiating a graceful shutdown. 

406 

407 Parameters 

408 ---------- 

409 params : dict 

410 Flat dictionary of (key, value) pairs of tunable parameters. 

411 force : bool 

412 If True, force stop the Host/VM. 

413 

414 Returns 

415 ------- 

416 result : (Status, dict) 

417 A pair of Status and result. The result is always {}. 

418 Status is one of {PENDING, SUCCEEDED, FAILED} 

419 """ 

420 params = self._set_default_params(params) 

421 config = merge_parameters( 

422 dest=self.config.copy(), 

423 source=params, 

424 required_keys=[ 

425 "subscription", 

426 "resourceGroup", 

427 "vmName", 

428 ], 

429 ) 

430 _LOG.info("Stop VM: %s", config["vmName"]) 

431 return self._azure_rest_api_post_helper( 

432 config, 

433 self._URL_STOP.format( 

434 subscription=config["subscription"], 

435 resource_group=config["resourceGroup"], 

436 vm_name=config["vmName"], 

437 ), 

438 ) 

439 

440 def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: 

441 return self.stop_host(params, force) 

442 

443 def restart_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: 

444 """ 

445 Reboot the VM on Azure by initiating a graceful shutdown. 

446 

447 Parameters 

448 ---------- 

449 params : dict 

450 Flat dictionary of (key, value) pairs of tunable parameters. 

451 force : bool 

452 If True, force restart the Host/VM. 

453 

454 Returns 

455 ------- 

456 result : (Status, dict) 

457 A pair of Status and result. The result is always {}. 

458 Status is one of {PENDING, SUCCEEDED, FAILED} 

459 """ 

460 params = self._set_default_params(params) 

461 config = merge_parameters( 

462 dest=self.config.copy(), 

463 source=params, 

464 required_keys=[ 

465 "subscription", 

466 "resourceGroup", 

467 "vmName", 

468 ], 

469 ) 

470 _LOG.info("Reboot VM: %s", config["vmName"]) 

471 return self._azure_rest_api_post_helper( 

472 config, 

473 self._URL_REBOOT.format( 

474 subscription=config["subscription"], 

475 resource_group=config["resourceGroup"], 

476 vm_name=config["vmName"], 

477 ), 

478 ) 

479 

480 def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: 

481 return self.restart_host(params, force) 

482 

483 def remote_exec( 

484 self, 

485 script: Iterable[str], 

486 config: dict, 

487 env_params: dict, 

488 ) -> Tuple[Status, dict]: 

489 """ 

490 Run a command on Azure VM. 

491 

492 Parameters 

493 ---------- 

494 script : Iterable[str] 

495 A list of lines to execute as a script on a remote VM. 

496 config : dict 

497 Flat dictionary of (key, value) pairs of the Environment parameters. 

498 They usually come from `const_args` and `tunable_params` 

499 properties of the Environment. 

500 env_params : dict 

501 Parameters to pass as *shell* environment variables into the script. 

502 This is usually a subset of `config` with some possible conversions. 

503 

504 Returns 

505 ------- 

506 result : (Status, dict) 

507 A pair of Status and result. 

508 Status is one of {PENDING, SUCCEEDED, FAILED} 

509 """ 

510 config = self._set_default_params(config) 

511 config = merge_parameters( 

512 dest=self.config.copy(), 

513 source=config, 

514 required_keys=[ 

515 "subscription", 

516 "resourceGroup", 

517 "vmName", 

518 "commandName", 

519 "location", 

520 ], 

521 ) 

522 

523 if _LOG.isEnabledFor(logging.INFO): 

524 _LOG.info("Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script)) 

525 

526 json_req = { 

527 "location": config["location"], 

528 "properties": { 

529 "source": {"script": "; ".join(script)}, 

530 "protectedParameters": [ 

531 {"name": key, "value": val} for (key, val) in env_params.items() 

532 ], 

533 "timeoutInSeconds": int(self._poll_timeout), 

534 "asyncExecution": True, 

535 }, 

536 } 

537 

538 url = self._URL_REXEC_RUN.format( 

539 subscription=config["subscription"], 

540 resource_group=config["resourceGroup"], 

541 vm_name=config["vmName"], 

542 command_name=config["commandName"], 

543 ) 

544 

545 if _LOG.isEnabledFor(logging.DEBUG): 

546 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) 

547 

548 response = requests.put( 

549 url, 

550 json=json_req, 

551 headers=self._get_headers(), 

552 timeout=self._request_timeout, 

553 ) 

554 

555 if _LOG.isEnabledFor(logging.DEBUG): 

556 _LOG.debug( 

557 "Response: %s\n%s", 

558 response, 

559 json.dumps(response.json(), indent=2) if response.content else "", 

560 ) 

561 else: 

562 _LOG.info("Response: %s", response) 

563 

564 if response.status_code in {200, 201}: 

565 results_url = self._URL_REXEC_RESULT.format( 

566 subscription=config["subscription"], 

567 resource_group=config["resourceGroup"], 

568 vm_name=config["vmName"], 

569 command_name=config["commandName"], 

570 ) 

571 return ( 

572 Status.PENDING, 

573 {**config, "asyncResultsUrl": results_url}, 

574 ) 

575 else: 

576 _LOG.error("Response: %s :: %s", response, response.text) 

577 return (Status.FAILED, {}) 

578 

579 def _check_remote_exec_status(self, params: dict) -> Tuple[Status, dict]: 

580 """ 

581 Checks the status of a pending remote execution on an Azure VM. 

582 

583 Parameters 

584 ---------- 

585 params: dict 

586 Flat dictionary of (key, value) pairs of tunable parameters. 

587 Must have the "asyncResultsUrl" key to get the results. 

588 If the key is not present, return Status.PENDING. 

589 

590 Returns 

591 ------- 

592 result : (Status, dict) 

593 A pair of Status and result. 

594 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} 

595 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

596 """ 

597 url = params.get("asyncResultsUrl") 

598 if url is None: 

599 return Status.PENDING, {} 

600 

601 session = self._get_session(params) 

602 try: 

603 response = session.get(url, timeout=self._request_timeout) 

604 except requests.exceptions.ReadTimeout: 

605 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

606 return Status.RUNNING, {} 

607 except requests.exceptions.RequestException as ex: 

608 _LOG.exception("Error in request checking operation status", exc_info=ex) 

609 return (Status.FAILED, {}) 

610 

611 if _LOG.isEnabledFor(logging.DEBUG): 

612 _LOG.debug( 

613 "Response: %s\n%s", 

614 response, 

615 json.dumps(response.json(), indent=2) if response.content else "", 

616 ) 

617 

618 if response.status_code == 200: 

619 output = response.json() 

620 execution_state = ( 

621 output.get("properties", {}).get("instanceView", {}).get("executionState") 

622 ) 

623 if execution_state in {"Running", "Pending"}: 

624 return Status.RUNNING, {} 

625 elif execution_state == "Succeeded": 

626 return Status.SUCCEEDED, output 

627 

628 _LOG.error("Response: %s :: %s", response, response.text) 

629 return Status.FAILED, {} 

630 

631 def get_remote_exec_results(self, config: dict) -> Tuple[Status, dict]: 

632 """ 

633 Get the results of the asynchronously running command. 

634 

635 Parameters 

636 ---------- 

637 config : dict 

638 Flat dictionary of (key, value) pairs of tunable parameters. 

639 Must have the "asyncResultsUrl" key to get the results. 

640 If the key is not present, return Status.PENDING. 

641 

642 Returns 

643 ------- 

644 result : (Status, dict) 

645 A pair of Status and result. 

646 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

647 A dict can have an "stdout" key with the remote output 

648 and an "stderr" key for errors / warnings. 

649 """ 

650 _LOG.info("Check the results on VM: %s", config.get("vmName")) 

651 (status, result) = self.wait_remote_exec_operation(config) 

652 _LOG.debug("Result: %s :: %s", status, result) 

653 if not status.is_succeeded(): 

654 # TODO: Extract the telemetry and status from stdout, if available 

655 return (status, result) 

656 

657 output = result.get("properties", {}).get("instanceView", {}) 

658 exit_code = output.get("exitCode") 

659 execution_state = output.get("executionState") 

660 outputs = output.get("output", "").strip().split("\n") 

661 errors = output.get("error", "").strip().split("\n") 

662 

663 if execution_state == "Succeeded" and exit_code == 0: 

664 status = Status.SUCCEEDED 

665 else: 

666 status = Status.FAILED 

667 

668 return ( 

669 status, 

670 { 

671 "stdout": outputs, 

672 "stderr": errors, 

673 "exitCode": exit_code, 

674 "startTimestamp": datetime.fromisoformat(output["startTime"]), 

675 "endTimestamp": datetime.fromisoformat(output["endTime"]), 

676 }, 

677 )