Coverage for bzfs_main/connection.py: 99%

217 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-07 04:44 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Network connection management is in refresh_ssh_connection_if_necessary() and class ConnectionPool; They reuse multiplexed 

16ssh connections for low latency.""" 

17 

18from __future__ import ( 

19 annotations, 

20) 

21import contextlib 

22import copy 

23import logging 

24import shlex 

25import subprocess 

26import threading 

27import time 

28from collections import ( 

29 Counter, 

30) 

31from collections.abc import ( 

32 Iterator, 

33) 

34from dataclasses import ( 

35 dataclass, 

36) 

37from subprocess import ( 

38 DEVNULL, 

39 PIPE, 

40 CalledProcessError, 

41 CompletedProcess, 

42) 

43from typing import ( 

44 TYPE_CHECKING, 

45 Final, 

46) 

47 

48from bzfs_main.connection_lease import ( 

49 ConnectionLease, 

50 ConnectionLeaseManager, 

51) 

52from bzfs_main.retry import ( 

53 RetryableError, 

54) 

55from bzfs_main.utils import ( 

56 LOG_TRACE, 

57 PROG_NAME, 

58 SmallPriorityQueue, 

59 Subprocesses, 

60 die, 

61 list_formatter, 

62 stderr_to_str, 

63 xprint, 

64) 

65 

66if TYPE_CHECKING: # pragma: no cover - for type hints only 

67 from bzfs_main.bzfs import ( 

68 Job, 

69 ) 

70 from bzfs_main.configuration import ( 

71 Params, 

72 Remote, 

73 ) 

74 

75# constants: 

76SHARED: Final[str] = "shared" 

77DEDICATED: Final[str] = "dedicated" 

78 

79 

80def run_ssh_command( 

81 job: Job, 

82 remote: Remote, 

83 level: int = -1, 

84 is_dry: bool = False, 

85 check: bool = True, 

86 print_stdout: bool = False, 

87 print_stderr: bool = True, 

88 cmd: list[str] | None = None, 

89) -> str: 

90 """Runs the given CLI cmd via ssh on the given remote, and returns stdout. 

91 

92 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host 

93 ($remote.local_ssh_command()) and the command to run on the given remote host ($cmd). 

94 

95 Note: When executing on a remote host (remote.ssh_user_host is set), cmd arguments are pre-quoted with shlex.quote to 

96 safely traverse the ssh "remote shell" boundary, as ssh concatenates argv into a single remote shell string. In local 

97 mode (no remote.ssh_user_host) argv is executed directly without an intermediate shell. 

98 """ 

99 level = level if level >= 0 else logging.INFO 

100 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0 

101 p, log = job.params, job.params.log 

102 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd] 

103 conn_pool: ConnectionPool = p.connection_pools[remote.location].pool(SHARED) 

104 with conn_pool.connection() as conn: 

105 ssh_cmd: list[str] = conn.ssh_cmd 

106 if remote.ssh_user_host: 

107 refresh_ssh_connection_if_necessary(job, remote, conn) 

108 cmd = quoted_cmd 

109 msg: str = "Would execute: %s" if is_dry else "Executing: %s" 

110 log.log(level, msg, list_formatter(conn.ssh_cmd_quoted + quoted_cmd, lstrip=True)) 

111 if is_dry: 

112 return "" 

113 try: 

114 sp: Subprocesses = job.subprocesses 

115 process: CompletedProcess[str] = sp.subprocess_run( 

116 ssh_cmd + cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True, timeout=timeout(job), check=check 

117 ) 

118 except (subprocess.CalledProcessError, subprocess.TimeoutExpired, UnicodeDecodeError) as e: 

119 if not isinstance(e, UnicodeDecodeError): 

120 xprint(log, stderr_to_str(e.stdout), run=print_stdout, end="") 

121 xprint(log, stderr_to_str(e.stderr), run=print_stderr, end="") 

122 raise 

123 else: 

124 xprint(log, process.stdout, run=print_stdout, end="") 

125 xprint(log, process.stderr, run=print_stderr, end="") 

126 return process.stdout 

127 

128 

129def try_ssh_command( 

130 job: Job, 

131 remote: Remote, 

132 level: int, 

133 is_dry: bool = False, 

134 print_stdout: bool = False, 

135 cmd: list[str] | None = None, 

136 exists: bool = True, 

137 error_trigger: str | None = None, 

138) -> str | None: 

139 """Convenience method that helps retry/react to a dataset or pool that potentially doesn't exist anymore.""" 

