Coverage for bzfs_main / util / utils.py: 100%

797 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-29 12:49 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight 

16concurrency primitives, etc. 

17 

18Everything in this module relies only on the Python standard library so other modules remain dependency free. Each utility 

19favors simple, predictable behavior on all supported platforms. 

20""" 

21 

22from __future__ import ( 

23 annotations, 

24) 

25import argparse 

26import base64 

27import bisect 

28import contextlib 

29import dataclasses 

30import errno 

31import hashlib 

32import itertools 

33import logging 

34import operator 

35import os 

36import platform 

37import pwd 

38import random 

39import re 

40import signal 

41import stat 

42import subprocess 

43import sys 

44import threading 

45import time 

46import types 

47from collections import ( 

48 defaultdict, 

49 deque, 

50) 

51from collections.abc import ( 

52 ItemsView, 

53 Iterable, 

54 Iterator, 

55 Sequence, 

56) 

57from concurrent.futures import ( 

58 Executor, 

59 Future, 

60) 

61from concurrent.futures.thread import ( 

62 ThreadPoolExecutor, 

63) 

64from dataclasses import ( 

65 dataclass, 

66) 

67from datetime import ( 

68 datetime, 

69 timedelta, 

70 timezone, 

71 tzinfo, 

72) 

73from subprocess import ( 

74 DEVNULL, 

75 PIPE, 

76) 

77from typing import ( 

78 IO, 

79 Any, 

80 Callable, 

81 Final, 

82 Generic, 

83 Literal, 

84 NoReturn, 

85 Protocol, 

86 SupportsIndex, 

87 TextIO, 

88 TypeVar, 

89 cast, 

90 final, 

91) 

92 

93# constants: 

94PROG_NAME: Final[str] = "bzfs" 

95ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_" 

96DIE_STATUS: Final[int] = 3 

97DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset 

98LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between 

99LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between 

100LOG_DEBUG: Final[int] = logging.DEBUG 

101LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between 

102YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target 

103UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID 

104DONT_SKIP_DATASET: Final[str] = "" 

105SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\" # intentionally not included: -_.:/ 

106SHELL_CHARS_AND_SLASH: Final[str] = SHELL_CHARS + "/" 

107FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write) 

108DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute) 

109UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions 

110UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path' 

111 

112RegexList = list[tuple[re.Pattern[str], bool]] # Type alias 

113 

114 

115def getenv_any(key: str, default: str | None = None, env_var_prefix: str = ENV_VAR_PREFIX) -> str | None: 

116 """All shell environment variable names used for configuration start with this prefix.""" 

117 return os.getenv(env_var_prefix + key, default) 

118 

119 

120def getenv_int(key: str, default: int, env_var_prefix: str = ENV_VAR_PREFIX) -> int: 

121 """Returns environment variable ``key`` as int with ``default`` fallback.""" 

122 return int(cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix))) 

123 

124 

125def getenv_bool(key: str, default: bool = False, env_var_prefix: str = ENV_VAR_PREFIX) -> bool: 

126 """Returns environment variable ``key`` as bool with ``default`` fallback.""" 

127 return cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)).lower().strip() == "true" 

128 

129 

130def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]: 

131 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command.""" 

132 assert lines is not None 

133 assert isinstance(lines, list) 

134 assert len(separator) == 1 

135 if field == 1: 

136 return [line[: line.index(separator)] for line in lines] 

137 elif field == 2: 

138 return [line[line.index(separator) + 1 :] for line in lines] 

139 else: 

140 raise ValueError(f"Invalid field value: {field}") 

141 

142 

143def drain(iterable: Iterable[Any]) -> None: 

144 """Consumes all items in the iterable, effectively draining it.""" 

145 for _ in iterable: 

146 del _ # help gc (iterable can block) 

147 

148 

149_K_ = TypeVar("_K_") 

150_V_ = TypeVar("_V_") 

151_R_ = TypeVar("_R_") 

152 

153 

154def shuffle_dict(dictionary: dict[_K_, _V_], /, rand: random.Random = random.SystemRandom()) -> dict[_K_, _V_]: # noqa: B008 

155 """Returns a new dict with items shuffled randomly.""" 

156 items: list[tuple[_K_, _V_]] = list(dictionary.items()) 

157 rand.shuffle(items) 

158 return dict(items) 

159 

160 

161def sorted_dict( 

162 dictionary: dict[_K_, _V_], /, *, key: Callable[[tuple[_K_, _V_]], Any] | None = None, reverse: bool = False 

163) -> dict[_K_, _V_]: 

164 """Returns a new dict with items sorted, primarily by key and secondarily by value (unless a custom key is supplied).""" 

165 return dict(sorted(dictionary.items(), key=key, reverse=reverse)) 

166 

167 

168def tail(file: str, *, n: int, errors: str | None = None) -> Sequence[str]: 

169 """Return the last ``n`` lines of ``file`` without following symlinks.""" 

170 if not os.path.isfile(file): 

171 return [] 

172 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False, check_perm=False) as fd: 

173 return deque(fd, maxlen=n) 

174 

175 

176_NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">")) 

177_NUMERIC_BACKREFERENCE_REGEX: Final[re.Pattern[str]] = re.compile(r"\\\d+") # example: \1 

178 

179 

180def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str: 

181 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky). 

182 

