Coverage for bzfs_main / util / connection.py: 99%

315 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 12:49 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Efficient thread-safe SSH command client; See run_ssh_command() and refresh_ssh_connection_if_necessary() and class 

16ConnectionPool and class ConnectionLease. 

17 

18Can be configured to reuse multiplexed SSH connections for low latency, even on fresh process startup, for example leading to 

19ballpark 3-5ms total time for running `/bin/echo hello` end-to-end over SSH on LAN, which requires two (sequential) network 

20round trips (one for CHANNEL_OPEN, plus a subsequent one for CHANNEL_REQUEST). 

21Has zero dependencies beyond the standard OpenSSH client CLI (`ssh`); also works with `hpnssh`. The latter uses larger TCP 

22window sizes for best throughput over high speed long distance networks, aka paths with large bandwidth-delay product. Also 

23see https://youtu.be/fcHXOgl3dis?t=473 and https://gist.github.com/rapier1/325de17bbb85f1ce663ccb866ce22639 

24 

25Example usage: 

26 

27import logging 

28from subprocess import DEVNULL, PIPE 

29from bzfs_main.util.connection import ConnectionPool, create_simple_minijob, create_simple_miniremote 

30from bzfs_main.util.retry import Retry, RetryPolicy, call_with_retries 

31 

32log = logging.getLogger(__name__) 

33remote = create_simple_miniremote(log=log, ssh_user_host="alice@127.0.0.1") 

34connection_pool = ConnectionPool(remote, connpool_name="example") 

35try: 

36 job = create_simple_minijob() 

37 retry_policy = RetryPolicy( 

38 max_retries=10, 

39 min_sleep_secs=0, 

40 initial_max_sleep_secs=0.125, 

41 max_sleep_secs=10, 

42 max_elapsed_secs=60, 

43 ) 

44 

45 def run_cmd(retry: Retry) -> str: 

46 with connection_pool.connection() as conn: 

47 stdout: str = conn.run_ssh_command( 

48 cmd=["echo", "hello"], job=job, check=True, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True 

49 ).stdout 

50 return stdout 

51 

52 stdout = call_with_retries(fn=run_cmd, policy=retry_policy, log=log) 

53 print(f"stdout: {stdout}") 

54finally: 

55 connection_pool.shutdown() 

56""" 

57 

58from __future__ import ( 

59 annotations, 

60) 

61import contextlib 

62import copy 

63import logging 

64import os 

65import shlex 

66import socket 

67import stat 

68import subprocess 

69import threading 

70import time 

71from collections.abc import ( 

72 Iterator, 

73) 

74from dataclasses import ( 

75 dataclass, 

76) 

77from subprocess import ( 

78 DEVNULL, 

79 PIPE, 

80) 

81from typing import ( 

82 Any, 

83 Final, 

84 Protocol, 

85 final, 

86 runtime_checkable, 

87) 

88 

89from bzfs_main.util.connection_lease import ( 

90 ConnectionLease, 

91 ConnectionLeaseManager, 

92) 

93from bzfs_main.util.retry import ( 

94 RetryableError, 

95) 

96from bzfs_main.util.utils import ( 

97 LOG_TRACE, 

98 SHELL_CHARS_AND_SLASH, 

99 SmallPriorityQueue, 

100 Subprocesses, 

101 die, 

102 get_home_directory, 

103 list_formatter, 

104 sha256_urlsafe_base64, 

105 stderr_to_str, 

106) 

107 

108# constants: 

109SHARED: Final[str] = "shared" 

110DEDICATED: Final[str] = "dedicated" 

111 

112 

113############################################################################# 

114@runtime_checkable 

115class MiniJob(Protocol): 

116 """Minimal Job interface required by the connections module; for loose coupling.""" 

117 

118 timeout_nanos: int | None # timestamp aka instant in time 

119 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only 

120 subprocesses: Subprocesses 

121 

122 

123############################################################################# 

124@runtime_checkable 

125class MiniParams(Protocol): 

126 """Minimal Params interface used by the connections module; for loose coupling.""" 

127 

128 log: logging.Logger 

129 ssh_program: str # name or path of executable; "hpnssh" is also valid 

130 

131 

132############################################################################# 

133@runtime_checkable 

134class MiniRemote(Protocol): 

135 """Minimal Remote interface used by the connections module; for loose coupling.""" 

136 

137 params: MiniParams 

138 location: str # "src" or "dst" 

139 ssh_user_host: str # use the empty string to indicate local mode (no ssh) 

140 ssh_extra_opts: tuple[str, ...] 

141 reuse_ssh_connection: bool 

142 ssh_control_persist_secs: int 

143 ssh_control_persist_margin_secs: int 

144 ssh_exit_on_shutdown: bool 

145 ssh_socket_dir: str 

146 

147 def is_ssh_available(self) -> bool: 

148 """Return True if the ssh client program required for this remote is available on the local host.""" 

149 

150 def local_ssh_command(self, socket_file: str | None) -> tuple[list[str], str | None]: 

151 """Returns the ssh CLI command to run locally in order to talk to the remote host; This excludes the (trailing) 

