Coverage for bzfs_main/connection.py: 100%

196 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-06 13:30 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Network connection management is in refresh_ssh_connection_if_necessary() and class ConnectionPool; They reuse multiplexed 

16ssh connections for low latency.""" 

17 

18from __future__ import annotations 

19import contextlib 

20import copy 

21import logging 

22import shlex 

23import subprocess 

24import threading 

25import time 

26from dataclasses import dataclass 

27from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess 

28from typing import ( 

29 TYPE_CHECKING, 

30 Counter, 

31 Iterator, 

32) 

33 

34from bzfs_main.retry import ( 

35 RetryableError, 

36) 

37from bzfs_main.utils import ( 

38 LOG_TRACE, 

39 PROG_NAME, 

40 SmallPriorityQueue, 

41 die, 

42 list_formatter, 

43 stderr_to_str, 

44 subprocess_run, 

45 xprint, 

46) 

47 

48if TYPE_CHECKING: # pragma: no cover - for type hints only 

49 from bzfs_main.bzfs import Job 

50 from bzfs_main.configuration import Params, Remote 

51 

52# constants: 

53SHARED: str = "shared" 

54DEDICATED: str = "dedicated" 

55 

56 

57def run_ssh_command( 

58 job: Job, 

59 remote: Remote, 

60 level: int = -1, 

61 is_dry: bool = False, 

62 check: bool = True, 

63 print_stdout: bool = False, 

64 print_stderr: bool = True, 

65 cmd: list[str] | None = None, 

66) -> str: 

67 """Runs the given cmd via ssh on the given remote, and returns stdout. 

68 

69 The full command is the concatenation of both the command to run on the localhost in order to talk to the remote host 

70 ($remote.local_ssh_command()) and the command to run on the given remote host ($cmd). 

71 """ 

72 level = level if level >= 0 else logging.INFO 

73 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0 

74 p, log = job.params, job.params.log 

75 quoted_cmd: list[str] = [shlex.quote(arg) for arg in cmd] 

76 conn_pool: ConnectionPool = p.connection_pools[remote.location].pool(SHARED) 

77 with conn_pool.connection() as conn: 

78 ssh_cmd: list[str] = conn.ssh_cmd 

79 if remote.ssh_user_host != "": 

80 refresh_ssh_connection_if_necessary(job, remote, conn) 

81 cmd = quoted_cmd 

82 msg: str = "Would execute: %s" if is_dry else "Executing: %s" 

83 log.log(level, msg, list_formatter(conn.ssh_cmd_quoted + quoted_cmd, lstrip=True)) 

84 if is_dry: 

85 return "" 

86 try: 

87 process: CompletedProcess = subprocess_run( 

88 ssh_cmd + cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True, timeout=timeout(job), check=check 

89 ) 

90 except (subprocess.CalledProcessError, subprocess.TimeoutExpired, UnicodeDecodeError) as e: 

91 if not isinstance(e, UnicodeDecodeError): 

92 xprint(log, stderr_to_str(e.stdout), run=print_stdout, end="") 

93 xprint(log, stderr_to_str(e.stderr), run=print_stderr, end="") 

94 raise 

95 else: 

96 xprint(log, process.stdout, run=print_stdout, end="") 

97 xprint(log, process.stderr, run=print_stderr, end="") 

98 return process.stdout # type: ignore[no-any-return] # need to ignore on python <= 3.8 

99 

100 

101def try_ssh_command( 

102 job: Job, 

103 remote: Remote, 

104 level: int, 

105 is_dry: bool = False, 

106 print_stdout: bool = False, 

107 cmd: list[str] | None = None, 

108 exists: bool = True, 

109 error_trigger: str | None = None, 

110) -> str | None: 

111 """Convenience method that helps retry/react to a dataset or pool that potentially doesn't exist anymore.""" 

112 assert cmd is not None and isinstance(cmd, list) and len(cmd) > 0 

113 log = job.params.log 

114 try: 

115 maybe_inject_error(job, cmd=cmd, error_trigger=error_trigger) 

116 return run_ssh_command(job, remote, level=level, is_dry=is_dry, print_stdout=print_stdout, cmd=cmd) 