183 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\(' 

184 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash 

185 with the replacement string '(?:' 

186 

187 Named capturing group example: '(?P<name>abc)' --> '(?:abc)' 

188 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash 

189 with the replacement string '(?:' 

190 

191 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups 

192 """ 

193 i = regex.find("[") 

194 if i >= 0 and regex.find("(", i) >= 0: 

195 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the case where the regex might 

196 # contain a regex character class that contains parenthesis. 

197 # Rewriting a regex is a performance optimization; correctness comes first. 

198 return regex 

199 

200 if "(?P=" in regex or "(?(" in regex or _NUMERIC_BACKREFERENCE_REGEX.search(regex): 

201 # Conservative fallback to minimize code complexity: skip the rewrite entirely if the regex might contain a 

202 # (named or conditional or numeric) backreference. 

203 # Rewriting a regex is a performance optimization; correctness comes first. 

204 return regex 

205 

206 i = len(regex) - 2 

207 while i >= 0: 

208 i = regex.rfind("(", 0, i + 1) 

209 if i >= 0 and (i == 0 or regex[i - 1] != "\\"): 

210 if regex[i + 1] != "?": 

211 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group 

212 else: # potentially a valid named capturing group 

213 regex = regex[0:i] + _NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1) 

214 i -= 1 

215 return regex 

216 

217 

218def get_home_directory() -> str: 

219 """Reliably detects home dir without using HOME env var.""" 

220 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~') 

221 return pwd.getpwuid(os.getuid()).pw_dir 

222 

223 

224def human_readable_bytes(num_bytes: float, *, separator: str = " ", precision: int | None = None) -> str: 

225 """Formats 'num_bytes' as a human-readable size; for example "567 MiB".""" 

226 sign = "-" if num_bytes < 0 else "" 

227 s = abs(num_bytes) 

228 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB") 

229 n = len(units) - 1 

230 i = 0 

231 while s >= 1024 and i < n: 

232 s /= 1024 

233 i += 1 

234 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}" 

235 return f"{sign}{formatted_num}{separator}{units[i]}" 

236 

237 

238def human_readable_duration(duration: float, *, unit: str = "ns", separator: str = "", precision: int | None = None) -> str: 

239 """Formats a duration in human units, automatically scaling as needed; for example "567ms".""" 

240 sign = "-" if duration < 0 else "" 

241 t = abs(duration) 

242 units = ("ns", "μs", "ms", "s", "m", "h", "d") 

243 i = units.index(unit) 

244 if t < 1 and t != 0: 

245 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000) 

246 t *= nanos[i] 

247 i = 0 

248 while t >= 1000 and i < 3: 

249 t /= 1000 

250 i += 1 

251 if i >= 3: 

252 while t >= 60 and i < 5: 

253 t /= 60 

254 i += 1 

255 if i >= 5: 

256 while t >= 24 and i < len(units) - 1: 

257 t /= 24 

258 i += 1 

259 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}" 

260 return f"{sign}{formatted_num}{separator}{units[i]}" 

261 

262 

263def human_readable_float(number: float) -> str: 

264 """Formats ``number`` with a variable precision depending on magnitude. 

265 

266 This design mirrors the way humans round values when scanning logs. 

267 

268 If the number has one digit before the decimal point (0 <= abs(number) < 10): 

269 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15"). 

270 

271 If the number has two digits before the decimal point (10 <= abs(number) < 100): 

272 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4"). 

273 

274 If the number has three or more digits before the decimal point (abs(number) >= 100): 

275 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124"). 

276 

277 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1" 

278 """ 

279 abs_number = abs(number) 

280 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0 

281 if precision == 0: 

282 return str(round(number)) 

283 result = f"{number:.{precision}f}" 

284 assert "." in result 

285 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty 

286 return "0" if result == "-0" else result 

287 

288 

289def percent(number: int, total: int, *, print_total: bool = False) -> str: 

290 """Returns percentage string of ``number`` relative to ``total``.""" 

291 tot: str = f"/{total}" if print_total else "" 

292 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%" 

293 

294 

295def open_nofollow( 

296 path: str, 

297 mode: str = "r", 

298 buffering: int = -1, 

299 encoding: str | None = None, 

300 errors: str | None = None, 

301 newline: str | None = None, 

302 *, 

303 perm: int = FILE_PERMISSIONS, 

304 check_owner: bool = True, 

305 check_perm: bool = True, 

306 **kwargs: Any, 

307) -> IO[Any]: 

308 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with 

309 errno.ELOOP/EMLINK if basename of path is a symlink. 

310 

311 Also, can specify custom permissions on O_CREAT, and verify secure ownership and mode bits. 

312 

313 If check_owner=True, write-capable opens require ownership by the effective UID; read-only opens also allow ownership by 

314 uid 0 (root). This allows safe reads of root-owned system files while preventing writes to files not owned by the caller. 

315 

316 If check_perm=True, the file must also not be writable by group or others, and must not be executable by anyone. 

317 """ 

318 if not mode: 

319 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus") 

320 flags = { 

321 "r": os.O_RDONLY, 

322 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 

323 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND, 

324 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL, 

325 }.get(mode[0]) 

326 if flags is None: 

327 raise ValueError(f"invalid mode {mode!r}") 

328 if "+" in mode: # enable read-write access for r+, w+, a+, x+ 

329 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags 

330 flags |= os.O_NOFOLLOW | os.O_CLOEXEC 

331 fd: int = os.open(path, flags=flags, mode=perm) 

332 try: 

333 stats: os.stat_result | None = None 

334 if check_owner: 

335 stats = os.fstat(fd) 

336 st_uid: int = stats.st_uid 

337 if st_uid != os.geteuid(): # verify ownership is current effective UID 

338 if (flags & (os.O_WRONLY | os.O_RDWR)) != 0: # require that writer owns the file 

339 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path) 

340 elif st_uid != 0: # it's ok for root to own a file that we'll merely read 

341 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()} or 0", path) 

342 if check_perm: 

343 stats = os.fstat(fd) if stats is None else stats 

344 if (stats.st_mode & (stat.S_IWGRP | stat.S_IWOTH | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)) != 0: 

345 st_mode: int = stat.S_IMODE(stats.st_mode) 

346 raise PermissionError( 

347 errno.EPERM, 

348 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]} " 

349 "but must not be writable by group or others, or executable by anyone", 

350 path, 

351 ) 

352 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs) 

353 except Exception: 

354 try: 

355 os.close(fd) 

356 except OSError: 

357 pass 

358 raise 

359 

360 

361def close_quietly(fd: int) -> None: 

362 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this.""" 

363 if fd >= 0: 

364 try: 

365 os.close(fd) 

366 except OSError: 

367 pass 

368 

369 

370_P = TypeVar("_P") 

371 

372 

373def find_match( 

374 seq: Sequence[_P], 

375 predicate: Callable[[_P], bool], 

376 start: SupportsIndex | None = None, 

377 end: SupportsIndex | None = None, 

378 *, 

379 reverse: bool = False, 

380 raises: bool | object | Callable[[], object] = False, # raises: bool | object | Callable = False, # python >= 3.10 

381) -> int: 

382 """Returns the integer index within ``seq`` of the first item (or last item if reverse=True) that matches the given 

383 predicate condition. 

384 

385 If no matching item is found returns -1 or ValueError, depending on the ``raises`` parameter, which is a bool indicating 

386 whether to raise an error, or an object containing the error message, but can also be a Callable/lambda in order to 

387 support efficient deferred generation of error messages. 

388 

389 Analog to ``str.find()``, including slicing semantics with parameters start and end, i.e. respects Python slicing 

390 semantics for start/end (including clamping). For example, seq can be a list, tuple or str. 

391 

392 Example usage: 

393 lst = ["a", "b", "-c", "d"] 

394 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True) 

395 if i >= 0: 

396 print(lst[i]) 

397 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}") 

398 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}") 