152 command to run on the remote host, which will be appended later; also returns the effective ControlPath used by the 

153 ssh CLI command, or ``None`` when SSH multiplexing is not active.""" 

154 

155 def cache_namespace(self) -> str: 

156 """Returns cache namespace string which is a stable, unique directory component for caches that distinguishes 

157 endpoints by username+host+port+ssh_config_file where applicable, and uses '-' when no user/host is present (local 

158 mode).""" 

159 

160 

161############################################################################# 

162def create_simple_miniremote( 

163 *, 

164 log: logging.Logger, 

165 ssh_user_host: str = "", # option passed to `ssh` CLI; empty string indicates local mode 

166 ssh_port: int | None = None, # option passed to `ssh -p` CLI 

167 ssh_extra_opts: list[str] | None = None, # optional args passed to `ssh` CLI 

168 ssh_verbose: bool = False, # option passed to `ssh -v` CLI 

169 ssh_config_file: str = "", # option passed to `ssh -F` CLI; path to ssh_config(5) file; e.g /path/to/homedir/.ssh/config 

170 ssh_cipher: str = "^aes256-gcm@openssh.com", # option passed to `ssh -c` CLI 

171 ssh_connect_timeout_secs: int | None = None, # option passed to `ssh -oConnectTimeout=N`; default is system TCP timeout 

172 ssh_program: str = "ssh", # name or path of CLI executable; "hpnssh" is also valid 

173 reuse_ssh_connection: bool = True, 

174 ssh_control_persist_secs: int = 600, 

175 ssh_control_persist_margin_secs: int = 2, 

176 ssh_socket_dir: str = os.path.join(get_home_directory(), ".ssh", "bzfs"), 

177 location: str = "dst", 

178) -> MiniRemote: 

179 """Factory that returns a simple implementation of the MiniRemote interface.""" 

180 

181 @dataclass(frozen=True) # aka immutable 

182 @final 

183 class SimpleMiniParams(MiniParams): 

184 log: logging.Logger 

185 ssh_program: str 

186 

187 @dataclass(frozen=True) # aka immutable 

188 @final 

189 class SimpleMiniRemote(MiniRemote): 

190 params: MiniParams 

191 location: str # "src" or "dst" 

192 ssh_user_host: str 

193 ssh_extra_opts: tuple[str, ...] 

194 reuse_ssh_connection: bool 

195 ssh_control_persist_secs: int 

196 ssh_control_persist_margin_secs: int 

197 ssh_exit_on_shutdown: bool 

198 ssh_socket_dir: str 

199 ssh_port: int | None 

200 ssh_config_file: str 

201 ssh_config_file_hash: str 

202 

203 def is_ssh_available(self) -> bool: 

204 return True 

205 

206 def local_ssh_command(self, socket_file: str | None) -> tuple[list[str], str | None]: 

207 if not self.ssh_user_host: 

208 return [], None # local mode 

209 ssh_cmd: list[str] = [self.params.ssh_program] 

210 ssh_cmd.extend(self.ssh_extra_opts) 

211 socket_path: str | None = None 

212 if self.reuse_ssh_connection and socket_file: 

213 ssh_cmd.append("-S") 

214 ssh_cmd.append(socket_file) 

215 socket_path = socket_file 

216 ssh_cmd.append(self.ssh_user_host) 

217 return ssh_cmd, socket_path 

218 

219 def cache_namespace(self) -> str: 

220 if not self.ssh_user_host: 

221 return "-" # local mode 

222 return f"{self.ssh_user_host}#{self.ssh_port or ''}#{self.ssh_config_file_hash}" 

223 

224 def validate_userhost(userhost: str) -> None: 

225 invalid_chars: str = SHELL_CHARS_AND_SLASH 

226 uh: str = userhost.replace("@", "", 1) 

227 if (not uh) or userhost.startswith("-") or ".." in userhost or any(c.isspace() or c in invalid_chars for c in uh): 

228 raise ValueError(f"Invalid [user@]host: '{userhost}'") 

229 

230 if log is None: 

231 raise ValueError("log must not be None") 

232 if not ssh_program: 

233 raise ValueError("ssh_program must be a non-empty string") 

234 if location not in ("src", "dst"): 

235 raise ValueError("location must be 'src' or 'dst'") 

236 if ssh_user_host: 

237 validate_userhost(ssh_user_host) 

238 if ssh_control_persist_secs < 1: 

239 raise ValueError("ssh_control_persist_secs must be >= 1") 

240 params: MiniParams = SimpleMiniParams(log=log, ssh_program=ssh_program) 

241 

242 ssh_extra_opts = ( # disable interactive password prompts and X11 forwarding and pseudo-terminal allocation 

243 ["-oBatchMode=yes", "-oServerAliveInterval=0", "-x", "-T"] if ssh_extra_opts is None else list(ssh_extra_opts) 

244 ) 

245 ssh_extra_opts += ["-v"] if ssh_verbose else [] 

246 ssh_extra_opts += ["-F", ssh_config_file] if ssh_config_file else [] 

247 ssh_extra_opts += ["-c", ssh_cipher] if ssh_cipher else [] 

248 ssh_extra_opts += ["-p", str(ssh_port)] if ssh_port is not None else [] 

249 ssh_extra_opts += [] if ssh_connect_timeout_secs is None else [f"-oConnectTimeout={max(0, ssh_connect_timeout_secs)}s"] 

250 ssh_config_file_hash = sha256_urlsafe_base64(os.path.abspath(ssh_config_file), padding=False) if ssh_config_file else "" 

251 return SimpleMiniRemote( 

252 params=params, 

253 location=location, 

254 ssh_user_host=ssh_user_host, 

255 ssh_extra_opts=tuple(ssh_extra_opts), 

256 reuse_ssh_connection=reuse_ssh_connection, 

257 ssh_control_persist_secs=ssh_control_persist_secs, 

258 ssh_control_persist_margin_secs=ssh_control_persist_margin_secs, 

259 ssh_exit_on_shutdown=False, 

260 ssh_socket_dir=ssh_socket_dir, 

261 ssh_port=ssh_port, 

262 ssh_config_file=ssh_config_file, 

263 ssh_config_file_hash=ssh_config_file_hash, 

264 ) 

265 

266 

267def create_simple_minijob( 

268 *, timeout_duration_secs: float | None = None, subprocesses: Subprocesses | None = None 

269) -> MiniJob: 

270 """Factory that returns a simple implementation of the MiniJob interface.""" 

271 

272 @dataclass(frozen=True) # aka immutable 

273 @final 

274 class SimpleMiniJob(MiniJob): 

275 timeout_nanos: int | None # timestamp aka instant in time 

276 timeout_duration_nanos: int | None # duration (not a timestamp); for logging only 

277 subprocesses: Subprocesses 

278 

279 t_duration_nanos: int | None = None if timeout_duration_secs is None else int(timeout_duration_secs * 1_000_000_000) 

280 timeout_nanos: int | None = None if t_duration_nanos is None else time.monotonic_ns() + t_duration_nanos 

281 subprocesses = Subprocesses() if subprocesses is None else subprocesses 

282 return SimpleMiniJob(timeout_nanos=timeout_nanos, timeout_duration_nanos=t_duration_nanos, subprocesses=subprocesses) 

283 

284 

285############################################################################# 

286def timeout(job: MiniJob) -> float | None: 

287 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur.""" 