117 except (subprocess.CalledProcessError, UnicodeDecodeError) as e: 

118 if not isinstance(e, UnicodeDecodeError): 

119 stderr: str = stderr_to_str(e.stderr) 

120 if exists and ( 

121 ": dataset does not exist" in stderr 

122 or ": filesystem does not exist" in stderr # solaris 11.4.0 

123 or ": does not exist" in stderr # solaris 11.4.0 'zfs send' with missing snapshot 

124 or ": no such pool" in stderr 

125 ): 

126 return None 

127 log.warning("%s", stderr.rstrip()) 

128 raise RetryableError("Subprocess failed") from e 

129 

130 

131def refresh_ssh_connection_if_necessary(job: Job, remote: Remote, conn: Connection) -> None: 

132 """Maintain or create an ssh master connection for low latency reuse.""" 

133 p, log = job.params, job.params.log 

134 if remote.ssh_user_host == "": 

135 return # we're in local mode; no ssh required 

136 if not p.is_program_available("ssh", "local"): 

137 die(f"{p.ssh_program} CLI is not available to talk to remote host. Install {p.ssh_program} first!") 

138 if not remote.reuse_ssh_connection: 

139 return 

140 # Performance: reuse ssh connection for low latency startup of frequent ssh invocations via the 'ssh -S' and 

141 # 'ssh -S -M -oControlPersist=60s' options. See https://en.wikibooks.org/wiki/OpenSSH/Cookbook/Multiplexing 

142 control_persist_limit_nanos: int = (job.control_persist_secs - job.control_persist_margin_secs) * 1_000_000_000 

143 with conn.lock: 

144 if time.monotonic_ns() - conn.last_refresh_time < control_persist_limit_nanos: 

145 return # ssh master is alive, reuse its TCP connection (this is the common case & the ultra-fast path) 

146 ssh_cmd: list[str] = conn.ssh_cmd 

147 ssh_socket_cmd: list[str] = ssh_cmd[0:-1] # omit trailing ssh_user_host 

148 ssh_socket_cmd += ["-O", "check", remote.ssh_user_host] 

149 # extend lifetime of ssh master by $control_persist_secs via 'ssh -O check' if master is still running. 

150 # 'ssh -S /path/to/socket -O check' doesn't talk over the network, hence is still a low latency fast path. 

151 t: float | None = timeout(job) 

152 if subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stdout=PIPE, stderr=PIPE, text=True, timeout=t).returncode == 0: 

153 log.log(LOG_TRACE, "ssh connection is alive: %s", list_formatter(ssh_socket_cmd)) 

154 else: # ssh master is not alive; start a new master: 

155 log.log(LOG_TRACE, "ssh connection is not yet alive: %s", list_formatter(ssh_socket_cmd)) 

156 control_persist_secs: int = job.control_persist_secs 

157 if "-v" in remote.ssh_extra_opts: 

158 # Unfortunately, with `ssh -v` (debug mode), the ssh master won't background; instead it stays in the 

159 # foreground and blocks until the ControlPersist timer expires (90 secs). To make progress earlier we ... 

160 control_persist_secs = min(control_persist_secs, 1) # tell ssh to block as briefly as possible (1 sec) 

161 ssh_socket_cmd = ssh_cmd[0:-1] # omit trailing ssh_user_host 

162 ssh_socket_cmd += ["-M", f"-oControlPersist={control_persist_secs}s", remote.ssh_user_host, "exit"] 

163 log.log(LOG_TRACE, "Executing: %s", list_formatter(ssh_socket_cmd)) 

164 process = subprocess_run(ssh_socket_cmd, stdin=DEVNULL, stderr=PIPE, text=True, timeout=timeout(job)) 

165 if process.returncode != 0: 

166 log.error("%s", process.stderr.rstrip()) 

167 die( 

168 f"Cannot ssh into remote host via '{' '.join(ssh_socket_cmd)}'. Fix ssh configuration " 

169 f"first, considering diagnostic log file output from running {PROG_NAME} with: -v -v -v" 

170 ) 

171 conn.last_refresh_time = time.monotonic_ns() 

172 

173 

174def timeout(job: Job) -> float | None: 

