Coverage for bzfs_main / util / connection.py: 99%
292 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 10:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 10:16 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Efficient thread-safe SSH command client; See run_ssh_command() and refresh_ssh_connection_if_necessary() and class
16ConnectionPool and class ConnectionLease.
18Can be configured to reuse multiplexed SSH connections for low latency, even on fresh process startup, for example leading to
19ballpark 3-5ms total time for running `/bin/echo hello` end-to-end over SSH on LAN, which requires two (sequential) network
20round trips (one for CHANNEL_OPEN, plus a subsequent one for CHANNEL_REQUEST).
21Has zero dependencies beyond the standard OpenSSH client CLI (`ssh`); also works with `hpnssh`. The latter uses larger TCP
22window sizes for best throughput over high speed long distance networks, aka paths with large bandwidth-delay product. Also
23see https://youtu.be/fcHXOgl3dis?t=473 and https://gist.github.com/rapier1/325de17bbb85f1ce663ccb866ce22639
25Example usage:
27import logging
28from subprocess import DEVNULL, PIPE
29from bzfs_main.util.connection import ConnectionPool, create_simple_minijob, create_simple_miniremote
30from bzfs_main.util.retry import Retry, RetryPolicy, call_with_retries
32log = logging.getLogger(__name__)
33remote = create_simple_miniremote(log=log, ssh_user_host="alice@127.0.0.1")
34connection_pool = ConnectionPool(remote, connpool_name="example")
35try:
36 job = create_simple_minijob()
37 retry_policy = RetryPolicy(
38 max_retries=10,
39 min_sleep_secs=0,
40 initial_max_sleep_secs=0.125,
41 max_sleep_secs=10,
42 max_elapsed_secs=60,
43 )
45 def run_cmd(retry: Retry) -> str:
46 with connection_pool.connection() as conn:
47 stdout: str = conn.run_ssh_command(
48 cmd=["echo", "hello"], job=job, check=True, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True
49 ).stdout
50 return stdout
52 stdout = call_with_retries(fn=run_cmd, policy=retry_policy, log=log)
53 print(f"stdout: {stdout}")
54finally:
55 connection_pool.shutdown()
56"""
58from __future__ import (
59 annotations,
60)
61import contextlib
62import copy
63import logging
64import os
65import shlex
66import subprocess
67import threading
68import time
69from collections.abc import (
70 Iterator,
71)
72from dataclasses import (
73 dataclass,
74)
75from subprocess import (
76 DEVNULL,
77 PIPE,
78)
79from typing import (
80 Any,
81 Final,
82 Protocol,
83 final,
84 runtime_checkable,
85)
87from bzfs_main.util.connection_lease import (
88 ConnectionLease,
89 ConnectionLeaseManager,
90)
91from bzfs_main.util.retry import (
92 RetryableError,
93)
94from bzfs_main.util.utils import (
95 LOG_TRACE,
96 SHELL_CHARS_AND_SLASH,
97 SmallPriorityQueue,
98 Subprocesses,
99 die,
100 get_home_directory,
101 list_formatter,
102 sha256_urlsafe_base64,
103 stderr_to_str,
104)
106# constants:
107SHARED: Final[str] = "shared"
108DEDICATED: Final[str] = "dedicated"
111#############################################################################
112@runtime_checkable
113class MiniJob(Protocol):
114 """Minimal Job interface required by the connections module; for loose coupling."""
116 timeout_nanos: int | None # timestamp aka instant in time
117 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only
118 subprocesses: Subprocesses
121#############################################################################
122@runtime_checkable
123class MiniParams(Protocol):
124 """Minimal Params interface used by the connections module; for loose coupling."""
126 log: logging.Logger
127 ssh_program: str # name or path of executable; "hpnssh" is also valid
130#############################################################################
131@runtime_checkable
132class MiniRemote(Protocol):
133 """Minimal Remote interface used by the connections module; for loose coupling."""
135 params: MiniParams
136 location: str # "src" or "dst"
137 ssh_user_host: str # use the empty string to indicate local mode (no ssh)
138 ssh_extra_opts: tuple[str, ...]
139 reuse_ssh_connection: bool
140 ssh_control_persist_secs: int
141 ssh_control_persist_margin_secs: int
142 ssh_exit_on_shutdown: bool
143 ssh_socket_dir: str
145 def is_ssh_available(self) -> bool:
146 """Return True if the ssh client program required for this remote is available on the local host."""
148 def local_ssh_command(self, socket_file: str | None) -> list[str]:
149 """Returns the ssh CLI command to run locally in order to talk to the remote host; This excludes the (trailing)
150 command to run on the remote host, which will be appended later."""
152 def cache_namespace(self) -> str:
153 """Returns cache namespace string which is a stable, unique directory component for caches that distinguishes
154 endpoints by username+host+port+ssh_config_file where applicable, and uses '-' when no user/host is present (local
155 mode)."""
158#############################################################################
159def create_simple_miniremote(
160 *,
161 log: logging.Logger,
162 ssh_user_host: str = "", # option passed to `ssh` CLI; empty string indicates local mode
163 ssh_port: int | None = None, # option passed to `ssh -p` CLI
164 ssh_extra_opts: list[str] | None = None, # optional args passed to `ssh` CLI
165 ssh_verbose: bool = False, # option passed to `ssh -v` CLI
166 ssh_config_file: str = "", # option passed to `ssh -F` CLI; path to ssh_config(5) file; e.g /path/to/homedir/.ssh/config
167 ssh_cipher: str = "^aes256-gcm@openssh.com", # option passed to `ssh -c` CLI
168 ssh_connect_timeout_secs: int | None = None, # option passed to `ssh -oConnectTimeout=N`; default is system TCP timeout
169 ssh_program: str = "ssh", # name or path of CLI executable; "hpnssh" is also valid
170 reuse_ssh_connection: bool = True,
171 ssh_control_persist_secs: int = 600,
172 ssh_control_persist_margin_secs: int = 2,
173 ssh_socket_dir: str = os.path.join(get_home_directory(), ".ssh", "bzfs"),
174 location: str = "dst",
175) -> MiniRemote:
176 """Factory that returns a simple implementation of the MiniRemote interface."""
178 @dataclass(frozen=True) # aka immutable
179 @final
180 class SimpleMiniParams(MiniParams):
181 log: logging.Logger
182 ssh_program: str
184 @dataclass(frozen=True) # aka immutable
185 @final
186 class SimpleMiniRemote(MiniRemote):
187 params: MiniParams
188 location: str # "src" or "dst"
189 ssh_user_host: str
190 ssh_extra_opts: tuple[str, ...]
191 reuse_ssh_connection: bool
192 ssh_control_persist_secs: int
193 ssh_control_persist_margin_secs: int
194 ssh_exit_on_shutdown: bool
195 ssh_socket_dir: str
196 ssh_port: int | None
197 ssh_config_file: str
198 ssh_config_file_hash: str
200 def is_ssh_available(self) -> bool:
201 return True
203 def local_ssh_command(self, socket_file: str | None) -> list[str]:
204 if not self.ssh_user_host:
205 return [] # local mode
206 ssh_cmd: list[str] = [self.params.ssh_program]
207 ssh_cmd.extend(self.ssh_extra_opts)
208 if self.reuse_ssh_connection and socket_file:
209 ssh_cmd.append("-S")
210 ssh_cmd.append(socket_file)
211 ssh_cmd.append(self.ssh_user_host)
212 return ssh_cmd
214 def cache_namespace(self) -> str:
215 if not self.ssh_user_host:
216 return "-" # local mode
217 return f"{self.ssh_user_host}#{self.ssh_port or ''}#{self.ssh_config_file_hash}"
219 def validate_userhost(userhost: str) -> None:
220 invalid_chars: str = SHELL_CHARS_AND_SLASH
221 uh: str = userhost.replace("@", "", 1)
222 if (not uh) or userhost.startswith("-") or ".." in userhost or any(c.isspace() or c in invalid_chars for c in uh):
223 raise ValueError(f"Invalid [user@]host: '{userhost}'")
225 if log is None:
226 raise ValueError("log must not be None")
227 if not ssh_program:
228 raise ValueError("ssh_program must be a non-empty string")
229 if location not in ("src", "dst"):
230 raise ValueError("location must be 'src' or 'dst'")
231 if ssh_user_host:
232 validate_userhost(ssh_user_host)
233 if ssh_control_persist_secs < 1:
234 raise ValueError("ssh_control_persist_secs must be >= 1")
235 params: MiniParams = SimpleMiniParams(log=log, ssh_program=ssh_program)
237 ssh_extra_opts = ( # disable interactive password prompts and X11 forwarding and pseudo-terminal allocation
238 ["-oBatchMode=yes", "-oServerAliveInterval=0", "-x", "-T"] if ssh_extra_opts is None else list(ssh_extra_opts)
239 )
240 ssh_extra_opts += ["-v"] if ssh_verbose else []
241 ssh_extra_opts += ["-F", ssh_config_file] if ssh_config_file else []
242 ssh_extra_opts += ["-c", ssh_cipher] if ssh_cipher else []
243 ssh_extra_opts += ["-p", str(ssh_port)] if ssh_port is not None else []
244 ssh_extra_opts += [] if ssh_connect_timeout_secs is None else [f"-oConnectTimeout={max(0, ssh_connect_timeout_secs)}s"]
245 ssh_config_file_hash = sha256_urlsafe_base64(os.path.abspath(ssh_config_file), padding=False) if ssh_config_file else ""
246 return SimpleMiniRemote(
247 params=params,
248 location=location,
249 ssh_user_host=ssh_user_host,
250 ssh_extra_opts=tuple(ssh_extra_opts),
251 reuse_ssh_connection=reuse_ssh_connection,
252 ssh_control_persist_secs=ssh_control_persist_secs,
253 ssh_control_persist_margin_secs=ssh_control_persist_margin_secs,
254 ssh_exit_on_shutdown=False,
255 ssh_socket_dir=ssh_socket_dir,
256 ssh_port=ssh_port,
257 ssh_config_file=ssh_config_file,
258 ssh_config_file_hash=ssh_config_file_hash,
259 )
262def create_simple_minijob(
263 *, timeout_duration_secs: float | None = None, subprocesses: Subprocesses | None = None
264) -> MiniJob:
265 """Factory that returns a simple implementation of the MiniJob interface."""
267 @dataclass(frozen=True) # aka immutable
268 @final
269 class SimpleMiniJob(MiniJob):
270 timeout_nanos: int | None # timestamp aka instant in time
271 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only
272 subprocesses: Subprocesses
274 t_duration_nanos: int | None = None if timeout_duration_secs is None else int(timeout_duration_secs * 1_000_000_000)
275 timeout_nanos: int | None = None if t_duration_nanos is None else time.monotonic_ns() + t_duration_nanos
276 subprocesses = Subprocesses() if subprocesses is None else subprocesses
277 return SimpleMiniJob(timeout_nanos=timeout_nanos, timeout_duration_nanos=t_duration_nanos, subprocesses=subprocesses)
280#############################################################################
281def timeout(job: MiniJob) -> float | None:
282 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur."""
283 timeout_nanos: int | None = job.timeout_nanos
284 if timeout_nanos is None:
285 return None # never raise a timeout
286 assert job.timeout_duration_nanos is not None
287 delta_nanos: int = timeout_nanos - time.monotonic_ns()
288 if delta_nanos <= 0:
289 raise subprocess.TimeoutExpired("_timeout", timeout=job.timeout_duration_nanos / 1_000_000_000)
290 return delta_nanos / 1_000_000_000 # seconds
293def squote(remote: MiniRemote, arg: str) -> str:
294 """Quotes an argument only when running remotely over ssh."""
295 assert arg is not None
296 return shlex.quote(arg) if remote.ssh_user_host else arg
299def dquote(arg: str) -> str:
300 """Shell-escapes backslash and double quotes and dollar and backticks, then surrounds with double quotes; For an example
301 how to safely construct and quote complex shell pipeline commands for use over SSH, see
302 replication.py:_prepare_zfs_send_receive()"""
303 arg = arg.replace("\\", "\\\\").replace('"', '\\"').replace("$", "\\$").replace("`", "\\`")
304 return '"' + arg + '"'
307#############################################################################
308@dataclass(order=True, repr=False)
309@final
310class Connection:
311 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection."""
313 _free: int # sort order evens out the number of concurrent sessions among the TCP connections
314 _last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot; also ensures no dupes
316 def __init__(
317 self,
318 remote: MiniRemote,
319 max_concurrent_ssh_sessions_per_tcp_connection: int,
320 *,
321 lease: ConnectionLease | None = None,
322 ) -> None:
323 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
324 self._remote: Final[MiniRemote] = remote
325 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection
326 self._free: int = max_concurrent_ssh_sessions_per_tcp_connection
327 self._last_modified: int = 0 # monotonically increasing
328 self._last_refresh_time: int = 0
329 self._lock: Final[threading.Lock] = threading.Lock()
330 self._reuse_ssh_connection: Final[bool] = remote.reuse_ssh_connection
331 self._connection_lease: Final[ConnectionLease | None] = lease
332 self._ssh_cmd: Final[list[str]] = remote.local_ssh_command(
333 None if self._connection_lease is None else self._connection_lease.socket_path
334 )
335 self._ssh_cmd_quoted: Final[list[str]] = [shlex.quote(item) for item in self._ssh_cmd]
337 @property
338 def ssh_cmd(self) -> list[str]:
339 return self._ssh_cmd.copy()
341 @property
342 def ssh_cmd_quoted(self) -> list[str]:
343 return self._ssh_cmd_quoted.copy()
345 def __repr__(self) -> str:
346 return str({"free": self._free})
348 def run_ssh_command(
349 self,
350 cmd: list[str],
351 *,
352 job: MiniJob,
353 loglevel: int = logging.INFO,
354 is_dry: bool = False,
355 **kwargs: Any, # optional low-level keyword args to be forwarded to subprocess.run()
356 ) -> subprocess.CompletedProcess:
357 """Runs the given CLI cmd via ssh on the given remote, and returns CompletedProcess including stdout and stderr.
359 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host
360 ($remote.local_ssh_command()) and the command to run on the given remote host ($cmd).
362 Note: When executing on a remote host (remote.ssh_user_host is set), cmd arguments are pre-quoted with shlex.quote to
363 safely traverse the ssh "remote shell" boundary, as ssh concatenates argv into a single remote shell string. In local
364 mode (no remote.ssh_user_host) argv is executed directly without an intermediate shell.
365 """
366 if not cmd:
367 raise ValueError("run_ssh_command requires a non-empty cmd list")
368 log: logging.Logger = self._remote.params.log
369 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd]
370 ssh_cmd: list[str] = self._ssh_cmd
371 if self._remote.ssh_user_host:
372 self.refresh_ssh_connection_if_necessary(job)
373 cmd = quoted_cmd
374 msg: str = "Would execute: %s" if is_dry else "Executing: %s"
375 log.log(loglevel, msg, list_formatter(self._ssh_cmd_quoted + quoted_cmd, lstrip=True))
376 if is_dry:
377 return subprocess.CompletedProcess(ssh_cmd + cmd, returncode=0, stdout=None, stderr=None)
378 else:
379 sp: Subprocesses = job.subprocesses
380 return sp.subprocess_run(ssh_cmd + cmd, timeout=timeout(job), log=log, **kwargs)
382 def refresh_ssh_connection_if_necessary(self, job: MiniJob) -> None:
383 """Maintain or create an ssh master connection for low latency reuse."""
384 remote: MiniRemote = self._remote
385 p: MiniParams = remote.params
386 log: logging.Logger = p.log
387 if not remote.ssh_user_host:
388 return # we're in local mode; no ssh required
389 if not remote.is_ssh_available():
390 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!")
391 if not remote.reuse_ssh_connection:
392 return
394 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and
395 # 'ssh -S -M -oControlPersist=90s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing
396 # and https://chessman7.substack.com/p/how-ssh-multiplexing-reuses-master
397 control_limit_nanos: int = (remote.ssh_control_persist_secs - remote.ssh_control_persist_margin_secs) * 1_000_000_000
398 with self._lock:
399 if time.monotonic_ns() < self._last_refresh_time + control_limit_nanos:
400 return # ssh master is alive, reuse its TCP connection (this is the common case and the ultra-fast path)
401 ssh_cmd: list[str] = self._ssh_cmd
402 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host
403 ssh_sock_cmd += ["-O", "check", remote.ssh_user_host]
404 # extend lifetime of ssh master by $ssh_control_persist_secs via `ssh -O check` if master is still running.
405 # `ssh -S /path/to/socket -O check` doesn't talk over the network, hence is still a low latency fast path.
406 sp: Subprocesses = job.subprocesses
407 t: float | None = timeout(job)
408 if sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, timeout=t, log=log).returncode == 0:
409 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_sock_cmd))
410 else: # ssh master is not alive; start a new master:
411 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_sock_cmd))
412 ssh_control_persist_secs: int = max(1, remote.ssh_control_persist_secs)
413 if any(opt.startswith("-v") and all(char == "v" for char in opt[1:]) for opt in remote.ssh_extra_opts):
414 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the
415 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ...
416 ssh_control_persist_secs = min(1, ssh_control_persist_secs) # tell ssh block as briefly as possible (1s)
417 ssh_sock_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host
418 ssh_sock_cmd += ["-M", f"-oControlPersist={ssh_control_persist_secs}s", remote.ssh_user_host, "exit"]
419 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_sock_cmd))
420 t = timeout(job)
421 try:
422 sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, check=True, timeout=t, log=log)
423 except subprocess.CalledProcessError as e:
424 log.error("%s", stderr_to_str(e.stderr).rstrip())
425 raise RetryableError(
426 f"Cannot ssh into remote host via '{' '.join(ssh_sock_cmd)}'. Fix ssh configuration first, "
427 "considering diagnostic log file output from running with -v -v -v.",
428 display_msg="ssh connect",
429 ) from e
430 self._last_refresh_time = time.monotonic_ns()
431 if self._connection_lease is not None:
432 self._connection_lease.set_socket_mtime_to_now()
434 def _increment_free(self, value: int) -> None:
435 """Adjusts the count of available SSH slots."""
436 self._free += value
437 assert self._free >= 0
438 assert self._free <= self._capacity
440 def _is_full(self) -> bool:
441 """Returns True if no more SSH sessions may be opened over this TCP connection."""
442 return self._free <= 0
444 def _update_last_modified(self, last_modified: int) -> None:
445 """Records when the connection was last used."""
446 self._last_modified = last_modified
448 def shutdown(self, msg_prefix: str) -> None:
449 """Closes the underlying SSH master connection and releases the corresponding connection lease."""
450 ssh_cmd: list[str] = self._ssh_cmd
451 if ssh_cmd and self._reuse_ssh_connection:
452 if self._connection_lease is None:
453 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]]
454 log = self._remote.params.log
455 log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_sock_cmd))
456 try:
457 proc: subprocess.CompletedProcess = subprocess.run(
458 ssh_sock_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=0.1
459 )
460 except subprocess.TimeoutExpired as e: # harmless as master auto-exits after ssh_control_persist_secs anyway
461 log.log(LOG_TRACE, "Harmless ssh master connection shutdown timeout: %s", e)
462 else:
463 if proc.returncode != 0: # harmless for the same reason
464 log.log(LOG_TRACE, "Harmless ssh master connection shutdown issue: %s", proc.stderr.rstrip())
465 else:
466 self._connection_lease.release()
469#############################################################################
470class ConnectionPool:
471 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse;
472 Note that max_concurrent_ssh_sessions_per_tcp_connection must not be larger than the server-side sshd_config(5)
473 MaxSessions parameter (which defaults to 10, see https://manpages.ubuntu.com/manpages/man5/sshd_config.5.html)."""
475 def __init__(
476 self, remote: MiniRemote, connpool_name: str, max_concurrent_ssh_sessions_per_tcp_connection: int = 8
477 ) -> None:
478 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
479 self._remote: Final[MiniRemote] = copy.copy(remote) # shallow copy for immutability (Remote is mutable)
480 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection
481 self._connpool_name: Final[str] = connpool_name
482 self._priority_queue: Final[SmallPriorityQueue[Connection]] = SmallPriorityQueue(
483 reverse=True # sorted by #free slots and last_modified
484 )
485 self._last_modified: int = 0 # monotonically increasing sequence number
486 self._lock: Final[threading.Lock] = threading.Lock()
487 lease_mgr: ConnectionLeaseManager | None = None
488 if self._remote.ssh_user_host and self._remote.reuse_ssh_connection and not self._remote.ssh_exit_on_shutdown:
489 lease_mgr = ConnectionLeaseManager(
490 root_dir=self._remote.ssh_socket_dir,
491 namespace=f"{self._remote.location}#{self._remote.cache_namespace()}#{self._connpool_name}",
492 ssh_control_persist_secs=max(90 * 60, 2 * self._remote.ssh_control_persist_secs + 2),
493 log=self._remote.params.log,
494 )
495 self._lease_mgr: Final[ConnectionLeaseManager | None] = lease_mgr
497 @contextlib.contextmanager
498 def connection(self) -> Iterator[Connection]:
499 """Context manager that yields a connection from the pool and automatically returns it on __exit__."""
500 conn: Connection = self.get_connection()
501 try:
502 yield conn
503 finally:
504 self.return_connection(conn)
506 def get_connection(self) -> Connection:
507 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue
508 while it is "checked out", and that identical Connection object is later, on return_connection(), temporarily removed
509 from the priority queue, updated with an incremented "free" slot count and then immediately reinserted into the
510 priority queue.
512 In effect, any Connection object remains intentionally contained in the priority queue at all times. This design
513 keeps ordering/fairness accurate while avoiding duplicate Connection instances.
514 """
515 with self._lock:
516 conn = self._priority_queue.pop() if len(self._priority_queue) > 0 else None
517 if conn is None or conn._is_full(): # noqa: SLF001 # pylint: disable=protected-access
518 if conn is not None:
519 self._priority_queue.push(conn)
520 conn = self._new_connection() # add a new connection
521 self._last_modified += 1
522 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access
523 conn._increment_free(-1) # noqa: SLF001 # pylint: disable=protected-access
524 self._priority_queue.push(conn)
525 return conn
527 def _new_connection(self) -> Connection:
528 lease: ConnectionLease | None = None if self._lease_mgr is None else self._lease_mgr.acquire()
529 return Connection(self._remote, self._capacity, lease=lease)
531 def return_connection(self, conn: Connection) -> None:
532 """Returns the given connection to the pool and updates its priority."""
533 assert conn is not None
534 with self._lock:
535 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue
536 if self._priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called
537 conn._increment_free(1) # noqa: SLF001 # pylint: disable=protected-access
538 self._last_modified += 1
539 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access
540 self._priority_queue.push(conn)
542 def shutdown(self, msg_prefix: str = "") -> None:
543 """Closes all SSH connections managed by this pool."""
544 with self._lock:
545 try:
546 if self._remote.reuse_ssh_connection:
547 msg_prefix = msg_prefix + "/" + self._connpool_name
548 for conn in self._priority_queue:
549 conn.shutdown(msg_prefix)
550 finally:
551 self._priority_queue.clear()
553 def __repr__(self) -> str:
554 with self._lock:
555 queue = self._priority_queue
556 return str({"capacity": self._capacity, "queue_len": len(queue), "queue": queue})
559#############################################################################
560@final
561class ConnectionPools:
562 """A bunch of named connection pools with various multiplexing capacities."""
564 def __init__(self, remote: MiniRemote, *, capacities: dict[str, int]) -> None:
565 """Creates one connection pool per name with the given capacities."""
566 self._pools: Final[dict[str, ConnectionPool]] = {
567 name: ConnectionPool(remote, name, capacity) for name, capacity in capacities.items()
568 }
570 def __repr__(self) -> str:
571 return str(self._pools)
573 def pool(self, name: str) -> ConnectionPool:
574 """Returns the pool associated with the given name."""
575 return self._pools[name]
577 def shutdown(self, msg_prefix: str = "") -> None:
578 """Shuts down every contained pool."""
579 for pool in self._pools.values():
580 pool.shutdown(msg_prefix)