288 timeout_nanos: int | None = job.timeout_nanos 

289 if timeout_nanos is None: 

290 return None # never raise a timeout 

291 assert job.timeout_duration_nanos is not None 

292 delta_nanos: int = timeout_nanos - time.monotonic_ns() 

293 if delta_nanos <= 0: 

294 raise subprocess.TimeoutExpired("_timeout", timeout=job.timeout_duration_nanos / 1_000_000_000) 

295 return delta_nanos / 1_000_000_000 # seconds 

296 

297 

298def squote(remote: MiniRemote, arg: str) -> str: 

299 """Quotes an argument only when running remotely over ssh.""" 

300 assert arg is not None 

301 return shlex.quote(arg) if remote.ssh_user_host else arg 

302 

303 

304def dquote(arg: str) -> str: 

305 """Shell-escapes backslash and double quotes and dollar and backticks, then surrounds with double quotes; For an example 

306 how to safely construct and quote complex shell pipeline commands for use over SSH, see 

307 replication.py:_prepare_zfs_send_receive()""" 

308 arg = arg.replace("\\", "\\\\").replace('"', '\\"').replace("$", "\\$").replace("`", "\\`") 

309 return '"' + arg + '"' 

310 

311 

312############################################################################# 

313@dataclass(order=True, repr=False) 