175 """Raises TimeoutExpired if necessary, else returns the number of seconds left until timeout is to occur.""" 

176 timeout_nanos: int | None = job.timeout_nanos 

177 if timeout_nanos is None: 

178 return None # never raise a timeout 

179 delta_nanos: int = timeout_nanos - time.monotonic_ns() 

180 if delta_nanos <= 0: 

181 assert job.params.timeout_nanos is not None 

182 raise subprocess.TimeoutExpired(PROG_NAME + "_timeout", timeout=job.params.timeout_nanos / 1_000_000_000) 

183 return delta_nanos / 1_000_000_000 # seconds 

184 

185 

186def maybe_inject_error(job: Job, cmd: list[str], error_trigger: str | None = None) -> None: 

187 """For testing only; for unit tests to simulate errors during replication and test correct handling of them.""" 

188 if error_trigger: 

189 counter = job.error_injection_triggers.get("before") 

190 if counter and decrement_injection_counter(job, counter, error_trigger): 

191 try: 

192 raise CalledProcessError(returncode=1, cmd=" ".join(cmd), stderr=error_trigger + ":dataset is busy") 

193 except subprocess.CalledProcessError as e: 

194 if error_trigger.startswith("retryable_"): 

195 raise RetryableError("Subprocess failed") from e 

196 else: 

197 raise 

198 

199 

200def decrement_injection_counter(job: Job, counter: Counter[str], trigger: str) -> bool: 

201 """For testing only.""" 

202 with job.injection_lock: 

203 if counter[trigger] <= 0: 

204 return False 

205 counter[trigger] -= 1 

206 return True 

207 

208 

209############################################################################# 

210@dataclass(order=True, repr=False) 

211class Connection: 

212 """Represents the ability to multiplex N=capacity concurrent SSH sessions over the same TCP connection.""" 

213 

214 free: int # sort order evens out the number of concurrent sessions among the TCP connections 

215 last_modified: int # LIFO: tiebreaker favors latest returned conn as that's most alive and hot 

216 

217 def __init__(self, remote: Remote, max_concurrent_ssh_sessions_per_tcp_connection: int, cid: int) -> None: 

218 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

219 self.capacity: int = max_concurrent_ssh_sessions_per_tcp_connection 

220 self.free: int = max_concurrent_ssh_sessions_per_tcp_connection 

221 self.last_modified: int = 0 

222 self.cid: int = cid 

223 self.ssh_cmd: list[str] = remote.local_ssh_command() 

224 self.ssh_cmd_quoted: list[str] = [shlex.quote(item) for item in self.ssh_cmd] 

225 self.lock: threading.Lock = threading.Lock() 

226 self.last_refresh_time: int = 0 

227 

228 def __repr__(self) -> str: 

229 return str({"free": self.free, "cid": self.cid}) 

230 

231 def increment_free(self, value: int) -> None: 

232 """Adjusts the count of available SSH slots.""" 

233 self.free += value 

234 assert self.free >= 0 

235 assert self.free <= self.capacity 

236 

237 def is_full(self) -> bool: 

238 """Returns True if no more SSH sessions may be opened over this TCP connection.""" 

239 return self.free <= 0 

240 

241 def update_last_modified(self, last_modified: int) -> None: 

242 """Records when the connection was last used.""" 

243 self.last_modified = last_modified 

244 

245 def shutdown(self, msg_prefix: str, p: Params) -> None: 

246 """Closes the underlying SSH master connection.""" 

247 ssh_cmd: list[str] = self.ssh_cmd 

248 if ssh_cmd: 

249 ssh_socket_cmd: list[str] = ssh_cmd[0:-1] + ["-O", "exit", ssh_cmd[-1]] 

250 p.log.log(LOG_TRACE, f"Executing {msg_prefix}: %s", shlex.join(ssh_socket_cmd)) 

251 process: CompletedProcess = subprocess.run(ssh_socket_cmd, stdin=DEVNULL, stderr=PIPE, text=True) 

252 if process.returncode != 0: 

253 p.log.log(LOG_TRACE, "%s", process.stderr.rstrip()) 

254 

255 

256############################################################################# 

257class ConnectionPool: 

258 """Fetch a TCP connection for use in an SSH session, use it, finally return it back to the pool for future reuse.""" 

