Coverage for bzfs_main / util / connection.py: 99%

292 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-24 10:16 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Efficient thread-safe SSH command client; See run_ssh_command() and refresh_ssh_connection_if_necessary() and class 

16ConnectionPool and class ConnectionLease. 

17 

18Can be configured to reuse multiplexed SSH connections for low latency, even on fresh process startup, for example leading to 

19ballpark 3-5ms total time for running `/bin/echo hello` end-to-end over SSH on LAN, which requires two (sequential) network 

20round trips (one for CHANNEL_OPEN, plus a subsequent one for CHANNEL_REQUEST). 

21Has zero dependencies beyond the standard OpenSSH client CLI (`ssh`); also works with `hpnssh`. The latter uses larger TCP 

22window sizes for best throughput over high speed long distance networks, aka paths with large bandwidth-delay product. Also 

23see https://youtu.be/fcHXOgl3dis?t=473 and https://gist.github.com/rapier1/325de17bbb85f1ce663ccb866ce22639 

24 

25Example usage: 

26 

27import logging 

28from subprocess import DEVNULL, PIPE 

29from bzfs_main.util.connection import ConnectionPool, create_simple_minijob, create_simple_miniremote 

30from bzfs_main.util.retry import Retry, RetryPolicy, call_with_retries 

31 

32log = logging.getLogger(__name__) 

33remote = create_simple_miniremote(log=log, ssh_user_host="alice@127.0.0.1") 

34connection_pool = ConnectionPool(remote, connpool_name="example") 

35try: 

36 job = create_simple_minijob() 

37 retry_policy = RetryPolicy( 

38 max_retries=10, 

39 min_sleep_secs=0, 

40 initial_max_sleep_secs=0.125, 

41 max_sleep_secs=10, 

42 max_elapsed_secs=60, 

43 ) 

44 

45 def run_cmd(retry: Retry) -> str: 

46 with connection_pool.connection() as conn: 

47 stdout: str = conn.run_ssh_command( 

48 cmd=["echo", "hello"], job=job, check=True, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True 

49 ).stdout 

50 return stdout 

51 

52 stdout = call_with_retries(fn=run_cmd, policy=retry_policy, log=log) 

53 print(f"stdout: {stdout}") 

54finally: 

55 connection_pool.shutdown() 

56""" 

57 

58from __future__ import ( 

59 annotations, 

60) 

61import contextlib 

62import copy 

63import logging 

64import os 

65import shlex 

66import subprocess 

67import threading 

68import time 

69from collections.abc import ( 

70 Iterator, 

71) 

72from dataclasses import ( 

73 dataclass, 

74) 

75from subprocess import ( 

76 DEVNULL, 

77 PIPE, 

78) 

79from typing import ( 

80 Any, 

81 Final, 

82 Protocol, 

83 final, 

84 runtime_checkable, 

85) 

86 

87from bzfs_main.util.connection_lease import ( 

88 ConnectionLease, 

89 ConnectionLeaseManager, 

90) 

91from bzfs_main.util.retry import ( 

92 RetryableError, 

93) 

94from bzfs_main.util.utils import ( 

95 LOG_TRACE, 

96 SHELL_CHARS_AND_SLASH, 

97 SmallPriorityQueue, 

98 Subprocesses, 

99 die, 

100 get_home_directory, 

101 list_formatter, 

102 sha256_urlsafe_base64, 

103 stderr_to_str, 

104) 

105 

106# constants: 

107SHARED: Final[str] = "shared" 

108DEDICATED: Final[str] = "dedicated" 

109 

110 

111############################################################################# 

112@runtime_checkable 

113class MiniJob(Protocol): 

114 """Minimal Job interface required by the connections module; for loose coupling.""" 

115 

116 timeout_nanos: int | None # timestamp aka instant in time 

117 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only 

118 subprocesses: Subprocesses 

119 

120 

121############################################################################# 

122@runtime_checkable 

123class MiniParams(Protocol): 

124 """Minimal Params interface used by the connections module; for loose coupling.""" 

125 

126 log: logging.Logger 

127 ssh_program: str # name or path of executable; "hpnssh" is also valid 

128 

129 

130############################################################################# 

131@runtime_checkable 

132class MiniRemote(Protocol): 

133 """Minimal Remote interface used by the connections module; for loose coupling.""" 

134 

135 params: MiniParams 

136 location: str # "src" or "dst" 

137 ssh_user_host: str # use the empty string to indicate local mode (no ssh) 

138 ssh_extra_opts: tuple[str, ...] 

139 reuse_ssh_connection: bool 

140 ssh_control_persist_secs: int 

141 ssh_control_persist_margin_secs: int 

142 ssh_exit_on_shutdown: bool 

143 ssh_socket_dir: str 

144 

145 def is_ssh_available(self) -> bool: 

146 """Return True if the ssh client program required for this remote is available on the local host.""" 

147 

148 def local_ssh_command(self, socket_file: str | None) -> list[str]: 

149 """Returns the ssh CLI command to run locally in order to talk to the remote host; This excludes the (trailing) 