314@final 

315class Connection: 

316 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection.""" 

317 

318 _free: int # sort order evens out the number of concurrent sessions among the TCP connections 

319 _last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot; also ensures no dupes 

320 

321 def __init__( 

322 self, 

323 remote: MiniRemote, 

324 max_concurrent_ssh_sessions_per_tcp_connection: int, 

325 *, 

326 lease: ConnectionLease | None = None, 

327 ) -> None: 

328 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

329 self._remote: Final[MiniRemote] = remote 

330 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection 

331 self._free: int = max_concurrent_ssh_sessions_per_tcp_connection 

332 self._last_modified: int = 0 # monotonically increasing 

333 self._last_refresh_time: int = 1 - (1 << 150) # negative infinity for all practical purposes 

334 self._lock: Final[threading.Lock] = threading.Lock() 

335 self._reuse_ssh_connection: Final[bool] = remote.reuse_ssh_connection 

336 self._connection_lease: Final[ConnectionLease | None] = lease 

337 ssh_cmd, ssh_socket_path = remote.local_ssh_command( 

338 None if self._connection_lease is None else self._connection_lease.socket_path 

339 ) 

340 self._ssh_socket_path: Final[str | None] = ssh_socket_path 

341 self._ssh_cmd: Final[list[str]] = ssh_cmd 

342 self._ssh_cmd_quoted: Final[list[str]] = [shlex.quote(item) for item in self._ssh_cmd] 

343 

344 @property 

345 def ssh_cmd(self) -> list[str]: 

346 return self._ssh_cmd.copy() 

347 

348 @property 

349 def ssh_cmd_quoted(self) -> list[str]: 

350 return self._ssh_cmd_quoted.copy() 

351 

352 def __repr__(self) -> str: 

353 return str({"free": self._free}) 

354 

355 def run_ssh_command( 

356 self, 

357 cmd: list[str], 

358 *, 

359 job: MiniJob, 

360 loglevel: int = logging.INFO, 

361 is_dry: bool = False, 

362 **kwargs: Any, # optional low-level keyword args to be forwarded to subprocess.run() 

363 ) -> subprocess.CompletedProcess: 

364 """Runs the given CLI cmd via ssh on the given remote, and returns CompletedProcess including stdout and stderr. 

365 

366 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host 

367 (``remote.local_ssh_command(...)[0]``) and the command to run on the given remote host (``cmd``). 

368 

369 Note: When executing on a remote host (remote.ssh_user_host is set), cmd arguments are pre-quoted with shlex.quote to 