399 """ 

400 if start is None and end is None: 

401 for i in range(len(seq) - 1, -1, -1) if reverse else range(len(seq)): 

402 if predicate(seq[i]): 

403 return i 

404 else: 

405 slice_start, slice_end, _ = slice(start, end).indices(len(seq)) 

406 for i in range(slice_end - 1, slice_start - 1, -1) if reverse else range(slice_start, slice_end): 

407 if predicate(seq[i]): 

408 return i 

409 if raises is False or raises is None: 

410 return -1 

411 if raises is True: 

412 raise ValueError("No matching item found in sequence") 

413 if callable(raises): 

414 raises = raises() 

415 raise ValueError(raises) 

416 

417 

418def is_descendant(dataset: str, of_root_dataset: str) -> bool: 

419 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same.""" 

420 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/") 

421 

422 

423def has_duplicates(sorted_list: list[Any]) -> bool: 

424 """Returns True if any adjacent items within the given sorted sequence are equal.""" 

425 return any(map(operator.eq, sorted_list, itertools.islice(sorted_list, 1, None))) 

426 

427 

428def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool: 

429 """Returns whether the (sorted) list of ZFS input datasets contains any siblings.""" 

430 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted" 

431 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates" 

432 skip_dataset: str = DONT_SKIP_DATASET 

433 parents: set[str] = set() 

434 for dataset in sorted_datasets: 

435 assert dataset 

436 parent = os.path.dirname(dataset) 

437 if parent in parents: 

438 return True # I have a sibling if my parent already has another child 

439 parents.add(parent) 

440 if is_descendant(dataset, of_root_dataset=skip_dataset): 

441 continue 

442 if skip_dataset != DONT_SKIP_DATASET: 

443 return True # I have a sibling if I am a root dataset and another root dataset already exists 

444 skip_dataset = dataset 

445 return False 

446 

447 

448def dry(msg: str, is_dry_run: bool) -> str: 

449 """Prefix ``msg`` with 'Dry' when in dry-run mode.""" 

450 return "Dry " + msg if is_dry_run else msg 

451 

452 

453def relativize_dataset(dataset: str, root_dataset: str) -> str: 

454 """Converts an absolute dataset path to one relative to ``root_dataset``. 

455 

456 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz. 

457 """ 

458 return dataset[len(root_dataset) :] 

459 

460 

461def dataset_paths(dataset: str) -> Iterator[str]: 

462 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c".""" 

463 i: int = 0 

464 while i >= 0: 

465 i = dataset.find("/", i) 

466 if i < 0: 

467 yield dataset 

468 else: 

469 yield dataset[:i] 

470 i += 1 

471 

472 

473def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str: 

474 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present.""" 

475 assert s.startswith(old_prefix) 

476 return new_prefix + s[len(old_prefix) :] 

477 

478 

479def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None: 

480 """Replaces ``old`` with ``new`` in-place for every string in ``lines``.""" 

481 for i in range(len(lines)): 

482 lines[i] = lines[i].replace(old, new, count) 

483 

484 

485_TAPPEND = TypeVar("_TAPPEND") 

486 

487 

488def append_if_absent(lst: list[_TAPPEND], *items: _TAPPEND) -> list[_TAPPEND]: 

489 """Appends items to list if they are not already present.""" 

490 for item in items: 

491 if item not in lst: 

492 lst.append(item) 

493 return lst 

494 

495 

496def xappend(lst: list[_TAPPEND], *items: _TAPPEND | Iterable[_TAPPEND]) -> list[_TAPPEND]: 

497 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If 

498 an item is an iterable does so recursively, flattening the output.""" 

499 for item in items: 

500 if isinstance(item, str) or not isinstance(item, Iterable): 

501 if item: 

502 lst.append(item) 

503 else: 

504 xappend(lst, *item) 

505 return lst 

506 

507 

508def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool: 

509 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False. 

510 

511 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match. 

512 """ 

513 for regex, is_negation in exclude_regexes: 

514 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

515 if is_negation: 

516 is_match = not is_match 

517 if is_match: 

518 return False 

519 

520 for regex, is_negation in include_regexes: 

521 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

522 if is_negation: 

523 is_match = not is_match 

524 if is_match: 

525 return True 

526 

527 return False 

528 

529 

530def compile_regexes(regexes: list[str], *, suffix: str = "") -> RegexList: 

531 """Compiles regex strings and keeps track of negations.""" 

532 assert isinstance(regexes, list) 

533 compiled_regexes: RegexList = [] 

534 for regex in regexes: 

535 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match 

536 if regex.endswith("\\$"): 

537 pass # trailing literal $ is ok 

538 elif regex.endswith("$"): 

539 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch() 

540 elif "$" in regex: 

541 raise re.error("Must not use non-trailing '$' character", regex) 

542 if is_negation := regex.startswith("!"): 

543 regex = regex[1:] 

544 regex = replace_capturing_groups_with_non_capturing_groups(regex) 

545 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")): 

546 regex = f"{regex}{suffix}" 

547 compiled_regexes.append((re.compile(regex), is_negation)) 

548 return compiled_regexes 

549 

550 

551def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any: 

552 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels.""" 

553 

554 @final 

555 class CustomListFormatter: 

556 """Formatter object that joins items when converted to ``str``.""" 

557 

558 def __str__(self) -> str: 

559 s = separator.join(map(str, iterable)) 

560 return s.lstrip() if lstrip else s 

561 

562 return CustomListFormatter() 

563 

564 

565def pretty_print_formatter(obj_to_format: Any) -> Any: 

566 """Lazy pprint formatter used to avoid overhead in disabled log levels.""" 

567 

568 @final 

569 class PrettyPrintFormatter: 

570 """Formatter that pretty-prints the object on conversion to ``str``.""" 

571 

572 def __str__(self) -> str: 

573 import pprint # lazy import for startup perf 

574 

575 return pprint.pformat(vars(obj_to_format)) 

576 

577 return PrettyPrintFormatter() 

578 

579 

580def stderr_to_str(stderr: Any) -> str: 

581 """Workaround for https://github.com/python/cpython/issues/87597.""" 

582 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace") 

583 

584 