150 command to run on the remote host, which will be appended later.""" 

151 

152 def cache_namespace(self) -> str: 

153 """Returns cache namespace string which is a stable, unique directory component for caches that distinguishes 

154 endpoints by username+host+port+ssh_config_file where applicable, and uses '-' when no user/host is present (local 

155 mode).""" 

156 

157 

158############################################################################# 

159def create_simple_miniremote( 

160 *, 

161 log: logging.Logger, 

162 ssh_user_host: str = "", # option passed to `ssh` CLI; empty string indicates local mode 

163 ssh_port: int | None = None, # option passed to `ssh -p` CLI 

164 ssh_extra_opts: list[str] | None = None, # optional args passed to `ssh` CLI 

165 ssh_verbose: bool = False, # option passed to `ssh -v` CLI 

166 ssh_config_file: str = "", # option passed to `ssh -F` CLI; path to ssh_config(5) file; e.g /path/to/homedir/.ssh/config 

167 ssh_cipher: str = "^aes256-gcm@openssh.com", # option passed to `ssh -c` CLI 

168 ssh_connect_timeout_secs: int | None = None, # option passed to `ssh -oConnectTimeout=N`; default is system TCP timeout 

169 ssh_program: str = "ssh", # name or path of CLI executable; "hpnssh" is also valid 

170 reuse_ssh_connection: bool = True, 

171 ssh_control_persist_secs: int = 600, 

172 ssh_control_persist_margin_secs: int = 2, 

173 ssh_socket_dir: str = os.path.join(get_home_directory(), ".ssh", "bzfs"), 

174 location: str = "dst", 

175) -> MiniRemote: 

176 """Factory that returns a simple implementation of the MiniRemote interface.""" 

177 

178 @dataclass(frozen=True) # aka immutable 

179 @final 

180 class SimpleMiniParams(MiniParams): 

181 log: logging.Logger 

182 ssh_program: str 

183 

184 @dataclass(frozen=True) # aka immutable 

185 @final 

186 class SimpleMiniRemote(MiniRemote): 

187 params: MiniParams 

188 location: str # "src" or "dst" 

189 ssh_user_host: str 

190 ssh_extra_opts: tuple[str, ...] 

191 reuse_ssh_connection: bool 

192 ssh_control_persist_secs: int 

193 ssh_control_persist_margin_secs: int 

194 ssh_exit_on_shutdown: bool 

195 ssh_socket_dir: str 

196 ssh_port: int | None 

197 ssh_config_file: str 

198 ssh_config_file_hash: str 

199 

200 def is_ssh_available(self) -> bool: 

201 return True 

202 

203 def local_ssh_command(self, socket_file: str | None) -> list[str]: 

204 if not self.ssh_user_host: 

205 return [] # local mode 

206 ssh_cmd: list[str] = [self.params.ssh_program] 

207 ssh_cmd.extend(self.ssh_extra_opts) 

208 if self.reuse_ssh_connection and socket_file: 

209 ssh_cmd.append("-S") 

210 ssh_cmd.append(socket_file) 

211 ssh_cmd.append(self.ssh_user_host) 

212 return ssh_cmd 

213 

214 def cache_namespace(self) -> str: 

215 if not self.ssh_user_host: 

216 return "-" # local mode 

217 return f"{self.ssh_user_host}#{self.ssh_port or ''}#{self.ssh_config_file_hash}" 

218 

219 def validate_userhost(userhost: str) -> None: 

220 invalid_chars: str = SHELL_CHARS_AND_SLASH 

221 uh: str = userhost.replace("@", "", 1) 

222 if (not uh) or userhost.startswith("-") or ".." in userhost or any(c.isspace() or c in invalid_chars for c in uh): 

223 raise ValueError(f"Invalid [user@]host: '{userhost}'") 

224 

225 if log is None: 

226 raise ValueError("log must not be None") 

227 if not ssh_program: 

228 raise ValueError("ssh_program must be a non-empty string") 

229 if location not in ("src", "dst"): 

230 raise ValueError("location must be 'src' or 'dst'") 

231 if ssh_user_host: 

232 validate_userhost(ssh_user_host) 

233 if ssh_control_persist_secs < 1: 

234 raise ValueError("ssh_control_persist_secs must be >= 1") 

235 params: MiniParams = SimpleMiniParams(log=log, ssh_program=ssh_program) 

236 

237 ssh_extra_opts = ( # disable interactive password prompts and X11 forwarding and pseudo-terminal allocation 

238 ["-oBatchMode=yes", "-oServerAliveInterval=0", "-x", "-T"] if ssh_extra_opts is None else list(ssh_extra_opts) 

239 ) 

240 ssh_extra_opts += ["-v"] if ssh_verbose else [] 

241 ssh_extra_opts += ["-F", ssh_config_file] if ssh_config_file else [] 

242 ssh_extra_opts += ["-c", ssh_cipher] if ssh_cipher else [] 

243 ssh_extra_opts += ["-p", str(ssh_port)] if ssh_port is not None else [] 

244 ssh_extra_opts += [] if ssh_connect_timeout_secs is None else [f"-oConnectTimeout={max(0, ssh_connect_timeout_secs)}s"] 

245 ssh_config_file_hash = sha256_urlsafe_base64(os.path.abspath(ssh_config_file), padding=False) if ssh_config_file else "" 

246 return SimpleMiniRemote( 

247 params=params, 

248 location=location, 

249 ssh_user_host=ssh_user_host, 

250 ssh_extra_opts=tuple(ssh_extra_opts), 

251 reuse_ssh_connection=reuse_ssh_connection, 

252 ssh_control_persist_secs=ssh_control_persist_secs, 

253 ssh_control_persist_margin_secs=ssh_control_persist_margin_secs, 

254 ssh_exit_on_shutdown=False, 

255 ssh_socket_dir=ssh_socket_dir, 

256 ssh_port=ssh_port, 

257 ssh_config_file=ssh_config_file, 

258 ssh_config_file_hash=ssh_config_file_hash, 

259 ) 

260 

261 

262def create_simple_minijob( 

263 *, timeout_duration_secs: float | None = None, subprocesses: Subprocesses | None = None 

264) -> MiniJob: 

265 """Factory that returns a simple implementation of the MiniJob interface.""" 

266 

267 @dataclass(frozen=True) # aka immutable 

268 @final 

269 class SimpleMiniJob(MiniJob): 

270 timeout_nanos: int | None # timestamp aka instant in time 

271 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only 

272 subprocesses: Subprocesses 

273 

274 t_duration_nanos: int | None = None if timeout_duration_secs is None else int(timeout_duration_secs * 1_000_000_000) 

275 timeout_nanos: int | None = None if t_duration_nanos is None else time.monotonic_ns() + t_duration_nanos 

276 subprocesses = Subprocesses() if subprocesses is None else subprocesses 

277 return SimpleMiniJob(timeout_nanos=timeout_nanos, timeout_duration_nanos=t_duration_nanos, subprocesses=subprocesses) 

278 

279 

280############################################################################# 

281def timeout(job: MiniJob) -> float | None: 

282 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur.""" 