370 safely traverse the ssh "remote shell" boundary, as ssh concatenates argv into a single remote shell string. In local 

371 mode (no remote.ssh_user_host) argv is executed directly without an intermediate shell. 

372 """ 

373 if not cmd: 

374 raise ValueError("run_ssh_command requires a non-empty cmd list") 

375 log: logging.Logger = self._remote.params.log 

376 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd] 

377 ssh_cmd: list[str] = self._ssh_cmd 

378 if self._remote.ssh_user_host: 

379 self.refresh_ssh_connection_if_necessary(job) 

380 cmd = quoted_cmd 

381 msg: str = "Would execute: %s" if is_dry else "Executing: %s" 

382 log.log(loglevel, msg, list_formatter(self._ssh_cmd_quoted + quoted_cmd, lstrip=True)) 

383 if is_dry: 

384 return subprocess.CompletedProcess(ssh_cmd + cmd, returncode=0, stdout=None, stderr=None) 

385 else: 

386 sp: Subprocesses = job.subprocesses 

387 return sp.subprocess_run(ssh_cmd + cmd, timeout=timeout(job), log=log, **kwargs) 

388 

389 def refresh_ssh_connection_if_necessary(self, job: MiniJob) -> None: 

390 """Maintain or create an ssh master connection for low latency reuse.""" 

391 remote: MiniRemote = self._remote 

392 p: MiniParams = remote.params 

393 log: logging.Logger = p.log 

394 if not remote.ssh_user_host: 

395 return # we're in local mode; no ssh required 

396 if not remote.is_ssh_available(): 

397 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!") 

398 if not remote.reuse_ssh_connection: 

399 return 

400 

401 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and 

402 # 'ssh -S -M -oControlPersist=90s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing 

403 # and https://chessman7.substack.com/p/how-ssh-multiplexing-reuses-master 

404 control_limit_nanos: int = (remote.ssh_control_persist_secs - remote.ssh_control_persist_margin_secs) * 1_000_000_000 

405 socket_path: str | None = self._ssh_socket_path 

406 with self._lock: 

407 if time.monotonic_ns() < self._last_refresh_time + control_limit_nanos: 

408 if socket_path is None or self._is_ssh_control_socket_usable(socket_path): 

409 return # ssh master is alive, reuse its TCP connection (this is the common case and the ultra-fast path) 

410 ssh_cmd: list[str] = self._ssh_cmd 

411 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host 

412 ssh_sock_cmd += ["-O", "check", remote.ssh_user_host] 

413 # extend lifetime of ssh master by $ssh_control_persist_secs via `ssh -O check` if master is still running. 

414 # `ssh -S /path/to/socket -O check` doesn't talk over the network, hence is still a low latency fast path. 

415 sp: Subprocesses = job.subprocesses 

416 t: float | None = timeout(job) 

417 if sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, timeout=t, log=log).returncode == 0: 

418 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_sock_cmd)) 

419 else: # ssh master is not alive; start a new master: 

420 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_sock_cmd)) 

421 if socket_path is not None and not self._is_ssh_control_socket_usable(socket_path): 

422 with contextlib.suppress(OSError): 

423 os.unlink(socket_path) # if present, remove stale ssh control socket path before master restart 

424 ssh_control_persist_secs: int = max(1, remote.ssh_control_persist_secs) 

425 if any(opt.startswith("-v") and all(char == "v" for char in opt[1:]) for opt in remote.ssh_extra_opts): 

426 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the 

427 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ... 

428 ssh_control_persist_secs = min(1, ssh_control_persist_secs) # tell ssh block as briefly as possible (1s) 

429 ssh_sock_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host 

430 ssh_sock_cmd += ["-M", f"-oControlPersist={ssh_control_persist_secs}s", remote.ssh_user_host, "exit"] 

431 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_sock_cmd)) 

432 t = timeout(job) 

433 try: 

434 sp.subprocess_run(ssh_sock_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, check=True, timeout=t, log=log) 

435 except subprocess.CalledProcessError as e: 

436 log.error("%s", stderr_to_str(e.stderr).rstrip()) 

437 raise RetryableError( 

438 f"Cannot ssh into remote host via '{' '.join(ssh_sock_cmd)}'. Fix ssh configuration first, " 

439 "considering diagnostic log file output from running with -v -v -v.", 

440 display_msg="ssh connect", 

441 ) from e 

442 self._last_refresh_time = time.monotonic_ns() 

443 if self._connection_lease is not None: 

444 self._connection_lease.set_socket_mtime_to_now() 

445 

446 def _is_ssh_control_socket_usable(self, socket_path: str) -> bool: 

447 """To improve ssh perf, check whether a control socket path is a live Unix-domain listener; this helps detect stale 