585def xprint(log: logging.Logger, value: Any, *, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None: 

586 """Optionally logs ``value`` at stdout/stderr level.""" 

587 if run and value: 

588 value = value if end else str(value).rstrip() 

589 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR 

590 log.log(level, "%s", value) 

591 

592 

593def sha256_hex(text: str) -> str: 

594 """Returns the sha256 hex string for the given text.""" 

595 return hashlib.sha256(text.encode()).hexdigest() 

596 

597 

598def sha256_urlsafe_base64(text: str, *, padding: bool = True) -> str: 

599 """Returns the URL-safe base64-encoded sha256 value for the given text.""" 

600 digest: bytes = hashlib.sha256(text.encode()).digest() 

601 s: str = base64.urlsafe_b64encode(digest).decode() 

602 return s if padding else s.rstrip("=") 

603 

604 

605def sha256_128_urlsafe_base64(text: str) -> str: 

606 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

607 s: str = sha256_urlsafe_base64(text, padding=False) 

608 return s[: len(s) // 2] 

609 

610 

611def sha256_85_urlsafe_base64(text: str) -> str: 

612 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

613 s: str = sha256_urlsafe_base64(text, padding=False) 

614 return s[: len(s) // 3] 

615 

616 

617def urlsafe_base64( 

618 value: int, max_value: int = 2**64 - 1, *, padding: bool = True, byteorder: Literal["little", "big"] = "big" 

619) -> str: 

620 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value].""" 

621 assert 0 <= value <= max_value 

622 max_bytes: int = (max_value.bit_length() + 7) // 8 

623 value_bytes: bytes = value.to_bytes(max_bytes, byteorder) 

624 s: str = base64.urlsafe_b64encode(value_bytes).decode() 

625 return s if padding else s.rstrip("=") 

626 

627 

628def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn: 

629 """Exits the program with ``exit_code`` after logging ``msg``.""" 

630 if parser is None: 

631 ex = SystemExit(msg) 

632 ex.code = exit_code 

633 raise ex 

634 else: 

635 parser.error(msg) 

636 

637 

638def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

639 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and 

640 provides optional child PID tracking, and optional logging of execution status via ``log`` and ``loglevel`` params.""" 

641 input_value = kwargs.pop("input", None) 

642 timeout = kwargs.pop("timeout", None) 

643 check = kwargs.pop("check", False) 

644 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None) 

645 if input_value is not None: 

646 if kwargs.get("stdin") is not None: 

647 raise ValueError("input and stdin are mutually exclusive") 

648 kwargs["stdin"] = subprocess.PIPE 

649 

650 log: logging.Logger | None = kwargs.pop("log", None) 

651 loglevel: int | None = kwargs.pop("loglevel", None) 

652 start_time_nanos: int = time.monotonic_ns() 

653 is_timeout: bool = False 

654 is_cancel: bool = False 

655 exitcode: int | None = None 

656 

657 def log_status() -> None: 

658 if log is not None: 

659 _loglevel: int = loglevel if loglevel is not None else getenv_int("subprocess_run_loglevel", LOG_TRACE) 

660 if log.isEnabledFor(_loglevel): 

661 elapsed_time: str = human_readable_float((time.monotonic_ns() - start_time_nanos) / 1_000_000) + "ms" 

662 status = "cancel" if is_cancel else "timeout" if is_timeout else "success" if exitcode == 0 else "failure" 

663 cmd = kwargs["args"] if "args" in kwargs else (args[0] if args else None) 

664 cmd_str: str = " ".join(str(arg) for arg in iter(cmd)) if isinstance(cmd, (list, tuple)) else str(cmd) 

665 log.log(_loglevel, f"Executed [{status}] [{elapsed_time}]: %s", cmd_str) 

666 

667 with xfinally(log_status): 

668 ctx: contextlib.AbstractContextManager[subprocess.Popen] 

669 if subprocesses is None: 

670 ctx = subprocess.Popen(*args, **kwargs) 

671 else: 

672 ctx = subprocesses.popen_and_track(*args, **kwargs) 

673 with ctx as proc: 

674 try: 

675 sp = subprocesses 

676 if sp is not None and sp._is_terminated(): # noqa: SLF001 pylint: disable=protected-access 

677 is_cancel = True 

678 timeout = 0.0 

679 stdout, stderr = proc.communicate(input_value, timeout=timeout) 

680 except BaseException as e: 

681 try: 

682 if isinstance(e, subprocess.TimeoutExpired): 

683 is_timeout = True 

684 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child proc and descendants 

685 finally: 

686 proc.kill() 

687 raise 

688 else: 

689 exitcode = proc.poll() 

690 assert exitcode is not None 

691 if check and exitcode: 

692 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr) 

693 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr) 

694 

695 

696def terminate_process_subtree( 

697 *, except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM 

698) -> None: 

699 """For each root PID: Sends the given signal to the root PID and all its descendant processes.""" 

700 current_pid: int = os.getpid() 

701 root_pids = [current_pid] if root_pids is None else root_pids 

702 all_pids: list[list[int]] = _get_descendant_processes(root_pids) 

703 assert len(all_pids) == len(root_pids) 

704 for i, pids in enumerate(all_pids): 

705 root_pid = root_pids[i] 

706 if root_pid == current_pid: 

707 pids += [] if except_current_process else [current_pid] 

708 else: 

709 pids.insert(0, root_pid) 

710 for pid in pids: 

711 with contextlib.suppress(OSError): 

712 os.kill(pid, sig) 

713 

714 

715def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]: 

716 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems.""" 

717 if len(root_pids) == 0: 

718 return [] 

719 cmd: list[str] = ["ps", "-Ao", "pid,ppid"] 

720 try: 

721 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines() 

722 except PermissionError: 

723 # degrade gracefully in sandbox environments that deny executing `ps` entirely 

724 return [[] for _ in root_pids] 

725 procs: dict[int, list[int]] = defaultdict(list) 

726 for line in lines[1:]: # all lines except the header line 

727 splits: list[str] = line.split() 

728 assert len(splits) == 2 

729 pid = int(splits[0]) 

730 ppid = int(splits[1]) 

731 procs[ppid].append(pid) 

732 

733 def recursive_append(ppid: int, descendants: list[int]) -> None: 

734 """Recursively collect descendant PIDs starting from ``ppid``.""" 

735 for child_pid in procs[ppid]: 

736 descendants.append(child_pid) 

737 recursive_append(child_pid, descendants) 

738 

739 all_descendants: list[list[int]] = [] 

740 for root_pid in root_pids: 

741 descendants: list[int] = [] 

742 recursive_append(root_pid, descendants) 

743 all_descendants.append(descendants) 

744 return all_descendants 

745 

746 

747@contextlib.contextmanager 

748def termination_signal_handler( 

749 termination_events: list[threading.Event], 

750 *, 

751 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(), 

752) -> Iterator[None]: 

753 """Context manager that installs SIGINT/SIGTERM handlers that set all ``termination_events`` and, by default, terminate 

