Source code for herethere.there.client

"""herethere.there.client"""
from __future__ import annotations
from contextlib import AbstractAsyncContextManager
import sys
from typing import List, Optional, TextIO

import asyncssh

from herethere.everywhere.config import ConnectionConfig
from herethere.everywhere.logging import logger


class PersistentConnection(AbstractAsyncContextManager):
    """SSH connection async context manager with automatic reconnection."""

    def __init__(self):
        self.config: Optional[ConnectionConfig] = None
        self.connection: Optional[asyncssh.SSHClientConnection] = None

    async def __aenter__(self):
        if await self.check_connection():
            return self.connection
        if self.connection:
            self.close()
        return await self.reconnect()

    async def __aexit__(self, *exc_info):
        pass

    def close(self):
        """Close current connection."""
        if self.connection:
            try:
                self.connection.close()
            except asyncssh.Error:
                pass
        self.connection = None

    async def configure(self, config: ConnectionConfig):
        """Apply new connection config."""
        self.close()
        self.config = config
        return await self.__aenter__()

    async def check_connection(self) -> bool:
        """Check connection is active."""
        if self.connection:
            try:
                await self.connection.run("ping", check=True)
            except asyncssh.Error:
                logger.debug("SSH connection ping failed.")
            else:
                return True
        return False

    async def reconnect(self):
        """Establish connection."""
        if not self.config:
            raise Exception("Connection is not configured.")
        self.connection = await asyncssh.connect(**self.config.asdict, known_hosts=None)
        return self.connection


[docs]class Client: """Client for remote interpreter.""" def __init__(self): self.connection = PersistentConnection()
[docs] async def copy(self) -> Client: """Return a copy of the configured connection.""" client = Client() await client.connect(self.connection.config) return client
[docs] async def connect(self, config: ConnectionConfig): """Connect to remote.""" await self.connection.configure(config)
[docs] async def disconnect(self): """Disconnect from the remote.""" self.connection.close()
[docs] async def runcode( self, code: str, stdout: Optional[TextIO] = None, stderr: Optional[TextIO] = None, ) -> str: """Execute python code on the remote side.""" await self._execute_code("code", code, stdout, stderr)
[docs] async def runcode_background( self, code: str, stdout: Optional[TextIO] = None, stderr: Optional[TextIO] = None, ) -> str: """Execute python code in a separate thread on the remote side.""" await self._execute_code("background", code, stdout, stderr)
[docs] async def shell( self, code: str, stdout: Optional[TextIO] = None, stderr: Optional[TextIO] = None, ) -> str: """Execute shell command on the remote side.""" await self._execute_code("shell", code, stdout, stderr)
[docs] async def upload(self, localpaths: List[str], remotepath) -> str: """Upload files and directories to remote via SFTP.""" async with self.connection as ssh: async with ssh.start_sftp_client() as sftp: await sftp.put( localpaths=localpaths, remotepath=remotepath, recurse=True, progress_handler=self.sftp_progress_handler, )
[docs] def sftp_progress_handler(self, *args, **kwargs): """SFTP uploading progress handler.""" logger.debug("SFTP progress: %s %s", args, kwargs)
async def _execute_code( self, command: str, code: str, stdout: Optional[TextIO] = None, stderr: Optional[TextIO] = None, ): """Execute command with a code on the remote side.""" if stdout is None: stdout = sys.stdout if stderr is None: stderr = sys.stderr async with self.connection as ssh: async with ssh.create_process(command) as process: process.stdin.write(code) line = True while line: line = await process.stdout.readline() if line: stdout.write(line) stderr.write(await process.stderr.read())