"""herethere.there.client"""
from __future__ import annotations
import asyncio
import contextlib
import sys
from contextlib import AbstractAsyncContextManager
from typing import TextIO
import asyncssh
from herethere.everywhere.config import ConnectionConfig
from herethere.everywhere.logging import logger
from herethere.everywhere.values import loads_value
class ConnectionNotConfiguredError(Exception):
"""Connection configuration is missing."""
class PersistentConnection(AbstractAsyncContextManager):
"""SSH connection async context manager with automatic reconnection."""
def __init__(self):
self.config: ConnectionConfig | None = None
self.connection: asyncssh.SSHClientConnection | None = None
async def __aenter__(self):
return await self.ensure_connected()
async def ensure_connected(self):
"""Return an active SSH connection, reconnecting if needed."""
if await self.check_connection():
return self.connection
if self.connection:
self.close()
return await self.reconnect()
async def __aexit__(self, *exc_info):
pass
def close(self):
"""Close current connection."""
if self.connection:
try:
self.connection.close()
except asyncssh.Error:
pass
self.connection = None
async def configure(self, config: ConnectionConfig):
"""Apply new connection config."""
self.close()
self.config = config
return await self.ensure_connected()
async def check_connection(self) -> bool:
"""Check connection is active."""
if self.connection:
try:
await self.connection.run("ping", check=True)
except asyncssh.Error:
logger.debug("SSH connection ping failed.")
else:
return True
return False
async def reconnect(self):
"""Establish connection."""
if not self.config:
raise ConnectionNotConfiguredError("Connection is not configured.")
self.connection = await asyncssh.connect(**self.config.asdict, known_hosts=None)
return self.connection
[docs]
class Client:
"""Client for remote interpreter."""
def __init__(self):
self.connection = PersistentConnection()
[docs]
async def copy(self) -> Client:
"""Return a copy of the configured connection."""
client = Client()
await client.connect(self.connection.config)
return client
[docs]
async def connect(self, config: ConnectionConfig):
"""Connect to remote."""
await self.connection.configure(config)
[docs]
async def disconnect(self):
"""Disconnect from the remote."""
self.connection.close()
[docs]
async def runcode(
self,
code: str,
stdout: TextIO | None = None,
stderr: TextIO | None = None,
) -> str:
"""Execute python code on the remote side."""
await self._execute_code("code", code, stdout, stderr)
[docs]
async def runcode_background(
self,
code: str,
stdout: TextIO | None = None,
stderr: TextIO | None = None,
) -> str:
"""Execute Python code in a separate thread on the remote side."""
await self._execute_code("background", code, stdout, stderr)
[docs]
async def shell(
self,
code: str,
stdout: TextIO | None = None,
stderr: TextIO | None = None,
) -> str:
"""Execute shell command on the remote side."""
await self._execute_code("shell", code, stdout, stderr)
[docs]
async def get(self, expression: str):
"""Evaluate a Python expression remotely and return its Python value."""
async with self.connection as ssh:
async with ssh.create_process("value") as process:
process.stdin.write(expression)
process.stdin.write_eof()
message = await process.stdout.read()
await process.wait()
if not message:
stderr = await process.stderr.read()
message = "Remote value command returned no output."
if stderr:
message += f"\nRemote stderr:\n{stderr}"
raise RuntimeError(message)
return loads_value(message)
[docs]
async def upload(self, localpaths: list[str], remotepath) -> None:
"""Upload files and directories to remote via SFTP."""
async with self.connection as ssh:
async with ssh.start_sftp_client() as sftp:
await sftp.put(
localpaths=localpaths,
remotepath=remotepath,
recurse=True,
progress_handler=self.sftp_progress_handler,
)
[docs]
async def download(self, remotepaths: list[str], localpath) -> None:
"""Download files and directories from remote via SFTP."""
async with self.connection as ssh:
async with ssh.start_sftp_client() as sftp:
await sftp.get(
remotepaths=remotepaths,
localpath=localpath,
recurse=True,
sparse=False,
block_size=256 * 1024,
max_requests=16,
progress_handler=self.sftp_progress_handler,
)
[docs]
def sftp_progress_handler(self, srcpath, dstpath, copied, total):
"""Log SFTP transfer progress."""
percent = copied / total * 100 if total else 100
logger.debug(
"SFTP progress: %s -> %s: %s/%s bytes (%.1f%%)",
srcpath,
dstpath,
copied,
total,
percent,
)
async def _execute_code(
self,
command: str,
code: str,
stdout: TextIO | None = None,
stderr: TextIO | None = None,
):
"""Execute command with a code on the remote side."""
if stdout is None:
stdout = sys.stdout
if stderr is None:
stderr = sys.stderr
async with self.connection as ssh:
async with ssh.create_process(command) as process:
process.stdin.write(code)
# Remote handlers read the submitted code from stdin. Signal
# end-of-input so they can start or finish execution instead
# of waiting for more code.
process.stdin.write_eof()
async def forward_output(reader, writer):
# Stream line-by-line for long-running commands such as
# `%there log`. readline() also returns the final partial
# line at EOF, so output without a trailing newline is kept.
while data := await reader.readline():
writer.write(data)
if hasattr(writer, "flush"):
writer.flush()
try:
await asyncio.gather(
forward_output(process.stdout, stdout),
forward_output(process.stderr, stderr),
)
await process.wait()
except asyncio.CancelledError:
process.terminate()
# Try to close the remote channel cleanly, but keep
# cancellation bounded from the caller's perspective.
with contextlib.suppress(Exception):
await asyncio.wait_for(process.wait(), timeout=1)
raise