Source code for autogen_ext.code_executors.docker_jupyter._jupyter_server

import asyncio
import atexit
import datetime
import io
import json
import logging
import os
import secrets
import uuid
from dataclasses import dataclass
from pathlib import Path
from time import sleep
from types import TracebackType
from typing import Any, Dict, List, Optional, Protocol, Type, Union, cast, runtime_checkable

import aiohttp
import docker
import docker.errors
import requests
import websockets
from requests.adapters import HTTPAdapter, Retry
from typing_extensions import Self


@dataclass
class JupyterConnectionInfo:
    """(Experimental)"""

    host: str
    """`str` - Host of the Jupyter gateway server"""
    use_https: bool
    """`bool` - Whether to use HTTPS"""
    port: Optional[int] = None
    """`Optional[int]` - Port of the Jupyter gateway server. If None, the default port is used"""
    token: Optional[str] = None
    """`Optional[str]` - Token for authentication. If None, no token is used"""


@runtime_checkable
class JupyterConnectable(Protocol):
    """(Experimental)"""

    @property
    def connection_info(self) -> JupyterConnectionInfo:
        """Return the connection information for this connectable."""
        ...


[docs] class JupyterClient: def __init__(self, connection_info: JupyterConnectionInfo): """(Experimental) A client for communicating with a Jupyter gateway server. Args: connection_info (JupyterConnectionInfo): Connection information """ self._connection_info = connection_info self._session = requests.Session() retries = Retry(total=5, backoff_factor=0.1) self._session.mount("http://", HTTPAdapter(max_retries=retries)) # Create aiohttp session for async requests self._async_session: aiohttp.ClientSession | None = None async def _ensure_async_session(self) -> aiohttp.ClientSession: if self._async_session is None: self._async_session = aiohttp.ClientSession() return self._async_session def _get_headers(self) -> Dict[str, str]: if self._connection_info.token is None: return {} return {"Authorization": f"token {self._connection_info.token}"} def _get_api_base_url(self) -> str: protocol = "https" if self._connection_info.use_https else "http" port = f":{self._connection_info.port}" if self._connection_info.port else "" return f"{protocol}://{self._connection_info.host}{port}" def _get_ws_base_url(self) -> str: port = f":{self._connection_info.port}" if self._connection_info.port else "" return f"ws://{self._connection_info.host}{port}"
[docs] async def list_kernel_specs(self) -> Dict[str, Dict[str, str]]: response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) return cast(Dict[str, Dict[str, str]], response.json())
[docs] async def list_kernels(self) -> List[Dict[str, str]]: response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) return cast(List[Dict[str, str]], response.json())
[docs] async def start_kernel(self, kernel_spec_name: str) -> str: """Start a new kernel asynchronously. Args: kernel_spec_name (str): Name of the kernel spec to start Returns: str: ID of the started kernel """ session = await self._ensure_async_session() async with session.post( f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers(), json={"name": kernel_spec_name}, ) as response: data = await response.json() return cast(str, data["id"])
[docs] async def delete_kernel(self, kernel_id: str) -> None: session = await self._ensure_async_session() async with session.delete( f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers() ) as response: response.raise_for_status()
[docs] async def restart_kernel(self, kernel_id: str) -> None: session = await self._ensure_async_session() async with session.post( f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers() ) as response: response.raise_for_status()
[docs] async def get_kernel_client(self, kernel_id: str) -> "JupyterKernelClient": ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" # Using websockets library for async websocket connections ws = await websockets.connect(ws_url, additional_headers=self._get_headers()) return JupyterKernelClient(ws)
[docs] async def close(self) -> None: """Close the async session""" if self._async_session is not None: await self._async_session.close() self._async_session = None self._session.close()
@dataclass class DataItem: mime_type: str data: str @dataclass class ExecutionResult: is_ok: bool output: str data_items: List[DataItem]
[docs] class JupyterKernelClient: """An asynchronous client for communicating with a Jupyter kernel.""" def __init__(self, websocket: websockets.ClientConnection) -> None: self._session_id = uuid.uuid4().hex self._websocket = websocket async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: await self.stop()
[docs] async def stop(self) -> None: await self._websocket.close()
async def _send_message(self, *, content: Dict[str, Any], channel: str, message_type: str) -> str: timestamp = datetime.datetime.now().isoformat() message_id = uuid.uuid4().hex message = { "header": { "username": "autogen", "version": "5.0", "session": self._session_id, "msg_id": message_id, "msg_type": message_type, "date": timestamp, }, "parent_header": {}, "channel": channel, "content": content, "metadata": {}, "buffers": {}, } await self._websocket.send(json.dumps(message)) return message_id async def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[Dict[str, Any]]: try: if timeout_seconds is not None: data = await asyncio.wait_for(self._websocket.recv(), timeout=timeout_seconds) else: data = await self._websocket.recv() if isinstance(data, bytes): return cast(Dict[str, Any], json.loads(data.decode("utf-8"))) return cast(Dict[str, Any], json.loads(data)) except asyncio.TimeoutError: return None
[docs] async def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool: message_id = await self._send_message(content={}, channel="shell", message_type="kernel_info_request") while True: message = await self._receive_message(timeout_seconds) # This means we timed out with no new messages. if message is None: return False if ( message.get("parent_header", {}).get("msg_id") == message_id and message["msg_type"] == "kernel_info_reply" ): return True
[docs] async def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult: message_id = await self._send_message( content={ "code": code, "silent": False, "store_history": True, "user_expressions": {}, "allow_stdin": False, "stop_on_error": True, }, channel="shell", message_type="execute_request", ) text_output: List[str] = [] data_output: List[DataItem] = [] while True: message = await self._receive_message(timeout_seconds) if message is None: return ExecutionResult( is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[] ) # Ignore messages that are not for this execution. if message.get("parent_header", {}).get("msg_id") != message_id: continue msg_type = message["msg_type"] content = message["content"] if msg_type in ["execute_result", "display_data"]: for data_type, data in content["data"].items(): if data_type == "text/plain": text_output.append(data) elif data_type.startswith("image/") or data_type == "text/html": data_output.append(DataItem(mime_type=data_type, data=data)) else: text_output.append(json.dumps(data)) elif msg_type == "stream": text_output.append(content["text"]) elif msg_type == "error": # Output is an error. return ExecutionResult( is_ok=False, output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", data_items=[], ) if msg_type == "status" and content["execution_state"] == "idle": break return ExecutionResult( is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output )
[docs] class DockerJupyterServer(JupyterConnectable): DEFAULT_DOCKERFILE = """FROM quay.io/jupyter/docker-stacks-foundation SHELL ["/bin/bash", "-o", "pipefail", "-c"] USER ${NB_UID} RUN mamba install --yes jupyter_kernel_gateway ipykernel && \ mamba clean --all -f -y && \ fix-permissions "${CONDA_DIR}" && \ fix-permissions "/home/${NB_USER}" ENV TOKEN="UNSET" CMD python -m jupyter kernelgateway --KernelGatewayApp.ip=0.0.0.0 \ --KernelGatewayApp.port=8888 \ --KernelGatewayApp.auth_token="${TOKEN}" \ --JupyterApp.answer_yes=true \ --JupyterWebsocketPersonality.list_kernels=true EXPOSE 8888 WORKDIR "${HOME}" """
[docs] class GenerateToken: pass
def __init__( self, *, custom_image_name: Optional[str] = None, container_name: Optional[str] = None, auto_remove: bool = True, stop_container: bool = True, docker_env: Optional[Dict[str, str]] = None, expose_port: int = 8888, token: Optional[Union[str, GenerateToken]] = None, work_dir: Union[Path, str] = "/workspace", bind_dir: Optional[Union[Path, str]] = None, ): """Start a Jupyter kernel gateway server in a Docker container. Args: custom_image_name: Custom Docker image to use. If None, builds and uses bundled image. container_name: Name for the Docker container. Auto-generated if None. auto_remove: If True, container will be deleted when stopped. stop_container: If True, container stops on program exit or when context manager exits. docker_env: Additional environment variables for the container. expose_port: Port to expose for Jupyter connection. token: Authentication token. If GenerateToken, creates random token. Empty for no auth. work_dir: Working directory inside the container. bind_dir: Local directory to bind to container's work_dir. """ # Generate container name if not provided container_name = container_name or f"autogen-jupyterkernelgateway-{uuid.uuid4()}" # Initialize Docker client client = docker.from_env() # Set up bind directory if specified self._bind_dir: Optional[Path] = None if bind_dir: self._bind_dir = Path(bind_dir) if isinstance(bind_dir, str) else bind_dir self._bind_dir.mkdir(exist_ok=True) os.chmod(bind_dir, 0o777) # Determine and prepare Docker image image_name = custom_image_name or "autogen-jupyterkernelgateway" if not custom_image_name: try: client.images.get(image_name) except docker.errors.ImageNotFound: # Build default image if not found here = Path(__file__).parent dockerfile = io.BytesIO(self.DEFAULT_DOCKERFILE.encode("utf-8")) logging.info(f"Building image {image_name}...") client.images.build(path=str(here), fileobj=dockerfile, tag=image_name) logging.info(f"Image {image_name} built successfully") else: # Verify custom image exists try: client.images.get(image_name) except docker.errors.ImageNotFound as err: raise ValueError(f"Custom image {image_name} does not exist") from err if docker_env is None: docker_env = {} if token is None: token = DockerJupyterServer.GenerateToken() # Set up authentication token self._token = secrets.token_hex(32) if isinstance(token, DockerJupyterServer.GenerateToken) else token # Prepare environment variables env = {"TOKEN": self._token} env.update(docker_env) # Define volume configuration if bind directory is specified volumes = {str(self._bind_dir): {"bind": str(work_dir), "mode": "rw"}} if self._bind_dir else None # Start the container container = client.containers.run( image_name, detach=True, auto_remove=auto_remove, environment=env, publish_all_ports=True, name=container_name, volumes=volumes, working_dir=str(work_dir), ) # Wait for container to be ready self._wait_for_ready(container) # Store container information self._container = container self._port = int(container.ports[f"{expose_port}/tcp"][0]["HostPort"]) self._container_id = container.id self._expose_port = expose_port if self._container_id is None: raise ValueError("Failed to obtain container id.") # Define cleanup function def cleanup() -> None: try: assert self._container_id is not None inner_container = client.containers.get(self._container_id) inner_container.stop() except docker.errors.NotFound: pass atexit.unregister(cleanup) # Register cleanup if container should be stopped automatically if stop_container: atexit.register(cleanup) self._cleanup_func = cleanup self._stop_container = stop_container @property def connection_info(self) -> JupyterConnectionInfo: return JupyterConnectionInfo(host="127.0.0.1", use_https=False, port=self._port, token=self._token) def _wait_for_ready(self, container: Any, timeout: int = 60, stop_time: float = 0.1) -> None: elapsed_time = 0.0 while container.status != "running" and elapsed_time < timeout: sleep(stop_time) elapsed_time += stop_time container.reload() continue if container.status != "running": raise ValueError("Container failed to start")
[docs] async def stop(self) -> None: loop = asyncio.get_event_loop() await loop.run_in_executor(None, self._cleanup_func)
[docs] async def get_client(self) -> JupyterClient: return JupyterClient(self.connection_info)
async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: await self.stop()