283 timeout_nanos: int | None = job.timeout_nanos 

284 if timeout_nanos is None: 

285 return None # never raise a timeout 

286 assert job.timeout_duration_nanos is not None 

287 delta_nanos: int = timeout_nanos - time.monotonic_ns() 

288 if delta_nanos <= 0: 

289 raise subprocess.TimeoutExpired("_timeout", timeout=job.timeout_duration_nanos / 1_000_000_000) 

290 return delta_nanos / 1_000_000_000 # seconds 

291 

292 

293def squote(remote: MiniRemote, arg: str) -> str: 

294 """Quotes an argument only when running remotely over ssh.""" 

295 assert arg is not None 

296 return shlex.quote(arg) if remote.ssh_user_host else arg 

297 

298 

299def dquote(arg: str) -> str: 

300 """Shell-escapes backslash and double quotes and dollar and backticks, then surrounds with double quotes; For an example 

301 how to safely construct and quote complex shell pipeline commands for use over SSH, see 

302 replication.py:_prepare_zfs_send_receive()""" 

303 arg = arg.replace("\\", "\\\\").replace('"', '\\"').replace("$", "\\$").replace("`", "\\`") 

304 return '"' + arg + '"' 

305 

306 

307############################################################################# 

308@dataclass(order=True, repr=False) 