754 all descendant processes.""" 

755 termination_events = list(termination_events) # shallow copy 

756 

757 def _handler(_sig: int, _frame: object) -> None: 

758 for event in termination_events: 

759 event.set() 

760 termination_handler() 

761 

762 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler 

763 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler 

764 try: 

765 yield # run body of context manager 

766 finally: 

767 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler 

768 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler 

769 

770 

771def return_false() -> bool: 

772 """Always returns ``False``; picklable.""" 

773 return False 

774 

775 

776def sleep_nanos(delay_nanos: int) -> None: 

777 """Same as time.sleep() but expects a relative sleep duration in nanoseconds as input value; picklable.""" 

778 time.sleep(delay_nanos / 1_000_000_000) 

779 

780 

781############################################################################# 

782@dataclass(frozen=True) 

783@final 

784class TaskTiming: 

785 """Customizable callbacks for reading the current monotonic time, sleeping and optional async termination; immutable.""" 

786 

787 monotonic_ns: Callable[[], int] = time.monotonic_ns 

788 

789 is_terminated: Callable[[], bool] = return_false 

790 """Returns whether a predicate has become true; can be used to indicate system shutdown or similar cancellation 

791 conditions; default is to always return ``False``.""" 

792 

793 sleep: Callable[[int], None] = sleep_nanos 

794 """Sleeps N nanoseconds; thread-safe.""" 

795 

796 def copy(self, **override_kwargs: Any) -> TaskTiming: 

797 """Creates a new object copying an existing one with the specified fields overridden for customization; thread- 

798 safe.""" 

799 return dataclasses.replace(self, **override_kwargs) 

800 

801 @staticmethod 

802 def make_from(termination_event: threading.Event | None) -> TaskTiming: 

803 """Convenience factory that creates an object that performs async termination when ``termination_event`` is set.""" 

804 if termination_event is None: 

805 return TaskTiming() 

806 

807 def _sleep(delay_nanos: int) -> None: 

808 termination_event.wait(delay_nanos / 1_000_000_000) # allow early wakeup on async termination 

809 

810 return TaskTiming(is_terminated=termination_event.is_set, sleep=_sleep) 

811 

812 

813############################################################################# 

814@final 

815class Subprocesses: 

816 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used 

817 when multiple jobs run concurrently within the same Python process. 

818 

819 Optionally binds to an ``_is_terminated`` predicate to enforce async cancellation by forcing immediate timeouts for newly 

820 spawned subprocesses once cancellation is requested. 

821 """ 

822 

823 def __init__(self, is_terminated: Callable[[], bool] = return_false) -> None: 

824 self._is_terminated: Final[Callable[[], bool]] = is_terminated 

825 self._lock: Final[threading.Lock] = threading.Lock() 

826 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order 

827 

828 @contextlib.contextmanager 

829 def popen_and_track(self, *popen_args: Any, **popen_kwargs: Any) -> Iterator[subprocess.Popen]: 

830 """Context manager that calls subprocess.Popen() and tracks the child PID for per-job termination. 

831 

832 Holds a lock across Popen+PID registration to prevent a race when terminate_process_subtrees() is invoked (e.g. from 

833 SIGINT/SIGTERM handlers), ensuring newly spawned child processes cannot escape termination. The child PID is 

834 unregistered on context exit. 

835 """ 

836 with self._lock: 

837 proc: subprocess.Popen = subprocess.Popen(*popen_args, **popen_kwargs) 

838 self._child_pids[proc.pid] = None 

839 try: 

840 yield proc 

841 finally: 

842 with self._lock: 

843 self._child_pids.pop(proc.pid, None) 

844 

845 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

846 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination.""" 

847 return subprocess_run(*args, **kwargs, subprocesses=self) 

848 

849 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None: 

850 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs.""" 

851 with self._lock: 

852 pids: list[int] = list(self._child_pids) 

853 self._child_pids.clear() 

854 terminate_process_subtree(root_pids=pids, sig=sig) 

855 

856 

857############################################################################# 

858def pid_exists(pid: int) -> bool | None: 

859 """Returns True if a process with PID exists, False if not, or None on error.""" 

860 if pid <= 0: 

861 return False 

862 try: # with signal=0, no signal is actually sent, but error checking is still performed 

863 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems 

864 except OSError as err: 

865 if err.errno == errno.ESRCH: # No such process 

866 return False 

867 if err.errno == errno.EPERM: # Operation not permitted 

868 return True 

869 return None 

870 return True 

871 

872 

873def nprefix(s: str) -> str: 

874 """Returns a canonical snapshot prefix with trailing underscore.""" 

875 return sys.intern(s + "_") 

876 

877 

878def ninfix(s: str) -> str: 

879 """Returns a canonical infix with trailing underscore when not empty.""" 

880 return sys.intern(s + "_") if s else "" 

881 

882 

883def nsuffix(s: str) -> str: 

884 """Returns a canonical suffix with leading underscore when not empty.""" 

885 return sys.intern("_" + s) if s else "" 

886 

887 

888def format_dict(dictionary: dict[Any, Any]) -> str: 

889 """Returns a formatted dictionary using repr for consistent output.""" 

890 return f'"{dictionary}"' 

891 

892 

893def format_obj(obj: object) -> str: 

894 """Returns a formatted str using repr for consistent output.""" 

895 return f'"{obj}"' 

896 

897 

898def validate_dataset_name(dataset: str, input_text: str) -> None: 