259 

260 def __init__(self, remote: Remote, max_concurrent_ssh_sessions_per_tcp_connection: int) -> None: 

261 assert max_concurrent_ssh_sessions_per_tcp_connection > 0 

262 self.remote: Remote = copy.copy(remote) # shallow copy for immutability (Remote is mutable) 

263 self.capacity: int = max_concurrent_ssh_sessions_per_tcp_connection 

264 self.priority_queue: SmallPriorityQueue[Connection] = SmallPriorityQueue( 

265 reverse=True # sorted by #free slots and last_modified 

266 ) 

267 self.last_modified: int = 0 # monotonically increasing sequence number 

268 self.cid: int = 0 # monotonically increasing connection number 

269 self._lock: threading.Lock = threading.Lock() 

270 

271 @contextlib.contextmanager 

272 def connection(self) -> Iterator[Connection]: 

273 """Context manager that yields a connection from the pool and automatically returns it on __exit__.""" 

274 conn: Connection = self.get_connection() 

275 try: 

276 yield conn 

277 finally: 

278 self.return_connection(conn) 

279 

280 def get_connection(self) -> Connection: 

281 """Any Connection object returned on get_connection() also remains intentionally contained in the priority queue, and 

282 that identical Connection object is later, on return_connection(), temporarily removed from the priority queue, 

283 updated with an incremented "free" slot count and then immediately reinserted into the priority queue. 

284 

285 In effect, any Connection object remains intentionally contained in the priority queue at all times. 

286 """ 

287 with self._lock: 

288 conn = self.priority_queue.pop() if len(self.priority_queue) > 0 else None 

289 if conn is None or conn.is_full(): 

290 if conn is not None: 

291 self.priority_queue.push(conn) 

292 conn = Connection(self.remote, self.capacity, self.cid) # add a new connection 

293 self.last_modified += 1 

294 conn.update_last_modified(self.last_modified) # LIFO tiebreaker favors latest conn as that's most alive 

295 self.cid += 1 

296 conn.increment_free(-1) 

297 self.priority_queue.push(conn) 

298 return conn 

299 

300 def return_connection(self, conn: Connection) -> None: 

301 """Returns the given connection to the pool and updates its priority.""" 

302 assert conn is not None 

303 with self._lock: 

304 # update priority = remove conn from queue, increment priority, finally reinsert updated conn into queue 

305 if self.priority_queue.remove(conn): # conn is not contained only if ConnectionPool.shutdown() was called 

306 conn.increment_free(1) 

307 self.last_modified += 1 

308 conn.update_last_modified(self.last_modified) # LIFO tiebreaker favors latest conn as that's most alive 

309 self.priority_queue.push(conn) 

310 

311 def shutdown(self, msg_prefix: str) -> None: 

312 """Closes all SSH connections managed by this pool.""" 

313 with self._lock: 

314 if self.remote.reuse_ssh_connection: 

315 for conn in self.priority_queue: 

316 conn.shutdown(msg_prefix, self.remote.params) 

317 self.priority_queue.clear() 

318 

319 def __repr__(self) -> str: 

320 with self._lock: 

321 queue = self.priority_queue 

322 return str({"capacity": self.capacity, "queue_len": len(queue), "cid": self.cid, "queue": queue}) 

323 

324 

325############################################################################# 

326class ConnectionPools: 

327 """A bunch of named connection pools with various multiplexing capacities.""" 

328 

329 def __init__(self, remote: Remote, capacities: dict[str, int]) -> None: 

330 """Creates one connection pool per name with the given capacities.""" 

331 self.pools: dict[str, ConnectionPool] = { 

332 name: ConnectionPool(remote, capacity) for name, capacity in capacities.items() 

333 } 

334 

335 def __repr__(self) -> str: 

336 return str(self.pools) 

337 

338 def pool(self, name: str) -> ConnectionPool: 

339 """Returns the pool associated with the given name.""" 

340 return self.pools[name] 

341 

342 def shutdown(self, msg_prefix: str) -> None: 

343 """Shuts down every contained pool.""" 

344 for name, pool in self.pools.items(): 

345 pool.shutdown(msg_prefix + "/" + name)