309@final 

310class Connection: 

311 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection.""" 

312 

313 _free: int # sort order evens out the number of concurrent sessions among the TCP connections 

314 _last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot; also ensures no dupes 

315 

316 def __init__( 

317 self, 

318 remote: MiniRemote, 

319 max_concurrent_ssh_sessions_per_tcp_connection: int, 

320 *, 

321 lease: ConnectionLease | None = None, 

322 ) -> None: 

323 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

324 self._remote: Final[MiniRemote] = remote 

325 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection 

326 self._free: int = max_concurrent_ssh_sessions_per_tcp_connection 

327 self._last_modified: int = 0 # monotonically increasing 

328 self._last_refresh_time: int = 0 

329 self._lock: Final[threading.Lock] = threading.Lock() 

330 self._reuse_ssh_connection: Final[bool] = remote.reuse_ssh_connection 

331 self._connection_lease: Final[ConnectionLease | None] = lease 

332 self._ssh_cmd: Final[list[str]] = remote.local_ssh_command( 

333 None if self._connection_lease is None else self._connection_lease.socket_path 

334 ) 

335 self._ssh_cmd_quoted: Final[list[str]] = [shlex.quote(item) for item in self._ssh_cmd] 

336 

337 @property 

338 def ssh_cmd(self) -> list[str]: 

339 return self._ssh_cmd.copy() 

340 

341 @property 

342 def ssh_cmd_quoted(self) -> list[str]: 

343 return self._ssh_cmd_quoted.copy() 

344 

345 def __repr__(self) -> str: 

346 return str({"free": self._free}) 

347 

348 def run_ssh_command( 

349 self, 

350 cmd: list[str], 

351 *, 

352 job: MiniJob, 

353 loglevel: int = logging.INFO, 

354 is_dry: bool = False, 

355 **kwargs: Any, # optional low-level keyword args to be forwarded to subprocess.run() 

356 ) -> subprocess.CompletedProcess: 

357 """Runs the given CLI cmd via ssh on the given remote, and returns CompletedProcess including stdout and stderr. 

358 

359 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host 

360 ($remote.local_ssh_command()) and the command to run on the given remote host ($cmd). 

361 

362 Note: When executing on a remote host (remote.ssh_user_host is set), cmd arguments are pre-quoted with shlex.quote to 

363 safely traverse the ssh "remote shell" boundary, as ssh concatenates argv into a single remote shell string. In local 

364 mode (no remote.ssh_user_host) argv is executed directly without an intermediate shell. 