899 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc.""" 

900 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424 

901 # and https://github.com/openzfs/zfs/issues/8798 

902 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html 

903 invalid_chars: str = SHELL_CHARS 

904 if ( 

905 dataset in ("", ".", "..") 

906 or dataset.startswith(("/", "./", "../")) 

907 or dataset.endswith(("/", "/.", "/..")) 

908 or any(substring in dataset for substring in ("//", "/./", "/../")) 

909 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset) 

910 or not dataset[0].isalpha() 

911 ): 

912 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'") 

913 

914 

915def validate_property_name(propname: str, input_text: str) -> str: 

916 """Checks that the ZFS property name contains no spaces or shell chars, etc.""" 

917 invalid_chars: str = SHELL_CHARS 

918 if (not propname) or propname.startswith("-") or any(char.isspace() or char in invalid_chars for char in propname): 

919 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'") 

920 return propname 

921 

922 

923def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None: 

924 """Checks that the given path is not a symbolic link.""" 

925 if os.path.islink(path): 

926 die(f"{msg}must not be a symlink: {path}", parser=parser) 

927 

928 

929def validate_file_permissions(path: str, mode: int) -> None: 

930 """Verify permissions and that ownership is current effective UID.""" 

931 stats: os.stat_result = os.stat(path, follow_symlinks=False) 

932 st_uid: int = stats.st_uid 

933 if st_uid != os.geteuid(): # verify ownership is current effective UID 

934 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}") 

935 st_mode = stat.S_IMODE(stats.st_mode) 

936 if st_mode != mode: 

937 die( 

938 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, " 

939 f"not {mode:03o} aka {stat.filemode(mode)[1:]})" 

940 ) 

941 

942 

943def parse_duration_to_milliseconds(duration: str, *, regex_suffix: str = "", context: str = "") -> int: 

944 """Parses human duration strings like '5minutes' or '2 hours' to milliseconds.""" 

945 unit_milliseconds: dict[str, int] = { 

946 "milliseconds": 1, 

947 "millis": 1, 

948 "seconds": 1000, 

949 "secs": 1000, 

950 "minutes": 60 * 1000, 

951 "mins": 60 * 1000, 

952 "hours": 60 * 60 * 1000, 

953 "days": 86400 * 1000, 

954 "weeks": 7 * 86400 * 1000, 

955 "months": round(30.5 * 86400 * 1000), 

956 "years": 365 * 86400 * 1000, 

957 } 

958 match = re.fullmatch( 

959 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix, 

960 duration, 

961 ) 

962 if not match: 

963 if context: 

964 die(f"Invalid duration format: {duration} within {context}") 

965 else: 

966 raise ValueError(f"Invalid duration format: {duration}") 

967 assert match 

968 quantity: int = int(match.group(1)) 

969 unit: str = match.group(2) 

970 return quantity * unit_milliseconds[unit] 

971 

972 

973def unixtime_fromisoformat(datetime_str: str) -> int: 

974 """Converts ISO 8601 datetime string into UTC Unix time in integer seconds.""" 

975 return int(datetime.fromisoformat(datetime_str).timestamp()) 

976 

977 

978def isotime_from_unixtime(unixtime_in_seconds: int) -> str: 

979 """Converts UTC Unix time seconds into ISO 8601 datetime string.""" 

980 tz: tzinfo = timezone.utc 

981 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz) 

982 return dt.isoformat(sep="_", timespec="seconds") 

983 

984 

985def current_datetime( 

986 tz_spec: str | None = None, 

987 now_fn: Callable[[tzinfo | None], datetime] | None = None, 

988) -> datetime: 

989 """Returns current time in ``tz_spec`` timezone or local timezone.""" 

990 if now_fn is None: 

991 now_fn = datetime.now 

992 return now_fn(get_timezone(tz_spec)) 

993 

994 

995def get_timezone(tz_spec: str | None = None) -> tzinfo | None: 

996 """Returns timezone from spec or local timezone if unspecified.""" 

997 tz: tzinfo | None 

998 if tz_spec is None: 

999 tz = None 

1000 elif tz_spec == "UTC": 

1001 tz = timezone.utc 

1002 else: 

1003 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec): 

1004 sign, hours, minutes = match.groups() 

1005 offset: int = int(hours) * 60 + int(minutes) 

1006 offset = -offset if sign == "-" else offset 

1007 tz = timezone(timedelta(minutes=offset)) 

1008 elif "/" in tz_spec: 

1009 from zoneinfo import ZoneInfo # lazy import for startup perf 

1010 

1011 tz = ZoneInfo(tz_spec) 

1012 else: 

1013 raise ValueError(f"Invalid timezone specification: {tz_spec}") 

1014 return tz 

1015 

1016 

1017############################################################################### 

1018@final 

1019class SnapshotPeriods: # thread-safe 

1020 """Parses snapshot suffix strings and converts between durations.""" 

1021 

1022 def __init__(self) -> None: 

1023 """Initialize lookup tables of suffixes and corresponding millis.""" 

1024 self.suffix_milliseconds: Final[dict[str, int]] = { 

1025 "yearly": 365 * 86400 * 1000, 

1026 "monthly": round(30.5 * 86400 * 1000), 

1027 "weekly": 7 * 86400 * 1000, 

1028 "daily": 86400 * 1000, 

1029 "hourly": 60 * 60 * 1000, 

1030 "minutely": 60 * 1000, 

1031 "secondly": 1000, 

1032 "millisecondly": 1, 

1033 } 

1034 self.period_labels: Final[dict[str, str]] = { 

1035 "yearly": "years", 

1036 "monthly": "months", 

1037 "weekly": "weeks", 

1038 "daily": "days", 

1039 "hourly": "hours", 

1040 "minutely": "minutes", 

1041 "secondly": "seconds", 

1042 "millisecondly": "milliseconds", 

1043 } 

1044 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})") 

1045 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern) 

1046 

1047 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]: 

1048 """Parse suffix like '10minutely' to (10, 'minutely').""" 

1049 return self._suffix_to_duration(suffix, self._suffix_regex0) 

1050 

1051 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]: 

1052 """Like :meth:`suffix_to_duration0` but expects an underscore prefix.""" 

1053 return self._suffix_to_duration(suffix, self._suffix_regex1) 

1054 

1055 @staticmethod 

1056 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]: 

1057 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly').""" 

1058 if match := regex.fullmatch(suffix): 

1059 duration_amount: int = int(match.group(1)) if match.group(1) else 1 

1060 assert duration_amount > 0 

1061 duration_unit: str = match.group(2) 

1062 return duration_amount, duration_unit 

1063 else: 

1064 return 0, "" 

1065 

1066 def label_milliseconds(self, snapshot: str) -> int: 

1067 """Returns duration encoded in ``snapshot`` suffix, in milliseconds.""" 

1068 i = snapshot.rfind("_") 

1069 snapshot = "" if i < 0 else snapshot[i + 1 :] 

1070 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0) 

1071 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0) 

1072 

1073 

1074############################################################################# 

1075@final 

1076class JobStats: 

1077 """Simple thread-safe counters summarizing job progress.""" 

1078 

1079 def __init__(self, jobs_all: int) -> None: 

1080 assert jobs_all >= 0 

1081 self.lock: Final[threading.Lock] = threading.Lock() 

1082 self.jobs_all: int = jobs_all 

1083 self.jobs_started: int = 0 

1084 self.jobs_completed: int = 0 

1085 self.jobs_failed: int = 0 

1086 self.jobs_running: int = 0 

1087 self.sum_elapsed_nanos: int = 0 

1088 self.started_job_names: Final[set[str]] = set() 

1089 

1090 def submit_job(self, job_name: str) -> str: 

1091 """Counts a job submission.""" 

1092 with self.lock: 

1093 self.jobs_started += 1 

1094 self.jobs_running += 1 

1095 self.started_job_names.add(job_name) 

1096 return str(self) 

1097 

1098 def complete_job(self, failed: bool, elapsed_nanos: int) -> str: 

1099 """Counts a job completion.""" 

1100 assert elapsed_nanos >= 0 

1101 with self.lock: 

1102 self.jobs_running -= 1 

1103 self.jobs_completed += 1 

1104 self.jobs_failed += 1 if failed else 0 

1105 self.sum_elapsed_nanos += elapsed_nanos 

1106 msg = str(self) 

