Coverage for bzfs_main/connection.py: 100%
196 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-06 13:30 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-06 13:30 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Network connection management is in refresh_ssh_connection_if_necessary() and class ConnectionPool; They reuse multiplexed
16ssh connections for low latency."""
18from __future__ import annotations
19import contextlib
20import copy
21import logging
22import shlex
23import subprocess
24import threading
25import time
26from dataclasses import dataclass
27from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess
28from typing import (
29 TYPE_CHECKING,
30 Counter,
31 Iterator,
32)
34from bzfs_main.retry import (
35 RetryableError,
36)
37from bzfs_main.utils import (
38 LOG_TRACE,
39 PROG_NAME,
40 SmallPriorityQueue,
41 die,
42 list_formatter,
43 stderr_to_str,
44 subprocess_run,
45 xprint,
46)
48if TYPE_CHECKING: # pragma: no cover - for type hints only
49 from bzfs_main.bzfs import Job
50 from bzfs_main.configuration import Params, Remote
52# constants:
53SHARED: str = "shared"
54DEDICATED: str = "dedicated"
57def run_ssh_command(
58 job: Job,
59 remote: Remote,
60 level: int = -1,
61 is_dry: bool = False,
62 check: bool = True,
63 print_stdout: bool = False,
64 print_stderr: bool = True,
65 cmd: list[str] | None = None,
66) -> str:
67 """Runs the given cmd via ssh on the given remote, and returns stdout.
69 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host
70 ($remote.local_ssh_command()) and the command to run on the given remote host ($cmd).
71 """
72 level = level if level >= 0 else logging.INFO
73 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0
74 p, log = job.params, job.params.log
75 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd]
76 conn_pool: ConnectionPool = p.connection_pools[remote.location].pool(SHARED)
77 with conn_pool.connection() as conn:
78 ssh_cmd: list[str] = conn.ssh_cmd
79 if remote.ssh_user_host != "":
80 refresh_ssh_connection_if_necessary(job, remote, conn)
81 cmd = quoted_cmd
82 msg: str = "Would execute: %s" if is_dry else "Executing: %s"
83 log.log(level, msg, list_formatter(conn.ssh_cmd_quoted + quoted_cmd, lstrip=True))
84 if is_dry:
85 return ""
86 try:
87 process: CompletedProcess = subprocess_run(
88 ssh_cmd + cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True, timeout=timeout(job), check=check
89 )
90 except (subprocess.CalledProcessError, subprocess.TimeoutExpired, UnicodeDecodeError) as e:
91 if not isinstance(e, UnicodeDecodeError):
92 xprint(log, stderr_to_str(e.stdout), run=print_stdout, end="")
93 xprint(log, stderr_to_str(e.stderr), run=print_stderr, end="")
94 raise
95 else:
96 xprint(log, process.stdout, run=print_stdout, end="")
97 xprint(log, process.stderr, run=print_stderr, end="")
98 return process.stdout # type: ignore[no-any-return] # need to ignore on python <= 3.8
101def try_ssh_command(
102 job: Job,
103 remote: Remote,
104 level: int,
105 is_dry: bool = False,
106 print_stdout: bool = False,
107 cmd: list[str] | None = None,
108 exists: bool = True,
109 error_trigger: str | None = None,
110) -> str | None:
111 """Convenience method that helps retry/react to a dataset or pool that potentially doesn't exist anymore."""
112 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0
113 log = job.params.log
114 try:
115 maybe_inject_error(job, cmd=cmd, error_trigger=error_trigger)
116 return run_ssh_command(job, remote, level=level, is_dry=is_dry, print_stdout=print_stdout, cmd=cmd)
117 except (subprocess.CalledProcessError, UnicodeDecodeError) as e:
118 if not isinstance(e, UnicodeDecodeError):
119 stderr: str = stderr_to_str(e.stderr)
120 if exists and (
121 ": dataset does not exist" in stderr
122 or ": filesystem does not exist" in stderr # solaris 11.4.0
123 or ": does not exist" in stderr # solaris 11.4.0 'zfs send' with missing snapshot
124 or ": no such pool" in stderr
125 ):
126 return None
127 log.warning("%s", stderr.rstrip())
128 raise RetryableError("Subprocess failed") from e
131def refresh_ssh_connection_if_necessary(job: Job, remote: Remote, conn: Connection) -> None:
132 """Maintain or create an ssh master connection for low latency reuse."""
133 p, log = job.params, job.params.log
134 if remote.ssh_user_host == "":
135 return # we're in local mode; no ssh required
136 if not p.is_program_available("ssh", "local"):
137 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!")
138 if not remote.reuse_ssh_connection:
139 return
140 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and
141 # 'ssh -S -M -oControlPersist=60s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing
142 control_persist_limit_nanos: int = (job.control_persist_secs - job.control_persist_margin_secs) * 1_000_000_000
143 with conn.lock:
144 if time.monotonic_ns() - conn.last_refresh_time < control_persist_limit_nanos:
145 return # ssh master is alive, reuse its TCP connection (this is the common case & the ultra-fast path)
146 ssh_cmd: list[str] = conn.ssh_cmd
147 ssh_socket_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host
148 ssh_socket_cmd += ["-O", "check", remote.ssh_user_host]
149 # extend lifetime of ssh master by $control_persist_secs via 'ssh -O check' if master is still running.
150 # 'ssh -S /path/to/socket -O check' doesn't talk over the network, hence is still a low latency fast path.
151 t: float | None = timeout(job)
152 if subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True, timeout=t).returncode == 0:
153 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_socket_cmd))
154 else: # ssh master is not alive; start a new master:
155 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_socket_cmd))
156 control_persist_secs: int = job.control_persist_secs
157 if "-v" in remote.ssh_extra_opts:
158 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the
159 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ...
160 control_persist_secs = min(control_persist_secs, 1) # tell ssh to block as briefly as possible (1 sec)
161 ssh_socket_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host
162 ssh_socket_cmd += ["-M", f"-oControlPersist={control_persist_secs}s", remote.ssh_user_host, "exit"]
163 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_socket_cmd))
164 process = subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=timeout(job))
165 if process.returncode != 0:
166 log.error("%s", process.stderr.rstrip())
167 die(
168 f"Cannot ssh into remote host via '{' '.join(ssh_socket_cmd)}'. Fix ssh configuration "
169 f"first, considering diagnostic log file output from running {PROG_NAME} with: -v -v -v"
170 )
171 conn.last_refresh_time = time.monotonic_ns()
174def timeout(job: Job) -> float | None:
175 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur."""
176 timeout_nanos: int | None = job.timeout_nanos
177 if timeout_nanos is None:
178 return None # never raise a timeout
179 delta_nanos: int = timeout_nanos - time.monotonic_ns()
180 if delta_nanos <= 0:
181 assert job.params.timeout_nanos is not None
182 raise subprocess.TimeoutExpired(PROG_NAME + "_timeout", timeout=job.params.timeout_nanos / 1_000_000_000)
183 return delta_nanos / 1_000_000_000 # seconds
186def maybe_inject_error(job: Job, cmd: list[str], error_trigger: str | None = None) -> None:
187 """For testing only; for unit tests to simulate errors during replication and test correct handling of them."""
188 if error_trigger:
189 counter = job.error_injection_triggers.get("before")
190 if counter and decrement_injection_counter(job, counter, error_trigger):
191 try:
192 raise CalledProcessError(returncode=1, cmd=" ".join(cmd), stderr=error_trigger + ":dataset is busy")
193 except subprocess.CalledProcessError as e:
194 if error_trigger.startswith("retryable_"):
195 raise RetryableError("Subprocess failed") from e
196 else:
197 raise
200def decrement_injection_counter(job: Job, counter: Counter[str], trigger: str) -> bool:
201 """For testing only."""
202 with job.injection_lock:
203 if counter[trigger] <= 0:
204 return False
205 counter[trigger] -= 1
206 return True
209#############################################################################
210@dataclass(order=True, repr=False)
211class Connection:
212 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection."""
214 free: int # sort order evens out the number of concurrent sessions among the TCP connections
215 last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot
217 def __init__(self, remote: Remote, max_concurrent_ssh_sessions_per_tcp_connection: int, cid: int) -> None:
218 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
219 self.capacity: int = max_concurrent_ssh_sessions_per_tcp_connection
220 self.free: int = max_concurrent_ssh_sessions_per_tcp_connection
221 self.last_modified: int = 0
222 self.cid: int = cid
223 self.ssh_cmd: list[str] = remote.local_ssh_command()
224 self.ssh_cmd_quoted: list[str] = [shlex.quote(item) for item in self.ssh_cmd]
225 self.lock: threading.Lock = threading.Lock()
226 self.last_refresh_time: int = 0
228 def __repr__(self) -> str:
229 return str({"free": self.free, "cid": self.cid})
231 def increment_free(self, value: int) -> None:
232 """Adjusts the count of available SSH slots."""
233 self.free += value
234 assert self.free >= 0
235 assert self.free <= self.capacity
237 def is_full(self) -> bool:
238 """Returns True if no more SSH sessions may be opened over this TCP connection."""
239 return self.free <= 0
241 def update_last_modified(self, last_modified: int) -> None:
242 """Records when the connection was last used."""
243 self.last_modified = last_modified
245 def shutdown(self, msg_prefix: str, p: Params) -> None:
246 """Closes the underlying SSH master connection."""
247 ssh_cmd: list[str] = self.ssh_cmd
248 if ssh_cmd:
249 ssh_socket_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]]
250 p.log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_socket_cmd))
251 process: CompletedProcess = subprocess.run(ssh_socket_cmd, stdin=DEVNULL, stderr=PIPE, text=True)
252 if process.returncode != 0:
253 p.log.log(LOG_TRACE, "%s", process.stderr.rstrip())
256#############################################################################
257class ConnectionPool:
258 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse."""
260 def __init__(self, remote: Remote, max_concurrent_ssh_sessions_per_tcp_connection: int) -> None:
261 assert max_concurrent_ssh_sessions_per_tcp_connection > 0
262 self.remote: Remote = copy.copy(remote) # shallow copy for immutability (Remote is mutable)
263 self.capacity: int = max_concurrent_ssh_sessions_per_tcp_connection
264 self.priority_queue: SmallPriorityQueue[Connection] = SmallPriorityQueue(
265 reverse=True # sorted by #free slots and last_modified
266 )
267 self.last_modified: int = 0 # monotonically increasing sequence number
268 self.cid: int = 0 # monotonically increasing connection number
269 self._lock: threading.Lock = threading.Lock()
271 @contextlib.contextmanager
272 def connection(self) -> Iterator[Connection]:
273 """Context manager that yields a connection from the pool and automatically returns it on __exit__."""
274 conn: Connection = self.get_connection()
275 try:
276 yield conn
277 finally:
278 self.return_connection(conn)
280 def get_connection(self) -> Connection:
281 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue, and
282 that identical Connection object is later, on return_connection(), temporarily removed from the priority queue,
283 updated with an incremented "free" slot count and then immediately reinserted into the priority queue.
285 In effect, any Connection object remains intentionally contained in the priority queue at all times.
286 """
287 with self._lock:
288 conn = self.priority_queue.pop() if len(self.priority_queue) > 0 else None
289 if conn is None or conn.is_full():
290 if conn is not None:
291 self.priority_queue.push(conn)
292 conn = Connection(self.remote, self.capacity, self.cid) # add a new connection
293 self.last_modified += 1
294 conn.update_last_modified(self.last_modified) # LIFO tiebreaker favors latest conn as that's most alive
295 self.cid += 1
296 conn.increment_free(-1)
297 self.priority_queue.push(conn)
298 return conn
300 def return_connection(self, conn: Connection) -> None:
301 """Returns the given connection to the pool and updates its priority."""
302 assert conn is not None
303 with self._lock:
304 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue
305 if self.priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called
306 conn.increment_free(1)
307 self.last_modified += 1
308 conn.update_last_modified(self.last_modified) # LIFO tiebreaker favors latest conn as that's most alive
309 self.priority_queue.push(conn)
311 def shutdown(self, msg_prefix: str) -> None:
312 """Closes all SSH connections managed by this pool."""
313 with self._lock:
314 if self.remote.reuse_ssh_connection:
315 for conn in self.priority_queue:
316 conn.shutdown(msg_prefix, self.remote.params)
317 self.priority_queue.clear()
319 def __repr__(self) -> str:
320 with self._lock:
321 queue = self.priority_queue
322 return str({"capacity": self.capacity, "queue_len": len(queue), "cid": self.cid, "queue": queue})
325#############################################################################
326class ConnectionPools:
327 """A bunch of named connection pools with various multiplexing capacities."""
329 def __init__(self, remote: Remote, capacities: dict[str, int]) -> None:
330 """Creates one connection pool per name with the given capacities."""
331 self.pools: dict[str, ConnectionPool] = {
332 name: ConnectionPool(remote, capacity) for name, capacity in capacities.items()
333 }
335 def __repr__(self) -> str:
336 return str(self.pools)
338 def pool(self, name: str) -> ConnectionPool:
339 """Returns the pool associated with the given name."""
340 return self.pools[name]
342 def shutdown(self, msg_prefix: str) -> None:
343 """Shuts down every contained pool."""
344 for name, pool in self.pools.items():
345 pool.shutdown(msg_prefix + "/" + name)