365 """ 

366 if not cmd: 

367 raise ValueError("run_ssh_command requires a non-empty cmd list") 

368 log: logging.Logger = self._remote.params.log 

369 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd] 

370 ssh_cmd: list[str] = self._ssh_cmd 

371 if self._remote.ssh_user_host: 

372 self.refresh_ssh_connection_if_necessary(job) 

373 cmd = quoted_cmd 

374 msg: str = "Would execute: %s" if is_dry else "Executing: %s" 

375 log.log(loglevel, msg, list_formatter(self._ssh_cmd_quoted + quoted_cmd, lstrip=True)) 

376 if is_dry: 

377 return subprocess.CompletedProcess(ssh_cmd + cmd, returncode=0, stdout=None, stderr=None) 

378 else: 

379 sp: Subprocesses = job.subprocesses 

380 return sp.subprocess_run(ssh_cmd + cmd, timeout=timeout(job), log=log, **kwargs) 

381 

382 def refresh_ssh_connection_if_necessary(self, job: MiniJob) -> None: 

383 """Maintain or create an ssh master connection for low latency reuse.""" 

384 remote: MiniRemote = self._remote 

385 p: MiniParams = remote.params 

386 log: logging.Logger = p.log 

387 if not remote.ssh_user_host: 

388 return # we're in local mode; no ssh required 

389 if not remote.is_ssh_available(): 

390 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!") 

391 if not remote.reuse_ssh_connection: 

392 return 

393 

394 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and 

395 # 'ssh -S -M -oControlPersist=90s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing 

396 # and https://chessman7.substack.com/p/how-ssh-multiplexing-reuses-master 

397 control_limit_nanos: int = (remote.ssh_control_persist_secs - remote.ssh_control_persist_margin_secs) * 1_000_000_000 

398 with self._lock: 

399 if time.monotonic_ns() < self._last_refresh_time + control_limit_nanos: 

400 return # ssh master is alive, reuse its TCP connection (this is the common case and the ultra-fast path) 

401 ssh_cmd: list[str] = self._ssh_cmd 

402 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host 

403 ssh_sock_cmd += ["-O", "check", remote.ssh_user_host] 

404 # extend lifetime of ssh master by $ssh_control_persist_secs via `ssh -O check` if master is still running. 

405 # `ssh -S /path/to/socket -O check` doesn't talk over the network, hence is still a low latency fast path. 

406 sp: Subprocesses = job.subprocesses 

407 t: float | None = timeout(job) 

408 if sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, timeout=t, log=log).returncode == 0: 

409 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_sock_cmd)) 

410 else: # ssh master is not alive; start a new master: 

411 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_sock_cmd)) 

412 ssh_control_persist_secs: int = max(1, remote.ssh_control_persist_secs) 

413 if any(opt.startswith("-v") and all(char == "v" for char in opt[1:]) for opt in remote.ssh_extra_opts): 

414 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the 

415 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ... 

416 ssh_control_persist_secs = min(1, ssh_control_persist_secs) # tell ssh block as briefly as possible (1s) 

417 ssh_sock_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host 

418 ssh_sock_cmd += ["-M", f"-oControlPersist={ssh_control_persist_secs}s", remote.ssh_user_host, "exit"] 

419 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_sock_cmd)) 

420 t = timeout(job) 

421 try: 

422 sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, check=True, timeout=t, log=log) 

423 except subprocess.CalledProcessError as e: 

424 log.error("%s", stderr_to_str(e.stderr).rstrip()) 

425 raise RetryableError( 

426 f"Cannot ssh into remote host via '{' '.join(ssh_sock_cmd)}'. Fix ssh configuration first, " 

427 "considering diagnostic log file output from running with -v -v -v.", 

428 display_msg="ssh connect", 

429 ) from e 

430 self._last_refresh_time = time.monotonic_ns() 

431 if self._connection_lease is not None: 

432 self._connection_lease.set_socket_mtime_to_now() 

433 

434 def _increment_free(self, value: int) -> None: 

435 """Adjusts the count of available SSH slots.""" 

436 self._free += value 

437 assert self._free >= 0 

438 assert self._free <= self._capacity 

439 

440 def _is_full(self) -> bool: 

441 """Returns True if no more SSH sessions may be opened over this TCP connection.""" 

442 return self._free <= 0 

443 

444 def _update_last_modified(self, last_modified: int) -> None: 

445 """Records when the connection was last used.""" 

446 self._last_modified = last_modified 

447 

448 def shutdown(self, msg_prefix: str) -> None: 

449 """Closes the underlying SSH master connection and releases the corresponding connection lease.""" 

450 ssh_cmd: list[str] = self._ssh_cmd 

451 if ssh_cmd and self._reuse_ssh_connection: 

452 if self._connection_lease is None: 

453 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]] 

454 log = self._remote.params.log 

455 log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_sock_cmd)) 

456 try: 

457 proc: subprocess.CompletedProcess = subprocess.run( 

458 ssh_sock_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=0.1 

459 ) 

460 except subprocess.TimeoutExpired as e: # harmless as master auto-exits after ssh_control_persist_secs anyway 

461 log.log(LOG_TRACE, "Harmless ssh master connection shutdown timeout: %s", e) 

462 else: 

463 if proc.returncode != 0: # harmless for the same reason 

464 log.log(LOG_TRACE, "Harmless ssh master connection shutdown issue: %s", proc.stderr.rstrip()) 

465 else: 

466 self._connection_lease.release() 

467 

468 

469############################################################################# 

470class ConnectionPool: 

471 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse; 

472 Note that max_concurrent_ssh_sessions_per_tcp_connection must not be larger than the server-side sshd_config(5) 

473 MaxSessions parameter (which defaults to 10, see https://manpages.ubuntu.com/manpages/man5/sshd_config.5.html).""" 

474 

475 def __init__( 

476 self, remote: MiniRemote, connpool_name: str, max_concurrent_ssh_sessions_per_tcp_connection: int = 8 

477 ) -> None: 

478 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

479 self._remote: Final[MiniRemote] = copy.copy(remote) # shallow copy for immutability (Remote is mutable) 

480 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection 

481 self._connpool_name: Final[str] = connpool_name 