1107 assert self.sum_elapsed_nanos >= 0, msg 

1108 assert self.jobs_running >= 0, msg 

1109 assert self.jobs_failed >= 0, msg 

1110 assert self.jobs_failed <= self.jobs_completed, msg 

1111 assert self.jobs_completed <= self.jobs_started, msg 

1112 assert self.jobs_started <= self.jobs_all, msg 

1113 return msg 

1114 

1115 def __repr__(self) -> str: 

1116 def pct(number: int) -> str: 

1117 """Returns percentage string relative to total jobs.""" 

1118 return percent(number, total=self.jobs_all, print_total=True) 

1119 

1120 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed 

1121 running = self.jobs_running 

1122 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed)) 

1123 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}" 

1124 

1125 

1126############################################################################# 

1127class Comparable(Protocol): 

1128 """Partial ordering protocol.""" 

1129 

1130 def __lt__(self, other: Any) -> bool: ... 

1131 

1132 

1133TComparable = TypeVar("TComparable", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue 

1134 

1135 

1136@final 

1137class SmallPriorityQueue(Generic[TComparable]): 

1138 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and 

1139 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case 

1140 for us. 

1141 

1142 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed 

1143 priority queue via 

1144 https://github.com/nvictus/pqdict. 

1145 But, to avoid an external dependency, is actually implemented 

1146 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that 

1147 are already contained in the queue, via removal of the element, followed by update of the element, followed by 

1148 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates. 

1149 """ 

1150 

1151 def __init__(self, reverse: bool = False) -> None: 

1152 """Creates an empty queue; sort order flips when ``reverse`` is True.""" 

1153 self._lst: Final[list[TComparable]] = [] 

1154 self._reverse: Final[bool] = reverse 

1155 

1156 def clear(self) -> None: 

1157 """Removes all elements from the queue.""" 

1158 self._lst.clear() 

1159 

1160 def push(self, element: TComparable) -> None: 

1161 """Inserts ``element`` while maintaining sorted order.""" 

1162 bisect.insort(self._lst, element) 

1163 

1164 def pop(self) -> TComparable: 

1165 """Removes and returns the smallest (or largest if reverse == True) element from the queue.""" 

1166 return self._lst.pop() if self._reverse else self._lst.pop(0) 

1167 

1168 def peek(self) -> TComparable: 

1169 """Returns the smallest (or largest if reverse == True) element without removing it.""" 

1170 return self._lst[-1] if self._reverse else self._lst[0] 

1171 

1172 def remove(self, element: TComparable) -> bool: 

1173 """Removes the first occurrence (in insertion order aka FIFO) of ``element`` and returns True if it was present.""" 

1174 lst = self._lst 

1175 i = bisect.bisect_left(lst, element) 

1176 is_contained = i < len(lst) and lst[i] == element 

1177 if is_contained: 

1178 del lst[i] # is an optimized memmove() 

1179 return is_contained 

1180 

1181 def __len__(self) -> int: 

1182 """Returns the number of queued elements.""" 

1183 return len(self._lst) 

1184 

1185 def __contains__(self, element: TComparable) -> bool: 

1186 """Returns ``True`` if ``element`` is present.""" 

1187 lst = self._lst 

1188 i = bisect.bisect_left(lst, element) 

1189 return i < len(lst) and lst[i] == element 

1190 

1191 def __iter__(self) -> Iterator[TComparable]: 

1192 """Iterates over queued elements in priority order.""" 

1193 return reversed(self._lst) if self._reverse else iter(self._lst) 

1194 

1195 def __repr__(self) -> str: 

1196 """Representation showing queue contents in current order.""" 

1197 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst) 

1198 

1199 

1200############################################################################### 

1201@final 

1202class SortedInterner(Generic[TComparable]): 

1203 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search).""" 

1204 

1205 def __init__(self, sorted_list: list[TComparable]) -> None: 

1206 self._lst: Final[list[TComparable]] = sorted_list 

1207 

1208 def interned(self, element: TComparable) -> TComparable: 

1209 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1210 lst = self._lst 

1211 i = binary_search(lst, element) 

1212 return lst[i] if i >= 0 else element 

1213 

1214 def __contains__(self, element: TComparable) -> bool: 

1215 """Returns ``True`` if ``element`` is present.""" 

1216 return binary_search(self._lst, element) >= 0 

1217 

1218 

1219def binary_search(sorted_list: list[TComparable], item: TComparable) -> int: 

1220 """Java-style binary search; Returns index >= 0 if an equal item is found in list, else '-insertion_point-1'; If it 

1221 returns index >= 0, the index will be the left-most index in case multiple such equal items are contained.""" 

1222 i = bisect.bisect_left(sorted_list, item) 

1223 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1 

1224 

1225 

1226############################################################################### 

1227_S = TypeVar("_S") 

1228 

1229 

1230@final 

1231class HashedInterner(Generic[_S]): 

1232 """Same as sys.intern() except that it isn't global and can also be used for types other than str.""" 

1233 

1234 def __init__(self, items: Iterable[_S] = frozenset()) -> None: 

1235 self._items: Final[dict[_S, _S]] = {v: v for v in items} 

1236 

1237 def intern(self, item: _S) -> _S: 

1238 """Interns the given item.""" 

1239 return self._items.setdefault(item, item) 

1240 

1241 def interned(self, item: _S) -> _S: 

1242 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1243 return self._items.get(item, item) 

1244 

1245 def __contains__(self, item: _S) -> bool: 

1246 return item in self._items 

1247 

1248 

1249############################################################################# 

1250@final 

1251class SynchronizedBool: 

1252 """Thread-safe wrapper around a regular bool.""" 

1253 

1254 def __init__(self, val: bool) -> None: 

1255 assert isinstance(val, bool) 

1256 self._lock: Final[threading.Lock] = threading.Lock() 

1257 self._value: bool = val 

1258 

1259 @property 

1260 def value(self) -> bool: 

1261 """Returns the current boolean value.""" 

1262 with self._lock: 

1263 return self._value 

1264 

1265 @value.setter 

1266 def value(self, new_value: bool) -> None: 

1267 """Atomically assign ``new_value``.""" 

1268 with self._lock: 

1269 self._value = new_value 

1270 

1271 def get_and_set(self, new_value: bool) -> bool: 

1272 """Swaps in ``new_value`` and return the previous value.""" 

1273 with self._lock: 

1274 old_value = self._value 

1275 self._value = new_value 

1276 return old_value 

1277 

1278 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool: 

1279 """Sets to ``new_value`` only if current value equals ``expected_value``.""" 

1280 with self._lock: 

1281 eq: bool = self._value == expected_value 

1282 if eq: 

1283 self._value = new_value 

1284 return eq 

1285 

1286 def __bool__(self) -> bool: 

1287 return self.value 

1288 

1289 def __repr__(self) -> str: 

1290 return repr(self.value) 

1291 

1292 def __str__(self) -> str: 

1293 return str(self.value) 

1294 

1295 

1296############################################################################# 

1297_K = TypeVar("_K") 

1298_V = TypeVar("_V") 

1299 

1300 

1301@final 

1302class SynchronizedDict(Generic[_K, _V]): 

1303 """Thread-safe wrapper around a regular dict.""" 

1304 

1305 def __init__(self, val: dict[_K, _V]) -> None: 

1306 assert isinstance(val, dict) 

1307 self._lock: Final[threading.Lock] = threading.Lock() 

1308 self._dict: Final[dict[_K, _V]] = val 

1309 

1310 def __getitem__(self, key: _K) -> _V: 

1311 with self._lock: 

1312 return self._dict[key] 

1313 

1314 def __setitem__(self, key: _K, value: _V) -> None: 

1315 with self._lock: 

1316 self._dict[key] = value 

1317 

1318 def __delitem__(self, key: _K) -> None: 

1319 with self._lock: 

1320 self._dict.pop(key) 

1321 

1322 def __contains__(self, key: _K) -> bool: 

1323 with self._lock: 

1324 return key in self._dict 

1325 

1326 def __len__(self) -> int: 

1327 with self._lock: 

1328 return len(self._dict) 

1329 

1330 def __repr__(self) -> str: 

1331 with self._lock: 

1332 return repr(self._dict) 

1333 

1334 def __str__(self) -> str: 

1335 with self._lock: 

1336 return str(self._dict) 

1337 

1338 def get(self, key: _K, default: _V | None = None) -> _V | None: 

1339 """Returns ``self[key]`` or ``default`` if missing.""" 

1340 with self._lock: 

1341 return self._dict.get(key, default) 

1342 

1343 def pop(self, key: _K, default: _V | None = None) -> _V | None: 

1344 """Removes ``key`` and returns its value.""" 

1345 with self._lock: 

1346 return self._dict.pop(key, default) 

1347 

1348 def clear(self) -> None: 

1349 """Removes all items atomically.""" 

1350 with self._lock: 

1351 self._dict.clear() 

1352 

1353 def items(self) -> ItemsView[_K, _V]: 

1354 """Returns a snapshot of dictionary items.""" 

1355 with self._lock: 

1356 return self._dict.copy().items() 

1357 

1358 

1359############################################################################# 

1360@final 

1361class InterruptibleSleep: 

1362 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable.""" 

1363 

1364 def __init__(self, lock: threading.Lock | None = None) -> None: 

1365 self._is_stopping: bool = False 

1366 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock() 

1367 self._condition: Final[threading.Condition] = threading.Condition(self._lock) 

1368 

1369 def sleep(self, duration_nanos: int) -> bool: 

1370 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted; 

1371 Equivalent to threading.Event.wait().""" 

1372 end_time_nanos: int = time.monotonic_ns() + duration_nanos 

1373 with self._lock: 

1374 while not self._is_stopping: 

1375 diff_nanos: int = end_time_nanos - time.monotonic_ns() 

1376 if diff_nanos <= 0: 

1377 return False 

1378 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout 

1379 return True 

1380 

1381 def interrupt(self) -> None: 

1382 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set().""" 

1383 with self._lock: 

1384 if not self._is_stopping: 

1385 self._is_stopping = True 

1386 self._condition.notify_all() 

1387 

1388 def reset(self) -> None: 

1389 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear().""" 

1390 with self._lock: 

1391 self._is_stopping = False 

1392 

1393 

1394############################################################################# 

1395@final 

1396class SynchronousExecutor(Executor): 

1397 """Executor that runs tasks inline in the calling thread, sequentially.""" 

1398 

1399 def __init__(self) -> None: 

1400 self._shutdown: bool = False 

1401 

1402 def submit(self, fn: Callable[..., _R_], /, *args: Any, **kwargs: Any) -> Future[_R_]: 

1403 """Executes `fn(*args, **kwargs)` immediately and returns its Future.""" 

1404 future: Future[_R_] = Future() 

1405 if self._shutdown: 

1406 raise RuntimeError("cannot schedule new futures after shutdown") 

1407 try: 

1408 result: _R_ = fn(*args, **kwargs) 

1409 except BaseException as exc: 

1410 future.set_exception(exc) 

1411 else: 

1412 future.set_result(result) 

1413 return future 

1414 

1415 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: 

1416 """Prevents new submissions; no worker resources to join/cleanup.""" 

1417 self._shutdown = True 

1418 

1419 @classmethod 

1420 def executor_for(cls, max_workers: int) -> Executor: 

1421 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor.""" 

1422 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers) 

1423 

1424 

1425############################################################################# 

1426@final 

1427class _XFinally(contextlib.AbstractContextManager): 

1428 """Context manager ensuring cleanup code executes after ``with`` blocks.""" 

1429 

1430 def __init__(self, cleanup: Callable[[], None]) -> None: 

1431 """Records the callable to run upon exit.""" 

1432 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits. 

1433 

1434 def __exit__( 

1435 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None 

1436 ) -> Literal[False]: 

1437 """Runs cleanup and propagate any exceptions appropriately.""" 

1438 try: 

1439 self._cleanup() 

1440 except BaseException as cleanup_exc: 

1441 if exc is None: 

1442 raise # No main error --> propagate cleanup error normally 

1443 # Both failed 

1444 # if sys.version_info >= (3, 11): 

1445 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None 

1446 # <= 3.10: attach so it shows up in traceback but doesn't mask 

1447 exc.__context__ = cleanup_exc 

1448 return False # reraise original exception 

1449 return False # propagate main exception if any 

1450 

1451 

1452def xfinally(cleanup: Callable[[], None]) -> _XFinally: 

1453 """Usage: with xfinally(lambda: cleanup()): ... 

1454 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never 

1455 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible. 

1456 

1457 Problem it solves 

1458 ----------------- 

1459 A naive ``try ... finally`` may lose the original exception: 

1460 

1461 try: 

1462 work() 

1463 finally: 

1464 cleanup() # <-- if this raises an exception, it replaces the real error! 

1465 

1466 `_XFinally` preserves exception priority: 

1467 

1468 * Body raises, cleanup succeeds --> original body exception is re-raised. 

1469 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``. 

1470 * Body succeeds, cleanup raises --> cleanup exception propagates normally. 

1471 

1472 Example: 

1473 ------- 

1474 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP 

1475 ... run_tasks() 

1476 

1477 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information. 

1478 """ 

1479 return _XFinally(cleanup)