140 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0 

141 log = job.params.log 

142 try: 

143 maybe_inject_error(job, cmd=cmd, error_trigger=error_trigger) 

144 return run_ssh_command(job, remote, level=level, is_dry=is_dry, print_stdout=print_stdout, cmd=cmd) 

145 except (subprocess.CalledProcessError, UnicodeDecodeError) as e: 

146 if not isinstance(e, UnicodeDecodeError): 

147 stderr: str = stderr_to_str(e.stderr) 

148 if exists and ( 

149 ": dataset does not exist" in stderr 

150 or ": filesystem does not exist" in stderr # solaris 11.4.0 

151 or ": no such pool" in stderr 

152 ): 

153 return None 

154 log.warning("%s", stderr.rstrip()) 

155 raise RetryableError("Subprocess failed") from e 

156 

157 

158def refresh_ssh_connection_if_necessary(job: Job, remote: Remote, conn: Connection) -> None: 

159 """Maintain or create an ssh master connection for low latency reuse.""" 

160 p, log = job.params, job.params.log 

161 if not remote.ssh_user_host: 

162 return # we're in local mode; no ssh required 

163 if not p.is_program_available("ssh", "local"): 

164 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!") 

165 if not remote.reuse_ssh_connection: 

166 return 

167 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and 

168 # 'ssh -S -M -oControlPersist=90s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing 

169 # and https://chessman7.substack.com/p/how-ssh-multiplexing-reuses-master 

170 control_persist_limit_nanos: int = (remote.ssh_control_persist_secs - job.control_persist_margin_secs) * 1_000_000_000 

171 with conn.lock: 

172 if time.monotonic_ns() < conn.last_refresh_time + control_persist_limit_nanos: 

173 return # ssh master is alive, reuse its TCP connection (this is the common case and the ultra-fast path) 

174 ssh_cmd: list[str] = conn.ssh_cmd 

175 ssh_socket_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host 

176 ssh_socket_cmd += ["-O", "check", remote.ssh_user_host] 

177 # extend lifetime of ssh master by $ssh_control_persist_secs via 'ssh -O check' if master is still running. 

178 # 'ssh -S /path/to/socket -O check' doesn't talk over the network, hence is still a low latency fast path. 

179 sp: Subprocesses = job.subprocesses 

180 if sp.subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, timeout=timeout(job)).returncode == 0: 

181 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_socket_cmd)) 

182 else: # ssh master is not alive; start a new master: 

183 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_socket_cmd)) 

184 ssh_control_persist_secs: int = remote.ssh_control_persist_secs 

185 if "-v" in remote.ssh_extra_opts: 

186 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the 

187 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ... 

188 ssh_control_persist_secs = min(1, ssh_control_persist_secs) # tell ssh to block as briefly as possible (1s) 

189 ssh_socket_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host 

190 ssh_socket_cmd += ["-M", f"-oControlPersist={ssh_control_persist_secs}s", remote.ssh_user_host, "exit"] 

191 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_socket_cmd)) 

192 try: 

193 sp.subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, check=True, timeout=timeout(job)) 

194 except subprocess.CalledProcessError as e: 

195 log.error("%s", stderr_to_str(e.stderr).rstrip()) 

196 raise RetryableError( 

197 f"Cannot ssh into remote host via '{' '.join(ssh_socket_cmd)}'. Fix ssh configuration " 

198 f"first, considering diagnostic log file output from running {PROG_NAME} with -v -v -v." 

199 ) from e 

200 conn.last_refresh_time = time.monotonic_ns() 

201 if conn.connection_lease is not None: 

202 conn.connection_lease.set_socket_mtime_to_now() 

203 

204 

205def timeout(job: Job) -> float | None: 