482 self._priority_queue: Final[SmallPriorityQueue[Connection]] = SmallPriorityQueue( 

483 reverse=True # sorted by #free slots and last_modified 

484 ) 

485 self._last_modified: int = 0 # monotonically increasing sequence number 

486 self._lock: Final[threading.Lock] = threading.Lock() 

487 lease_mgr: ConnectionLeaseManager | None = None 

488 if self._remote.ssh_user_host and self._remote.reuse_ssh_connection and not self._remote.ssh_exit_on_shutdown: 

489 lease_mgr = ConnectionLeaseManager( 

490 root_dir=self._remote.ssh_socket_dir, 

491 namespace=f"{self._remote.location}#{self._remote.cache_namespace()}#{self._connpool_name}", 

492 ssh_control_persist_secs=max(90 * 60, 2 * self._remote.ssh_control_persist_secs + 2), 

493 log=self._remote.params.log, 

494 ) 

495 self._lease_mgr: Final[ConnectionLeaseManager | None] = lease_mgr 

496 

497 @contextlib.contextmanager 

498 def connection(self) -> Iterator[Connection]: 

499 """Context manager that yields a connection from the pool and automatically returns it on __exit__.""" 

500 conn: Connection = self.get_connection() 

501 try: 

502 yield conn 

503 finally: 

504 self.return_connection(conn) 

505 

506 def get_connection(self) -> Connection: 

507 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue 

508 while it is "checked out", and that identical Connection object is later, on return_connection(), temporarily removed 

509 from the priority queue, updated with an incremented "free" slot count and then immediately reinserted into the 

510 priority queue. 

511 

512 In effect, any Connection object remains intentionally contained in the priority queue at all times. This design 

513 keeps ordering/fairness accurate while avoiding duplicate Connection instances. 

514 """ 

515 with self._lock: 

516 conn = self._priority_queue.pop() if len(self._priority_queue) > 0 else None 

517 if conn is None or conn._is_full(): # noqa: SLF001 # pylint: disable=protected-access 

518 if conn is not None: 

519 self._priority_queue.push(conn) 

520 conn = self._new_connection() # add a new connection 

521 self._last_modified += 1 

522 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access 

523 conn._increment_free(-1) # noqa: SLF001 # pylint: disable=protected-access 

524 self._priority_queue.push(conn) 

525 return conn 

526 

527 def _new_connection(self) -> Connection: 

528 lease: ConnectionLease | None = None if self._lease_mgr is None else self._lease_mgr.acquire() 

529 return Connection(self._remote, self._capacity, lease=lease) 

530 

531 def return_connection(self, conn: Connection) -> None: 

532 """Returns the given connection to the pool and updates its priority.""" 

533 assert conn is not None 

534 with self._lock: 

535 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue 

536 if self._priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called 

537 conn._increment_free(1) # noqa: SLF001 # pylint: disable=protected-access 

538 self._last_modified += 1 

539 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access 

540 self._priority_queue.push(conn) 

541 

542 def shutdown(self, msg_prefix: str = "") -> None: 

543 """Closes all SSH connections managed by this pool.""" 

544 with self._lock: 

545 try: 

546 if self._remote.reuse_ssh_connection: 

547 msg_prefix = msg_prefix + "/" + self._connpool_name 

548 for conn in self._priority_queue: 

549 conn.shutdown(msg_prefix) 

550 finally: 

551 self._priority_queue.clear() 

552 

553 def __repr__(self) -> str: 

554 with self._lock: 

555 queue = self._priority_queue 

556 return str({"capacity": self._capacity, "queue_len": len(queue), "queue": queue}) 

557 

558 

559############################################################################# 

560@final 

561class ConnectionPools: 

562 """A bunch of named connection pools with various multiplexing capacities.""" 

563 

564 def __init__(self, remote: MiniRemote, *, capacities: dict[str, int]) -> None: 

565 """Creates one connection pool per name with the given capacities.""" 

566 self._pools: Final[dict[str, ConnectionPool]] = { 

567 name: ConnectionPool(remote, name, capacity) for name, capacity in capacities.items() 

568 } 

569 

570 def __repr__(self) -> str: 

571 return str(self._pools) 

572 

573 def pool(self, name: str) -> ConnectionPool: 

574 """Returns the pool associated with the given name.""" 

575 return self._pools[name] 

576 

577 def shutdown(self, msg_prefix: str = "") -> None: 

578 """Shuts down every contained pool.""" 

579 for pool in self._pools.values(): 

580 pool.shutdown(msg_prefix)