448 socket files that still exist after master crashes. 

449 

450 _is_ssh_control_socket_usable() is ~300x faster than `ssh ... -O check`: ~5 microseconds vs ~1-2 milliseconds. 

451 """ 

452 try: 

453 st_mode: int = os.stat(socket_path, follow_symlinks=False).st_mode 

454 if not stat.S_ISSOCK(st_mode): 

455 return False 

456 except OSError: 

457 return False 

458 try: 

459 with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: 

460 return sock.connect_ex(socket_path) == 0 

461 except OSError: 

462 return False 

463 

464 def _increment_free(self, value: int) -> None: 

465 """Adjusts the count of available SSH slots.""" 

466 self._free += value 

467 assert self._free >= 0 

468 assert self._free <= self._capacity 

469 

470 def _is_full(self) -> bool: 

471 """Returns True if no more SSH sessions may be opened over this TCP connection.""" 

472 return self._free <= 0 

473 

474 def _update_last_modified(self, last_modified: int) -> None: 

475 """Records when the connection was last used.""" 

476 self._last_modified = last_modified 

477 

478 def shutdown(self, msg_prefix: str) -> None: 

479 """Closes the underlying SSH master connection and releases the corresponding connection lease.""" 

480 ssh_cmd: list[str] = self._ssh_cmd 

481 if ssh_cmd and self._reuse_ssh_connection: 

482 if self._connection_lease is None: 

483 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]] 

484 log = self._remote.params.log 

485 log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_sock_cmd)) 

486 try: 

487 proc: subprocess.CompletedProcess = subprocess.run( 

488 ssh_sock_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=0.1 

489 ) 

490 except subprocess.TimeoutExpired as e: # harmless as master auto-exits after ssh_control_persist_secs anyway 

491 log.log(LOG_TRACE, "Harmless ssh master connection shutdown timeout: %s", e) 

492 else: 

493 if proc.returncode != 0: # harmless for the same reason 

494 log.log(LOG_TRACE, "Harmless ssh master connection shutdown issue: %s", proc.stderr.rstrip()) 

495 else: 

496 self._connection_lease.release() 

497 

498 

499############################################################################# 

500class ConnectionPool: 

501 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse; 

502 Note that max_concurrent_ssh_sessions_per_tcp_connection must not be larger than the server-side sshd_config(5) 

503 MaxSessions parameter (which defaults to 10, see https://manpages.ubuntu.com/manpages/man5/sshd_config.5.html).""" 

504 

505 def __init__( 

506 self, remote: MiniRemote, connpool_name: str, max_concurrent_ssh_sessions_per_tcp_connection: int = 8 

507 ) -> None: 

508 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

509 self._remote: Final[MiniRemote] = copy.copy(remote) # shallow copy for immutability (Remote is mutable) 

510 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection 

511 self._connpool_name: Final[str] = connpool_name 

512 self._priority_queue: Final[SmallPriorityQueue[Connection]] = SmallPriorityQueue( 

513 reverse=True # sorted by #free slots and last_modified 

514 ) 

515 self._last_modified: int = 0 # monotonically increasing sequence number 

516 self._lock: Final[threading.Lock] = threading.Lock() 

517 lease_mgr: ConnectionLeaseManager | None = None 

518 if self._remote.ssh_user_host and self._remote.reuse_ssh_connection and not self._remote.ssh_exit_on_shutdown: 