206 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur.""" 

207 timeout_nanos: int | None = job.timeout_nanos 

208 if timeout_nanos is None: 

209 return None # never raise a timeout 

210 delta_nanos: int = timeout_nanos - time.monotonic_ns() 

211 if delta_nanos <= 0: 

212 assert job.params.timeout_nanos is not None 

213 raise subprocess.TimeoutExpired(PROG_NAME + "_timeout", timeout=job.params.timeout_nanos / 1_000_000_000) 

214 return delta_nanos / 1_000_000_000 # seconds 

215 

216 

217def maybe_inject_error(job: Job, cmd: list[str], error_trigger: str | None = None) -> None: 

218 """For testing only; for unit tests to simulate errors during replication and test correct handling of them.""" 

219 if error_trigger: 

220 counter = job.error_injection_triggers.get("before") 

221 if counter and decrement_injection_counter(job, counter, error_trigger): 

222 try: 

223 raise CalledProcessError(returncode=1, cmd=" ".join(cmd), stderr=error_trigger + ":dataset is busy") 

224 except subprocess.CalledProcessError as e: 

225 if error_trigger.startswith("retryable_"): 

226 raise RetryableError("Subprocess failed") from e 

227 else: 

228 raise 

229 

230 

231def decrement_injection_counter(job: Job, counter: Counter[str], trigger: str) -> bool: 

232 """For testing only.""" 

233 with job.injection_lock: 

234 if counter[trigger] <= 0: 

235 return False 

236 counter[trigger] -= 1 

237 return True 

238 

239 

240############################################################################# 

241@dataclass(order=True, repr=False) 

242class Connection: 

243 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection.""" 

244 

245 _free: int # sort order evens out the number of concurrent sessions among the TCP connections 

246 _last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot; also ensures no dupes 

247 

248 def __init__( 

249 self, 

250 remote: Remote, 

251 max_concurrent_ssh_sessions_per_tcp_connection: int, 

252 cid: int, 

253 lease: ConnectionLease | None = None, 

254 ) -> None: 

255 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

256 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection 

257 self._free: int = max_concurrent_ssh_sessions_per_tcp_connection 

258 self._last_modified: int = 0 # monotonically increasing 

259 self._cid: Final[int] = cid 

260 self.last_refresh_time: int = 0 

261 self.lock: Final[threading.Lock] = threading.Lock() 

262 self._reuse_ssh_connection: Final[bool] = remote.reuse_ssh_connection 

263 self.connection_lease: Final[ConnectionLease | None] = lease 

264 self.ssh_cmd: Final[list[str]] = remote.local_ssh_command( 

265 None if self.connection_lease is None else self.connection_lease.socket_path 

266 ) 

267 self.ssh_cmd_quoted: Final[list[str]] = [shlex.quote(item) for item in self.ssh_cmd] 

268 

269 def __repr__(self) -> str: 

270 return str({"free": self._free, "cid": self._cid}) 

271 

272 def increment_free(self, value: int) -> None: 

273 """Adjusts the count of available SSH slots.""" 

274 self._free += value 

275 assert self._free >= 0 

276 assert self._free <= self._capacity 

277 

278 def is_full(self) -> bool: 

279 """Returns True if no more SSH sessions may be opened over this TCP connection.""" 

280 return self._free <= 0 

281 

282 def update_last_modified(self, last_modified: int) -> None: 

283 """Records when the connection was last used.""" 

284 self._last_modified = last_modified 

285 

286 def shutdown(self, msg_prefix: str, p: Params) -> None: 

287 """Closes the underlying SSH master connection and releases the corresponding connection lease.""" 

288 ssh_cmd: list[str] = self.ssh_cmd 

289 if ssh_cmd and self._reuse_ssh_connection: 

290 if self.connection_lease is None: 

291 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]] 

292 p.log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_sock_cmd)) 

293 try: 

294 proc: CompletedProcess = subprocess.run(ssh_sock_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=0.1) 

295 except subprocess.TimeoutExpired as e: # harmless as master auto-exits after ssh_control_persist_secs anyway 

296 p.log.log(LOG_TRACE, "Harmless ssh master connection shutdown timeout: %s", e) 

297 else: 

298 if proc.returncode != 0: # harmless for the same reason 

299 p.log.log(LOG_TRACE, "Harmless ssh master connection shutdown issue: %s", proc.stderr.rstrip()) 

300 else: 

301 self.connection_lease.release() 

302 

303 

304############################################################################# 

305class ConnectionPool: 

