Source code for herethere.there.client

"""herethere.there.client"""

from __future__ import annotations

import asyncio
import contextlib
import sys
from contextlib import AbstractAsyncContextManager
from typing import TextIO

import asyncssh

from herethere.everywhere.config import ConnectionConfig
from herethere.everywhere.logging import logger


class ConnectionNotConfiguredError(Exception):
    """Connection configuration is missing."""


class PersistentConnection(AbstractAsyncContextManager):
    """SSH connection async context manager with automatic reconnection."""

    def __init__(self):
        self.config: ConnectionConfig | None = None
        self.connection: asyncssh.SSHClientConnection | None = None

    async def __aenter__(self):
        return await self.ensure_connected()

    async def ensure_connected(self):
        """Return an active SSH connection, reconnecting if needed."""
        if await self.check_connection():
            return self.connection
        if self.connection:
            self.close()
        return await self.reconnect()

    async def __aexit__(self, *exc_info):
        pass

    def close(self):
        """Close current connection."""
        if self.connection:
            try:
                self.connection.close()
            except asyncssh.Error:
                pass
        self.connection = None

    async def configure(self, config: ConnectionConfig):
        """Apply new connection config."""
        self.close()
        self.config = config
        return await self.ensure_connected()

    async def check_connection(self) -> bool:
        """Check connection is active."""
        if self.connection:
            try:
                await self.connection.run("ping", check=True)
            except asyncssh.Error:
                logger.debug("SSH connection ping failed.")
            else:
                return True
        return False

    async def reconnect(self):
        """Establish connection."""
        if not self.config:
            raise ConnectionNotConfiguredError("Connection is not configured.")
        self.connection = await asyncssh.connect(**self.config.asdict, known_hosts=None)
        return self.connection


[docs] class Client: """Client for remote interpreter.""" def __init__(self): self.connection = PersistentConnection()
[docs] async def copy(self) -> Client: """Return a copy of the configured connection.""" client = Client() await client.connect(self.connection.config) return client
[docs] async def connect(self, config: ConnectionConfig): """Connect to remote.""" await self.connection.configure(config)
[docs] async def disconnect(self): """Disconnect from the remote.""" self.connection.close()
[docs] async def runcode( self, code: str, stdout: TextIO | None = None, stderr: TextIO | None = None, ) -> str: """Execute python code on the remote side.""" await self._execute_code("code", code, stdout, stderr)
[docs] async def runcode_background( self, code: str, stdout: TextIO | None = None, stderr: TextIO | None = None, ) -> str: """Execute python code in a separate thread on the remote side.""" await self._execute_code("background", code, stdout, stderr)
[docs] async def shell( self, code: str, stdout: TextIO | None = None, stderr: TextIO | None = None, ) -> str: """Execute shell command on the remote side.""" await self._execute_code("shell", code, stdout, stderr)
[docs] async def upload(self, localpaths: list[str], remotepath) -> str: """Upload files and directories to remote via SFTP.""" async with self.connection as ssh: async with ssh.start_sftp_client() as sftp: await sftp.put( localpaths=localpaths, remotepath=remotepath, recurse=True, progress_handler=self.sftp_progress_handler, )
[docs] def sftp_progress_handler(self, *args, **kwargs): """SFTP uploading progress handler.""" logger.debug("SFTP progress: %s %s", args, kwargs)
async def _execute_code( self, command: str, code: str, stdout: TextIO | None = None, stderr: TextIO | None = None, ): """Execute command with a code on the remote side.""" if stdout is None: stdout = sys.stdout if stderr is None: stderr = sys.stderr async with self.connection as ssh: async with ssh.create_process(command) as process: process.stdin.write(code) # Remote handlers read the submitted code from stdin. Signal # end-of-input so they can start or finish execution instead # of waiting for more code. process.stdin.write_eof() async def forward_output(reader, writer): # Stream line-by-line for long-running commands such as # `%there log`. readline() also returns the final partial # line at EOF, so output without a trailing newline is kept. while data := await reader.readline(): writer.write(data) if hasattr(writer, "flush"): writer.flush() try: await asyncio.gather( forward_output(process.stdout, stdout), forward_output(process.stderr, stderr), ) await process.wait() except asyncio.CancelledError: process.terminate() # Try to close the remote channel cleanly, but keep # cancellation bounded from the caller's perspective. with contextlib.suppress(Exception): await asyncio.wait_for(process.wait(), timeout=1) raise