519 lease_mgr = ConnectionLeaseManager( 

520 root_dir=self._remote.ssh_socket_dir, 

521 namespace=f"{self._remote.location}#{self._remote.cache_namespace()}#{self._connpool_name}", 

522 ssh_control_persist_secs=max(90 * 60, 2 * self._remote.ssh_control_persist_secs + 2), 

523 log=self._remote.params.log, 

524 ) 

525 self._lease_mgr: Final[ConnectionLeaseManager | None] = lease_mgr 

526 

527 @contextlib.contextmanager 

528 def connection(self) -> Iterator[Connection]: 

529 """Context manager that yields a connection from the pool and automatically returns it on __exit__.""" 

530 conn: Connection = self.get_connection() 

531 try: 

532 yield conn 

533 finally: 

534 self.return_connection(conn) 

535 

536 def get_connection(self) -> Connection: 

537 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue 

538 while it is "checked out", and that identical Connection object is later, on return_connection(), temporarily removed 

539 from the priority queue, updated with an incremented "free" slot count and then immediately reinserted into the 

540 priority queue. 

541 

542 In effect, any Connection object remains intentionally contained in the priority queue at all times. This design 

543 keeps ordering/fairness accurate while avoiding duplicate Connection instances. 

544 """ 

545 with self._lock: 

546 conn = self._priority_queue.pop() if len(self._priority_queue) > 0 else None 

547 if conn is None or conn._is_full(): # noqa: SLF001 # pylint: disable=protected-access 

548 if conn is not None: 

549 self._priority_queue.push(conn) 

550 conn = self._new_connection() # add a new connection 

551 self._last_modified += 1 

552 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access 

553 conn._increment_free(-1) # noqa: SLF001 # pylint: disable=protected-access 

554 self._priority_queue.push(conn) 

555 return conn 

556 

557 def _new_connection(self) -> Connection: 

558 lease: ConnectionLease | None = None if self._lease_mgr is None else self._lease_mgr.acquire() 

559 return Connection(self._remote, self._capacity, lease=lease) 

560 

561 def return_connection(self, conn: Connection) -> None: 

562 """Returns the given connection to the pool and updates its priority.""" 

563 assert conn is not None 

564 with self._lock: 

565 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue 

566 if self._priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called 

567 conn._increment_free(1) # noqa: SLF001 # pylint: disable=protected-access 

568 self._last_modified += 1 

569 conn._update_last_modified(self._last_modified) # noqa: SLF001 # pylint: disable=protected-access 

570 self._priority_queue.push(conn) 

571 

572 def shutdown(self, msg_prefix: str = "") -> None: 

573 """Closes all SSH connections managed by this pool.""" 

574 with self._lock: 

575 try: 

576 if self._remote.reuse_ssh_connection: 

577 msg_prefix = msg_prefix + "/" + self._connpool_name 

578 for conn in self._priority_queue: 

579 conn.shutdown(msg_prefix) 

580 finally: 

581 self._priority_queue.clear() 

582 

583 def __repr__(self) -> str: 

584 with self._lock: 

585 queue = self._priority_queue 

586 return str({"capacity": self._capacity, "queue_len": len(queue), "queue": queue}) 

587 

588 

589############################################################################# 

590@final 

591class ConnectionPools: 

592 """A bunch of named connection pools with various multiplexing capacities.""" 

593 

594 def __init__(self, remote: MiniRemote, *, capacities: dict[str, int]) -> None: 

595 """Creates one connection pool per name with the given capacities.""" 

596 self._pools: Final[dict[str, ConnectionPool]] = { 

597 name: ConnectionPool(remote, name, capacity) for name, capacity in capacities.items() 

598 } 

599 

600 def __repr__(self) -> str: 

601 return str(self._pools) 

602 

603 def pool(self, name: str) -> ConnectionPool: 

604 """Returns the pool associated with the given name.""" 

605 return self._pools[name] 

606 

607 def shutdown(self, msg_prefix: str = "") -> None: 

608 """Shuts down every contained pool.""" 

609 for pool in self._pools.values(): 

610 pool.shutdown(msg_prefix)