306 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse.""" 

307 

308 def __init__(self, remote: Remote, max_concurrent_ssh_sessions_per_tcp_connection: int, connpool_name: str) -> None: 

309 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

310 self._remote: Final[Remote] = copy.copy(remote) # shallow copy for immutability (Remote is mutable) 

311 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection 

312 self._connpool_name: Final[str] = connpool_name 

313 self._priority_queue: Final[SmallPriorityQueue[Connection]] = SmallPriorityQueue( 

314 reverse=True # sorted by #free slots and last_modified 

315 ) 

316 self._last_modified: int = 0 # monotonically increasing sequence number 

317 self._cid: int = 0 # monotonically increasing connection number 

318 self._lock: Final[threading.Lock] = threading.Lock() 

319 lease_mgr: ConnectionLeaseManager | None = None 

320 if self._remote.ssh_user_host and self._remote.reuse_ssh_connection and not self._remote.ssh_exit_on_shutdown: 

321 lease_mgr = ConnectionLeaseManager( 

322 root_dir=self._remote.ssh_socket_dir, 

323 namespace=f"{self._remote.location}#{self._remote.cache_namespace()}#{self._connpool_name}", 

324 ssh_control_persist_secs=max(90 * 60, 2 * self._remote.ssh_control_persist_secs + 2), 

325 log=self._remote.params.log, 

326 ) 

327 self._lease_mgr: Final[ConnectionLeaseManager | None] = lease_mgr 

328 

329 @contextlib.contextmanager 

330 def connection(self) -> Iterator[Connection]: 

331 """Context manager that yields a connection from the pool and automatically returns it on __exit__.""" 

332 conn: Connection = self.get_connection() 

333 try: 

334 yield conn 

335 finally: 

336 self.return_connection(conn) 

337 

338 def get_connection(self) -> Connection: 

339 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue 

340 while it is "checked out", and that identical Connection object is later, on return_connection(), temporarily removed 

341 from the priority queue, updated with an incremented "free" slot count and then immediately reinserted into the 

342 priority queue. 

343 

344 In effect, any Connection object remains intentionally contained in the priority queue at all times. This design 

345 keeps ordering/fairness accurate while avoiding duplicate Connection instances. 

346 """ 

347 with self._lock: 

348 conn = self._priority_queue.pop() if len(self._priority_queue) > 0 else None 

349 if conn is None or conn.is_full(): 

350 if conn is not None: 

351 self._priority_queue.push(conn) 

352 lease: ConnectionLease | None = None if self._lease_mgr is None else self._lease_mgr.acquire() 

353 conn = Connection(self._remote, self._capacity, self._cid, lease=lease) # add a new connection 

354 self._last_modified += 1 

355 conn.update_last_modified(self._last_modified) # LIFO tiebreaker favors latest conn as that's most alive 

356 self._cid += 1 

357 conn.increment_free(-1) 

358 self._priority_queue.push(conn) 

359 return conn 

360 

361 def return_connection(self, conn: Connection) -> None: 

362 """Returns the given connection to the pool and updates its priority.""" 

363 assert conn is not None 

364 with self._lock: 

365 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue 

366 if self._priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called 

367 conn.increment_free(1) 

368 self._last_modified += 1 

369 conn.update_last_modified(self._last_modified) # LIFO tiebreaker favors latest conn as that's most alive 

370 self._priority_queue.push(conn) 

371 

372 def shutdown(self, msg_prefix: str) -> None: 

373 """Closes all SSH connections managed by this pool.""" 

374 with self._lock: 

375 try: 

376 if self._remote.reuse_ssh_connection: 

377 for conn in self._priority_queue: 

378 conn.shutdown(msg_prefix, self._remote.params) 

379 finally: 

380 self._priority_queue.clear() 

381 

382 def __repr__(self) -> str: 

383 with self._lock: 

384 queue = self._priority_queue 

385 return str({"capacity": self._capacity, "queue_len": len(queue), "cid": self._cid, "queue": queue}) 

386 

387 

388############################################################################# 

389class ConnectionPools: 

390 """A bunch of named connection pools with various multiplexing capacities.""" 

391 

392 def __init__(self, remote: Remote, capacities: dict[str, int]) -> None: 

393 """Creates one connection pool per name with the given capacities.""" 

394 self._pools: Final[dict[str, ConnectionPool]] = { 

395 name: ConnectionPool(remote, capacity, name) for name, capacity in capacities.items() 

396 } 

397 

398 def __repr__(self) -> str: 

399 return str(self._pools) 

400 

401 def pool(self, name: str) -> ConnectionPool: 

402 """Returns the pool associated with the given name.""" 

403 return self._pools[name] 

404 

405 def shutdown(self, msg_prefix: str) -> None: 

406 """Shuts down every contained pool.""" 

407 for name, pool in self._pools.items(): 

408 pool.shutdown(msg_prefix + "/" + name)