Coverage for bzfs_main/connection.py: 99%
217 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-07 04:44 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-07 04:44 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Network connection management is in refresh_ssh_connection_if_necessary() and class ConnectionPool; They reuse multiplexed
16ssh connections for low latency."""
18from __future__ import (
19 annotations,
20)
21import contextlib
22import copy
23import logging
24import shlex
25import subprocess
26import threading
27import time
28from collections import (
29 Counter,
30)
31from collections.abc import (
32 Iterator,
33)
34from dataclasses import (
35 dataclass,
36)
37from subprocess import (
38 DEVNULL,
39 PIPE,
40 CalledProcessError,
41 CompletedProcess,
42)
43from typing import (
44 TYPE_CHECKING,
45 Final,
46)
48from bzfs_main.connection_lease import (
49 ConnectionLease,
50 ConnectionLeaseManager,
51)
52from bzfs_main.retry import (
53 RetryableError,
54)
55from bzfs_main.utils import (
56 LOG_TRACE,
57 PROG_NAME,
58 SmallPriorityQueue,
59 Subprocesses,
60 die,
61 list_formatter,
62 stderr_to_str,
63 xprint,
64)
66if TYPE_CHECKING: # pragma: no cover - for type hints only
67 from bzfs_main.bzfs import (
68 Job,
69 )
70 from bzfs_main.configuration import (
71 Params,
72 Remote,
73 )
75# constants:
76SHARED: Final[str] = "shared"
77DEDICATED: Final[str] = "dedicated"
80def run_ssh_command(
81 job: Job,
82 remote: Remote,
83 level: int = -1,
84 is_dry: bool = False,
85 check: bool = True,
86 print_stdout: bool = False,
87 print_stderr: bool = True,
88 cmd: list[str] | None = None,
89) -> str:
90 """Runs the given CLI cmd via ssh on the given remote, and returns stdout.
92 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host
93 ($remote.local_ssh_command()) and the command to run on the given remote host ($cmd).
95 Note: When executing on a remote host (remote.ssh_user_host is set), cmd arguments are pre-quoted with shlex.quote to
96 safely traverse the ssh "remote shell" boundary, as ssh concatenates argv into a single remote shell string. In local
97 mode (no remote.ssh_user_host) argv is executed directly without an intermediate shell.
98 """
99 level = level if level >= 0 else logging.INFO
100 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0
101 p, log = job.params, job.params.log
102 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd]
103 conn_pool: ConnectionPool = p.connection_pools[remote.location].pool(SHARED)
104 with conn_pool.connection() as conn:
105 ssh_cmd: list[str] = conn.ssh_cmd
106 if remote.ssh_user_host:
107 refresh_ssh_connection_if_necessary(job, remote, conn)
108 cmd = quoted_cmd
109 msg: str = "Would execute: %s" if is_dry else "Executing: %s"
110 log.log(level, msg, list_formatter(conn.ssh_cmd_quoted + quoted_cmd, lstrip=True))
111 if is_dry:
112 return ""
113 try:
114 sp: Subprocesses = job.subprocesses
115 process: CompletedProcess[str] = sp.subprocess_run(
116 ssh_cmd + cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True, timeout=timeout(job), check=check
117 )
118 except (subprocess.CalledProcessError, subprocess.TimeoutExpired, UnicodeDecodeError) as e:
119 if not isinstance(e, UnicodeDecodeError):
120 xprint(log, stderr_to_str(e.stdout), run=print_stdout, end="")
121 xprint(log, stderr_to_str(e.stderr), run=print_stderr, end="")
122 raise
123 else:
124 xprint(log, process.stdout, run=print_stdout, end="")
125 xprint(log, process.stderr, run=print_stderr, end="")
126 return process.stdout
129def try_ssh_command(
130 job: Job,
131 remote: Remote,
132 level: int,
133 is_dry: bool = False,
134 print_stdout: bool = False,
135 cmd: list[str] | None = None,
136 exists: bool = True,
137 error_trigger: str | None = None,
138) -> str | None:
139 """Convenience method that helps retry/react to a dataset or pool that potentially doesn't exist anymore."""
140 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0
141 log = job.params.log
142 try:
143 maybe_inject_error(job, cmd=cmd, error_trigger=error_trigger)
144 return run_ssh_command(job, remote, level=level, is_dry=is_dry, print_stdout=print_stdout, cmd=cmd)
145 except (subprocess.CalledProcessError, UnicodeDecodeError) as e:
146 if not isinstance(e, UnicodeDecodeError):
147 stderr: str = stderr_to_str(e.stderr)
148 if exists and (
149 ": dataset does not exist" in stderr
150 or ": filesystem does not exist" in stderr # solaris 11.4.0
151 or ": no such pool" in stderr
152 ):
153 return None
154 log.warning("%s", stderr.rstrip())
155 raise RetryableError("Subprocess failed") from e
158def refresh_ssh_connection_if_necessary(job: Job, remote: Remote, conn: Connection) -> None:
159 """Maintain or create an ssh master connection for low latency reuse."""
160 p, log = job.params, job.params.log
161 if not remote.ssh_user_host:
162 return # we're in local mode; no ssh required
163 if not p.is_program_available("ssh", "local"):
164 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!")
165 if not remote.reuse_ssh_connection:
166 return
167 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and
168 # 'ssh -S -M -oControlPersist=90s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing
169 # and https://chessman7.substack.com/p/how-ssh-multiplexing-reuses-master
170 control_persist_limit_nanos: int = (remote.ssh_control_persist_secs - job.control_persist_margin_secs) * 1_000_000_000
171 with conn.lock:
172 if time.monotonic_ns() < conn.last_refresh_time + control_persist_limit_nanos:
173 return # ssh master is alive, reuse its TCP connection (this is the common case and the ultra-fast path)
174 ssh_cmd: list[str] = conn.ssh_cmd
175 ssh_socket_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host
176 ssh_socket_cmd += ["-O", "check", remote.ssh_user_host]
177 # extend lifetime of ssh master by $ssh_control_persist_secs via 'ssh -O check' if master is still running.
178 # 'ssh -S /path/to/socket -O check' doesn't talk over the network, hence is still a low latency fast path.
179 sp: Subprocesses = job.subprocesses
180 if sp.subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, timeout=timeout(job)).returncode == 0:
181 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_socket_cmd))
182 else: # ssh master is not alive; start a new master:
183 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_socket_cmd))
184 ssh_control_persist_secs: int = remote.ssh_control_persist_secs
185 if "-v" in remote.ssh_extra_opts:
186 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the
187 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ...
188 ssh_control_persist_secs = min(1, ssh_control_persist_secs) # tell ssh to block as briefly as possible (1s)
189 ssh_socket_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host
190 ssh_socket_cmd += ["-M", f"-oControlPersist={ssh_control_persist_secs}s", remote.ssh_user_host, "exit"]
191 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_socket_cmd))
192 try:
193 sp.subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, check=True, timeout=timeout(job))
194 except subprocess.CalledProcessError as e:
195 log.error("%s", stderr_to_str(e.stderr).rstrip())
196 raise RetryableError(
197 f"Cannot ssh into remote host via '{' '.join(ssh_socket_cmd)}'. Fix ssh configuration "
198 f"first, considering diagnostic log file output from running {PROG_NAME} with -v -v -v."
199 ) from e
200 conn.last_refresh_time = time.monotonic_ns()
201 if conn.connection_lease is not None:
202 conn.connection_lease.set_socket_mtime_to_now()
205def timeout(job: Job) -> float | None:
206 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur."""
207 timeout_nanos: int | None = job.timeout_nanos
208 if timeout_nanos is None:
209 return None # never raise a timeout
210 delta_nanos: int = timeout_nanos - time.monotonic_ns()
211 if delta_nanos <= 0:
212 assert job.params.timeout_nanos is not None
213 raise subprocess.TimeoutExpired(PROG_NAME + "_timeout", timeout=job.params.timeout_nanos / 1_000_000_000)
214 return delta_nanos / 1_000_000_000 # seconds
217def maybe_inject_error(job: Job, cmd: list[str], error_trigger: str | None = None) -> None:
218 """For testing only; for unit tests to simulate errors during replication and test correct handling of them."""
219 if error_trigger:
220 counter = job.error_injection_triggers.get("before")
221 if counter and decrement_injection_counter(job, counter, error_trigger):
222 try:
223 raise CalledProcessError(returncode=1, cmd=" ".join(cmd), stderr=error_trigger + ":dataset is busy")
224 except subprocess.CalledProcessError as e:
225 if error_trigger.startswith("retryable_"):
226 raise RetryableError("Subprocess failed") from e
227 else:
228 raise
231def decrement_injection_counter(job: Job, counter: Counter[str], trigger: str) -> bool:
232 """For testing only."""
233 with job.injection_lock:
234 if counter[trigger] <= 0:
235 return False
236 counter[trigger] -= 1
237 return True
240#############################################################################
241@dataclass(order=True, repr=False)
242class Connection:
243 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection."""
245 _free: int # sort order evens out the number of concurrent sessions among the TCP connections
246 _last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot; also ensures no dupes
248 def __init__(
249 self,
250 remote: Remote,
251 max_concurrent_ssh_sessions_per_tcp_connection: int,
252 cid: int,
253 lease: ConnectionLease | None = None,
254 ) -> None:
255 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
256 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection
257 self._free: int = max_concurrent_ssh_sessions_per_tcp_connection
258 self._last_modified: int = 0 # monotonically increasing
259 self._cid: Final[int] = cid
260 self.last_refresh_time: int = 0
261 self.lock: Final[threading.Lock] = threading.Lock()
262 self._reuse_ssh_connection: Final[bool] = remote.reuse_ssh_connection
263 self.connection_lease: Final[ConnectionLease | None] = lease
264 self.ssh_cmd: Final[list[str]] = remote.local_ssh_command(
265 None if self.connection_lease is None else self.connection_lease.socket_path
266 )
267 self.ssh_cmd_quoted: Final[list[str]] = [shlex.quote(item) for item in self.ssh_cmd]
269 def __repr__(self) -> str:
270 return str({"free": self._free, "cid": self._cid})
272 def increment_free(self, value: int) -> None:
273 """Adjusts the count of available SSH slots."""
274 self._free += value
275 assert self._free >= 0
276 assert self._free <= self._capacity
278 def is_full(self) -> bool:
279 """Returns True if no more SSH sessions may be opened over this TCP connection."""
280 return self._free <= 0
282 def update_last_modified(self, last_modified: int) -> None:
283 """Records when the connection was last used."""
284 self._last_modified = last_modified
286 def shutdown(self, msg_prefix: str, p: Params) -> None:
287 """Closes the underlying SSH master connection and releases the corresponding connection lease."""
288 ssh_cmd: list[str] = self.ssh_cmd
289 if ssh_cmd and self._reuse_ssh_connection:
290 if self.connection_lease is None:
291 ssh_sock_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]]
292 p.log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_sock_cmd))
293 try:
294 proc: CompletedProcess = subprocess.run(ssh_sock_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=0.1)
295 except subprocess.TimeoutExpired as e: # harmless as master auto-exits after ssh_control_persist_secs anyway
296 p.log.log(LOG_TRACE, "Harmless ssh master connection shutdown timeout: %s", e)
297 else:
298 if proc.returncode != 0: # harmless for the same reason
299 p.log.log(LOG_TRACE, "Harmless ssh master connection shutdown issue: %s", proc.stderr.rstrip())
300 else:
301 self.connection_lease.release()
304#############################################################################
305class ConnectionPool:
306 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse."""
308 def __init__(self, remote: Remote, max_concurrent_ssh_sessions_per_tcp_connection: int, connpool_name: str) -> None:
309 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
310 self._remote: Final[Remote] = copy.copy(remote) # shallow copy for immutability (Remote is mutable)
311 self._capacity: Final[int] = max_concurrent_ssh_sessions_per_tcp_connection
312 self._connpool_name: Final[str] = connpool_name
313 self._priority_queue: Final[SmallPriorityQueue[Connection]] = SmallPriorityQueue(
314 reverse=True # sorted by #free slots and last_modified
315 )
316 self._last_modified: int = 0 # monotonically increasing sequence number
317 self._cid: int = 0 # monotonically increasing connection number
318 self._lock: Final[threading.Lock] = threading.Lock()
319 lease_mgr: ConnectionLeaseManager | None = None
320 if self._remote.ssh_user_host and self._remote.reuse_ssh_connection and not self._remote.ssh_exit_on_shutdown:
321 lease_mgr = ConnectionLeaseManager(
322 root_dir=self._remote.ssh_socket_dir,
323 namespace=f"{self._remote.location}#{self._remote.cache_namespace()}#{self._connpool_name}",
324 ssh_control_persist_secs=max(90 * 60, 2 * self._remote.ssh_control_persist_secs + 2),
325 log=self._remote.params.log,
326 )
327 self._lease_mgr: Final[ConnectionLeaseManager | None] = lease_mgr
329 @contextlib.contextmanager
330 def connection(self) -> Iterator[Connection]:
331 """Context manager that yields a connection from the pool and automatically returns it on __exit__."""
332 conn: Connection = self.get_connection()
333 try:
334 yield conn
335 finally:
336 self.return_connection(conn)
338 def get_connection(self) -> Connection:
339 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue
340 while it is "checked out", and that identical Connection object is later, on return_connection(), temporarily removed
341 from the priority queue, updated with an incremented "free" slot count and then immediately reinserted into the
342 priority queue.
344 In effect, any Connection object remains intentionally contained in the priority queue at all times. This design
345 keeps ordering/fairness accurate while avoiding duplicate Connection instances.
346 """
347 with self._lock:
348 conn = self._priority_queue.pop() if len(self._priority_queue) > 0 else None
349 if conn is None or conn.is_full():
350 if conn is not None:
351 self._priority_queue.push(conn)
352 lease: ConnectionLease | None = None if self._lease_mgr is None else self._lease_mgr.acquire()
353 conn = Connection(self._remote, self._capacity, self._cid, lease=lease) # add a new connection
354 self._last_modified += 1
355 conn.update_last_modified(self._last_modified) # LIFO tiebreaker favors latest conn as that's most alive
356 self._cid += 1
357 conn.increment_free(-1)
358 self._priority_queue.push(conn)
359 return conn
361 def return_connection(self, conn: Connection) -> None:
362 """Returns the given connection to the pool and updates its priority."""
363 assert conn is not None
364 with self._lock:
365 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue
366 if self._priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called
367 conn.increment_free(1)
368 self._last_modified += 1
369 conn.update_last_modified(self._last_modified) # LIFO tiebreaker favors latest conn as that's most alive
370 self._priority_queue.push(conn)
372 def shutdown(self, msg_prefix: str) -> None:
373 """Closes all SSH connections managed by this pool."""
374 with self._lock:
375 try:
376 if self._remote.reuse_ssh_connection:
377 for conn in self._priority_queue:
378 conn.shutdown(msg_prefix, self._remote.params)
379 finally:
380 self._priority_queue.clear()
382 def __repr__(self) -> str:
383 with self._lock:
384 queue = self._priority_queue
385 return str({"capacity": self._capacity, "queue_len": len(queue), "cid": self._cid, "queue": queue})
388#############################################################################
389class ConnectionPools:
390 """A bunch of named connection pools with various multiplexing capacities."""
392 def __init__(self, remote: Remote, capacities: dict[str, int]) -> None:
393 """Creates one connection pool per name with the given capacities."""
394 self._pools: Final[dict[str, ConnectionPool]] = {
395 name: ConnectionPool(remote, capacity, name) for name, capacity in capacities.items()
396 }
398 def __repr__(self) -> str:
399 return str(self._pools)
401 def pool(self, name: str) -> ConnectionPool:
402 """Returns the pool associated with the given name."""
403 return self._pools[name]
405 def shutdown(self, msg_prefix: str) -> None:
406 """Shuts down every contained pool."""
407 for name, pool in self._pools.items():
408 pool.shutdown(msg_prefix + "/" + name)