Coverage for bzfs_main / util / connection.py: 99%
315 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 12:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 12:49 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Efficient thread-safe SSH command client; See run_ssh_command() and refresh_ssh_connection_if_necessary() and class
16ConnectionPool and class ConnectionLease.
18Can be configured to reuse multiplexed SSH connections for low latency, even on fresh process startup, for example leading to
19ballpark 3-5ms total time for running `/bin/echo hello` end-to-end over SSH on LAN, which requires two (sequential) network
20round trips (one for CHANNEL_OPEN, plus a subsequent one for CHANNEL_REQUEST).
21Has zero dependencies beyond the standard OpenSSH client CLI (`ssh`); also works with `hpnssh`. The latter uses larger TCP
22window sizes for best throughput over high speed long distance networks, aka paths with large bandwidth-delay product. Also
23see https://youtu.be/fcHXOgl3dis?t=473 and https://gist.github.com/rapier1/325de17bbb85f1ce663ccb866ce22639
25Example usage:
27import logging
28from subprocess import DEVNULL, PIPE
29from bzfs_main.util.connection import ConnectionPool, create_simple_minijob, create_simple_miniremote
30from bzfs_main.util.retry import Retry, RetryPolicy, call_with_retries
32log = logging.getLogger(__name__)
33remote = create_simple_miniremote(log=log, ssh_user_host="alice@127.0.0.1")
34connection_pool = ConnectionPool(remote, connpool_name="example")
35try:
36 job = create_simple_minijob()
37 retry_policy = RetryPolicy(
38 max_retries=10,
39 min_sleep_secs=0,
40 initial_max_sleep_secs=0.125,
41 max_sleep_secs=10,
42 max_elapsed_secs=60,
43 )
45 def run_cmd(retry: Retry) -> str:
46 with connection_pool.connection() as conn:
47 stdout: str = conn.run_ssh_command(
48 cmd=["echo", "hello"], job=job, check=True, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True
49 ).stdout
50 return stdout
52 stdout = call_with_retries(fn=run_cmd, policy=retry_policy, log=log)
53 print(f"stdout: {stdout}")
54finally:
55 connection_pool.shutdown()
56"""
58from __future__ import (
59 annotations,
60)
61import contextlib
62import copy
63import logging
64import os
65import shlex
66import socket
67import stat
68import subprocess
69import threading
70import time
71from collections.abc import (
72 Iterator,
73)
74from dataclasses import (
75 dataclass,
76)
77from subprocess import (
78 DEVNULL,
79 PIPE,
80)
81from typing import (
82 Any,
83 Final,
84 Protocol,
85 final,
86 runtime_checkable,
87)
89from bzfs_main.util.connection_lease import (
90 ConnectionLease,
91 ConnectionLeaseManager,
92)
93from bzfs_main.util.retry import (
94 RetryableError,
95)
96from bzfs_main.util.utils import (
97 LOG_TRACE,
98 SHELL_CHARS_AND_SLASH,
99 SmallPriorityQueue,
100 Subprocesses,
101 die,
102 get_home_directory,
103 list_formatter,
104 sha256_urlsafe_base64,
105 stderr_to_str,
106)
108# constants:
109SHARED: Final[str] = "shared"
110DEDICATED: Final[str] = "dedicated"
113#############################################################################
114@runtime_checkable
115class MiniJob(Protocol):
116 """Minimal Job interface required by the connections module; for loose coupling."""
118 timeout_nanos: int | None # timestamp aka instant in time
119 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only
120 subprocesses: Subprocesses
123#############################################################################
124@runtime_checkable
125class MiniParams(Protocol):
126 """Minimal Params interface used by the connections module; for loose coupling."""
128 log: logging.Logger
129 ssh_program: str # name or path of executable; "hpnssh" is also valid
132#############################################################################
133@runtime_checkable
134class MiniRemote(Protocol):
135 """Minimal Remote interface used by the connections module; for loose coupling."""
137 params: MiniParams
138 location: str # "src" or "dst"
139 ssh_user_host: str # use the empty string to indicate local mode (no ssh)
140 ssh_extra_opts: tuple[str, ...]
141 reuse_ssh_connection: bool
142 ssh_control_persist_secs: int
143 ssh_control_persist_margin_secs: int
144 ssh_exit_on_shutdown: bool
145 ssh_socket_dir: str
147 def is_ssh_available(self) -> bool:
148 """Return True if the ssh client program required for this remote is available on the local host."""
150 def local_ssh_command(self, socket_file: str | None) -> tuple[list[str], str | None]:
151 """Returns the ssh CLI command to run locally in order to talk to the remote host; This excludes the (trailing)
152 command to run on the remote host, which will be appended later; also returns the effective ControlPath used by the
153 ssh CLI command, or ``None`` when SSH multiplexing is not active."""
155 def cache_namespace(self) -> str:
156 """Returns cache namespace string which is a stable, unique directory component for caches that distinguishes
157 endpoints by username+host+port+ssh_config_file where applicable, and uses '-' when no user/host is present (local
158 mode)."""
161#############################################################################
162def create_simple_miniremote(
163 *,
164 log: logging.Logger,
165 ssh_user_host: str = "", # option passed to `ssh` CLI; empty string indicates local mode
166 ssh_port: int | None = None, # option passed to `ssh -p` CLI
167 ssh_extra_opts: list[str] | None = None, # optional args passed to `ssh` CLI
168 ssh_verbose: bool = False, # option passed to `ssh -v` CLI
169 ssh_config_file: str = "", # option passed to `ssh -F` CLI; path to ssh_config(5) file; e.g /path/to/homedir/.ssh/config
170 ssh_cipher: str = "^aes256-gcm@openssh.com", # option passed to `ssh -c` CLI
171 ssh_connect_timeout_secs: int | None = None, # option passed to `ssh -oConnectTimeout=N`; default is system TCP timeout
172 ssh_program: str = "ssh", # name or path of CLI executable; "hpnssh" is also valid
173 reuse_ssh_connection: bool = True,
174 ssh_control_persist_secs: int = 600,
175 ssh_control_persist_margin_secs: int = 2,
176 ssh_socket_dir: str = os.path.join(get_home_directory(), ".ssh", "bzfs"),
177 location: str = "dst",
178) -> MiniRemote:
179 """Factory that returns a simple implementation of the MiniRemote interface."""
181 @dataclass(frozen=True) # aka immutable
182 @final
183 class SimpleMiniParams(MiniParams):
184 log: logging.Logger
185 ssh_program: str
187 @dataclass(frozen=True) # aka immutable
188 @final
189 class SimpleMiniRemote(MiniRemote):
190 params: MiniParams
191 location: str # "src" or "dst"
192 ssh_user_host: str
193 ssh_extra_opts: tuple[str, ...]
194 reuse_ssh_connection: bool
195 ssh_control_persist_secs: int
196 ssh_control_persist_margin_secs: int
197 ssh_exit_on_shutdown: bool
198 ssh_socket_dir: str
199 ssh_port: int | None
200 ssh_config_file: str
201 ssh_config_file_hash: str
203 def is_ssh_available(self) -> bool:
204 return True
206 def local_ssh_command(self, socket_file: str | None) -> tuple[list[str], str | None]:
207 if not self.ssh_user_host:
208 return [], None # local mode
209 ssh_cmd: list[str] = [self.params.ssh_program]
210 ssh_cmd.extend(self.ssh_extra_opts)
211 socket_path: str | None = None
212 if self.reuse_ssh_connection and socket_file:
213 ssh_cmd.append("-S")
214 ssh_cmd.append(socket_file)
215 socket_path = socket_file
216 ssh_cmd.append(self.ssh_user_host)
217 return ssh_cmd, socket_path
219 def cache_namespace(self) -> str:
220 if not self.ssh_user_host:
221 return "-" # local mode
222 return f"{self.ssh_user_host}#{self.ssh_port or ''}#{self.ssh_config_file_hash}"
224 def validate_userhost(userhost: str) -> None:
225 invalid_chars: str = SHELL_CHARS_AND_SLASH
226 uh: str = userhost.replace("@", "", 1)
227 if (not uh) or userhost.startswith("-") or ".." in userhost or any(c.isspace() or c in invalid_chars for c in uh):
228 raise ValueError(f"Invalid [user@]host: '{userhost}'")
230 if log is None:
231 raise ValueError("log must not be None")
232 if not ssh_program:
233 raise ValueError("ssh_program must be a non-empty string")
234 if location not in ("src", "dst"):
235 raise ValueError("location must be 'src' or 'dst'")
236 if ssh_user_host:
237 validate_userhost(ssh_user_host)
238 if ssh_control_persist_secs < 1:
239 raise ValueError("ssh_control_persist_secs must be >= 1")
240 params: MiniParams = SimpleMiniParams(log=log, ssh_program=ssh_program)
242 ssh_extra_opts = ( # disable interactive password prompts and X11 forwarding and pseudo-terminal allocation
243 ["-oBatchMode=yes", "-oServerAliveInterval=0", "-x", "-T"] if ssh_extra_opts is None else list(ssh_extra_opts)
244 )
245 ssh_extra_opts += ["-v"] if ssh_verbose else []
246 ssh_extra_opts += ["-F", ssh_config_file] if ssh_config_file else []
247 ssh_extra_opts += ["-c", ssh_cipher] if ssh_cipher else []
248 ssh_extra_opts += ["-p", str(ssh_port)] if ssh_port is not None else []
249 ssh_extra_opts += [] if ssh_connect_timeout_secs is None else [f"-oConnectTimeout={max(0, ssh_connect_timeout_secs)}s"]
250 ssh_config_file_hash = sha256_urlsafe_base64(os.path.abspath(ssh_config_file), padding=False) if ssh_config_file else ""
251 return SimpleMiniRemote(
252 params=params,
253 location=location,
254 ssh_user_host=ssh_user_host,
255 ssh_extra_opts=tuple(ssh_extra_opts),
256 reuse_ssh_connection=reuse_ssh_connection,
257 ssh_control_persist_secs=ssh_control_persist_secs,
258 ssh_control_persist_margin_secs=ssh_control_persist_margin_secs,
259 ssh_exit_on_shutdown=False,
260 ssh_socket_dir=ssh_socket_dir,
261 ssh_port=ssh_port,
262 ssh_config_file=ssh_config_file,
263 ssh_config_file_hash=ssh_config_file_hash,
264 )
267def create_simple_minijob(
268 *, timeout_duration_secs: float | None = None, subprocesses: Subprocesses | None = None
269) -> MiniJob:
270 """Factory that returns a simple implementation of the MiniJob interface."""
272 @dataclass(frozen=True) # aka immutable
273 @final
274 class SimpleMiniJob(MiniJob):
275 timeout_nanos: int | None # timestamp aka instant in time
276 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only
277 subprocesses: Subprocesses
279 t_duration_nanos: int | None = None if timeout_duration_secs is None else int(timeout_duration_secs * 1_000_000_000)
280 timeout_nanos: int | None = None if t_duration_nanos is None else time.monotonic_ns() + t_duration_nanos
281 subprocesses = Subprocesses() if subprocesses is None else subprocesses
282 return SimpleMiniJob(timeout_nanos=timeout_nanos, timeout_duration_nanos=t_duration_nanos, subprocesses=subprocesses)
285#############################################################################
286def timeout(job: MiniJob) -> float | None:
287 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur."""
288 timeout_nanos: int | None = job.timeout_nanos
289 if timeout_nanos is None:
290 return None # never raise a timeout
291 assert job.timeout_duration_nanos is not None
292 delta_nanos: int = timeout_nanos - time.monotonic_ns()
293 if delta_nanos <= 0:
294 raise subprocess.TimeoutExpired("_timeout", timeout=job.timeout_duration_nanos / 1_000_000_000)
295 return delta_nanos / 1_000_000_000 # seconds
298def squote(remote: MiniRemote, arg: str) -> str:
299 """Quotes an argument only when running remotely over ssh."""
300 assert arg is not None
301 return shlex.quote(arg) if remote.ssh_user_host else arg
304def dquote(arg: str) -> str:
305 """Shell-escapes backslash and double quotes and dollar and backticks, then surrounds with double quotes; For an example
306 how to safely construct and quote complex shell pipeline commands for use over SSH, see
307 replication.py:_prepare_zfs_send_receive()"""
308 arg = arg.replace("\\", "\\\\").replace('"', '\\"').replace("$", "\\$").replace("`", "\\`")
309 return '"' + arg + '"'
312#############################################################################
313@dataclass(order=True, repr=False)
314@final
315class Connection:
316 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection."""
318 _free: int # sort order evens out the number of concurrent sessions among the TCP connections
319 _last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot; also ensures no dupes
321 def __init__(
322 self,
323 remote: MiniRemote,
324 max_concurrent_ssh_sessions_per_tcp_connection: int,
325 *,
326 lease: ConnectionLease | None = None,
327 ) -> None:
328 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
329 self._remote: Final[MiniRemote] = remote
330 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection
331 self._free: int = max_concurrent_ssh_sessions_per_tcp_connection
332 self._last_modified: int = 0 # monotonically increasing
333 self._last_refresh_time: int = 1 - (1 << 150) # negative infinity for all practical purposes
334 self._lock: Final[threading.Lock] = threading.Lock()
335 self._reuse_ssh_connection: Final[bool] = remote.reuse_ssh_connection
336 self._connection_lease: Final[ConnectionLease | None] = lease
337 ssh_cmd, ssh_socket_path = remote.local_ssh_command(
338 None if self._connection_lease is None else self._connection_lease.socket_path
339 )
340 self._ssh_socket_path: Final[str | None] = ssh_socket_path
341 self._ssh_cmd: Final[list[str]] = ssh_cmd
342 self._ssh_cmd_quoted: Final[list[str]] = [shlex.quote(item) for item in self._ssh_cmd]
344 @property
345 def ssh_cmd(self) -> list[str]:
346 return self._ssh_cmd.copy()
348 @property
349 def ssh_cmd_quoted(self) -> list[str]:
350 return self._ssh_cmd_quoted.copy()
352 def __repr__(self) -> str:
353 return str({"free": self._free})
355 def run_ssh_command(
356 self,
357 cmd: list[str],
358 *,
359 job: MiniJob,
360 loglevel: int = logging.INFO,
361 is_dry: bool = False,
362 **kwargs: Any, # optional low-level keyword args to be forwarded to subprocess.run()
363 ) -> subprocess.CompletedProcess:
364 """Runs the given CLI cmd via ssh on the given remote, and returns CompletedProcess including stdout and stderr.
366 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host
367 (``remote.local_ssh_command(...)[0]``) and the command to run on the given remote host (``cmd``).
369 Note: When executing on a remote host (remote.ssh_user_host is set), cmd arguments are pre-quoted with shlex.quote to
370 safely traverse the ssh "remote shell" boundary, as ssh concatenates argv into a single remote shell string. In local
371 mode (no remote.ssh_user_host) argv is executed directly without an intermediate shell.
372 """
373 if not cmd:
374 raise ValueError("run_ssh_command requires a non-empty cmd list")
375 log: logging.Logger = self._remote.params.log
376 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd]
377 ssh_cmd: list[str] = self._ssh_cmd
378 if self._remote.ssh_user_host:
379 self.refresh_ssh_connection_if_necessary(job)
380 cmd = quoted_cmd
381 msg: str = "Would execute: %s" if is_dry else "Executing: %s"
382 log.log(loglevel, msg, list_formatter(self._ssh_cmd_quoted + quoted_cmd, lstrip=True))
383 if is_dry:
384 return subprocess.CompletedProcess(ssh_cmd + cmd, returncode=0, stdout=None, stderr=None)
385 else:
386 sp: Subprocesses = job.subprocesses
387 return sp.subprocess_run(ssh_cmd + cmd, timeout=timeout(job), log=log, **kwargs)
389 def refresh_ssh_connection_if_necessary(self, job: MiniJob) -> None:
390 """Maintain or create an ssh master connection for low latency reuse."""
391 remote: MiniRemote = self._remote
392 p: MiniParams = remote.params
393 log: logging.Logger = p.log
394 if not remote.ssh_user_host:
395 return # we're in local mode; no ssh required
396 if not remote.is_ssh_available():
397 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!")
398 if not remote.reuse_ssh_connection:
399 return
401 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and
402 # 'ssh -S -M -oControlPersist=90s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing
403 # and https://chessman7.substack.com/p/how-ssh-multiplexing-reuses-master
404 control_limit_nanos: int = (remote.ssh_control_persist_secs - remote.ssh_control_persist_margin_secs) * 1_000_000_000
405 socket_path: str | None = self._ssh_socket_path
406 with self._lock:
407 if time.monotonic_ns() < self._last_refresh_time + control_limit_nanos:
408 if socket_path is None or self._is_ssh_control_socket_usable(socket_path):
409 return # ssh master is alive, reuse its TCP connection (this is the common case and the ultra-fast path)
410 ssh_cmd: list[str] = self._ssh_cmd
411 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host
412 ssh_sock_cmd += ["-O", "check", remote.ssh_user_host]
413 # extend lifetime of ssh master by $ssh_control_persist_secs via `ssh -O check` if master is still running.
414 # `ssh -S /path/to/socket -O check` doesn't talk over the network, hence is still a low latency fast path.
415 sp: Subprocesses = job.subprocesses
416 t: float | None = timeout(job)
417 if sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, timeout=t, log=log).returncode == 0:
418 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_sock_cmd))
419 else: # ssh master is not alive; start a new master:
420 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_sock_cmd))
421 if socket_path is not None and not self._is_ssh_control_socket_usable(socket_path):
422 with contextlib.suppress(OSError):
423 os.unlink(socket_path) # if present, remove stale ssh control socket path before master restart
424 ssh_control_persist_secs: int = max(1, remote.ssh_control_persist_secs)
425 if any(opt.startswith("-v") and all(char == "v" for char in opt[1:]) for opt in remote.ssh_extra_opts):
426 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the
427 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ...
428 ssh_control_persist_secs = min(1, ssh_control_persist_secs) # tell ssh block as briefly as possible (1s)
429 ssh_sock_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host
430 ssh_sock_cmd += ["-M", f"-oControlPersist={ssh_control_persist_secs}s", remote.ssh_user_host, "exit"]
431 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_sock_cmd))
432 t = timeout(job)
433 try:
434 sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, check=True, timeout=t, log=log)
435 except subprocess.CalledProcessError as e:
436 log.error("%s", stderr_to_str(e.stderr).rstrip())
437 raise RetryableError(
438 f"Cannot ssh into remote host via '{' '.join(ssh_sock_cmd)}'. Fix ssh configuration first, "
439 "considering diagnostic log file output from running with -v -v -v.",
440 display_msg="ssh connect",
441 ) from e
442 self._last_refresh_time = time.monotonic_ns()
443 if self._connection_lease is not None:
444 self._connection_lease.set_socket_mtime_to_now()
446 def _is_ssh_control_socket_usable(self, socket_path: str) -> bool:
447 """To improve ssh perf, check whether a control socket path is a live Unix-domain listener; this helps detect stale
448 socket files that still exist after master crashes.
450 _is_ssh_control_socket_usable() is ~300x faster than `ssh ... -O check`: ~5 microseconds vs ~1-2 milliseconds.
451 """
452 try:
453 st_mode: int = os.stat(socket_path, follow_symlinks=False).st_mode
454 if not stat.S_ISSOCK(st_mode):
455 return False
456 except OSError:
457 return False
458 try:
459 with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
460 return sock.connect_ex(socket_path) == 0
461 except OSError:
462 return False
464 def _increment_free(self, value: int) -> None:
465 """Adjusts the count of available SSH slots."""
466 self._free += value
467 assert self._free >= 0
468 assert self._free <= self._capacity
470 def _is_full(self) -> bool:
471 """Returns True if no more SSH sessions may be opened over this TCP connection."""
472 return self._free <= 0
474 def _update_last_modified(self, last_modified: int) -> None:
475 """Records when the connection was last used."""
476 self._last_modified = last_modified
478 def shutdown(self, msg_prefix: str) -> None:
479 """Closes the underlying SSH master connection and releases the corresponding connection lease."""
480 ssh_cmd: list[str] = self._ssh_cmd
481 if ssh_cmd and self._reuse_ssh_connection:
482 if self._connection_lease is None:
483 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]]
484 log = self._remote.params.log
485 log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_sock_cmd))
486 try:
487 proc: subprocess.CompletedProcess = subprocess.run(
488 ssh_sock_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=0.1
489 )
490 except subprocess.TimeoutExpired as e: # harmless as master auto-exits after ssh_control_persist_secs anyway
491 log.log(LOG_TRACE, "Harmless ssh master connection shutdown timeout: %s", e)
492 else:
493 if proc.returncode != 0: # harmless for the same reason
494 log.log(LOG_TRACE, "Harmless ssh master connection shutdown issue: %s", proc.stderr.rstrip())
495 else:
496 self._connection_lease.release()
499#############################################################################
500class ConnectionPool:
501 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse;
502 Note that max_concurrent_ssh_sessions_per_tcp_connection must not be larger than the server-side sshd_config(5)
503 MaxSessions parameter (which defaults to 10, see https://manpages.ubuntu.com/manpages/man5/sshd_config.5.html)."""
505 def __init__(
506 self, remote: MiniRemote, connpool_name: str, max_concurrent_ssh_sessions_per_tcp_connection: int = 8
507 ) -> None:
508 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
509 self._remote: Final[MiniRemote] = copy.copy(remote) # shallow copy for immutability (Remote is mutable)
510 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection
511 self._connpool_name: Final[str] = connpool_name
512 self._priority_queue: Final[SmallPriorityQueue[Connection]] = SmallPriorityQueue(
513 reverse=True # sorted by #free slots and last_modified
514 )
515 self._last_modified: int = 0 # monotonically increasing sequence number
516 self._lock: Final[threading.Lock] = threading.Lock()
517 lease_mgr: ConnectionLeaseManager | None = None
518 if self._remote.ssh_user_host and self._remote.reuse_ssh_connection and not self._remote.ssh_exit_on_shutdown:
519 lease_mgr = ConnectionLeaseManager(
520 root_dir=self._remote.ssh_socket_dir,
521 namespace=f"{self._remote.location}#{self._remote.cache_namespace()}#{self._connpool_name}",
522 ssh_control_persist_secs=max(90 * 60, 2 * self._remote.ssh_control_persist_secs + 2),
523 log=self._remote.params.log,
524 )
525 self._lease_mgr: Final[ConnectionLeaseManager | None] = lease_mgr
527 @contextlib.contextmanager
528 def connection(self) -> Iterator[Connection]:
529 """Context manager that yields a connection from the pool and automatically returns it on __exit__."""
530 conn: Connection = self.get_connection()
531 try:
532 yield conn
533 finally:
534 self.return_connection(conn)
536 def get_connection(self) -> Connection:
537 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue
538 while it is "checked out", and that identical Connection object is later, on return_connection(), temporarily removed
539 from the priority queue, updated with an incremented "free" slot count and then immediately reinserted into the
540 priority queue.
542 In effect, any Connection object remains intentionally contained in the priority queue at all times. This design
543 keeps ordering/fairness accurate while avoiding duplicate Connection instances.
544 """
545 with self._lock:
546 conn = self._priority_queue.pop() if len(self._priority_queue) > 0 else None
547 if conn is None or conn._is_full(): # noqa: SLF001 # pylint: disable=protected-access
548 if conn is not None:
549 self._priority_queue.push(conn)
550 conn = self._new_connection() # add a new connection
551 self._last_modified += 1
552 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access
553 conn._increment_free(-1) # noqa: SLF001 # pylint: disable=protected-access
554 self._priority_queue.push(conn)
555 return conn
557 def _new_connection(self) -> Connection:
558 lease: ConnectionLease | None = None if self._lease_mgr is None else self._lease_mgr.acquire()
559 return Connection(self._remote, self._capacity, lease=lease)
561 def return_connection(self, conn: Connection) -> None:
562 """Returns the given connection to the pool and updates its priority."""
563 assert conn is not None
564 with self._lock:
565 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue
566 if self._priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called
567 conn._increment_free(1) # noqa: SLF001 # pylint: disable=protected-access
568 self._last_modified += 1
569 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access
570 self._priority_queue.push(conn)
572 def shutdown(self, msg_prefix: str = "") -> None:
573 """Closes all SSH connections managed by this pool."""
574 with self._lock:
575 try:
576 if self._remote.reuse_ssh_connection:
577 msg_prefix = msg_prefix + "/" + self._connpool_name
578 for conn in self._priority_queue:
579 conn.shutdown(msg_prefix)
580 finally:
581 self._priority_queue.clear()
583 def __repr__(self) -> str:
584 with self._lock:
585 queue = self._priority_queue
586 return str({"capacity": self._capacity, "queue_len": len(queue), "queue": queue})
589#############################################################################
590@final
591class ConnectionPools:
592 """A bunch of named connection pools with various multiplexing capacities."""
594 def __init__(self, remote: MiniRemote, *, capacities: dict[str, int]) -> None:
595 """Creates one connection pool per name with the given capacities."""
596 self._pools: Final[dict[str, ConnectionPool]] = {
597 name: ConnectionPool(remote, name, capacity) for name, capacity in capacities.items()
598 }
600 def __repr__(self) -> str:
601 return str(self._pools)
603 def pool(self, name: str) -> ConnectionPool:
604 """Returns the pool associated with the given name."""
605 return self._pools[name]
607 def shutdown(self, msg_prefix: str = "") -> None:
608 """Shuts down every contained pool."""
609 for pool in self._pools.values():
610 pool.shutdown(msg_prefix)