Coverage for bzfs_main / util / utils.py: 100%

789 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-24 10:16 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight 

16concurrency primitives, etc. 

17 

18Everything in this module relies only on the Python standard library so other modules remain dependency free. Each utility 

19favors simple, predictable behavior on all supported platforms. 

20""" 

21 

22from __future__ import ( 

23 annotations, 

24) 

25import argparse 

26import base64 

27import bisect 

28import contextlib 

29import dataclasses 

30import errno 

31import hashlib 

32import itertools 

33import logging 

34import operator 

35import os 

36import platform 

37import pwd 

38import random 

39import re 

40import signal 

41import stat 

42import subprocess 

43import sys 

44import threading 

45import time 

46import types 

47from collections import ( 

48 defaultdict, 

49 deque, 

50) 

51from collections.abc import ( 

52 ItemsView, 

53 Iterable, 

54 Iterator, 

55 Sequence, 

56) 

57from concurrent.futures import ( 

58 Executor, 

59 Future, 

60 ThreadPoolExecutor, 

61) 

62from dataclasses import ( 

63 dataclass, 

64) 

65from datetime import ( 

66 datetime, 

67 timedelta, 

68 timezone, 

69 tzinfo, 

70) 

71from subprocess import ( 

72 DEVNULL, 

73 PIPE, 

74) 

75from typing import ( 

76 IO, 

77 Any, 

78 Callable, 

79 Final, 

80 Generic, 

81 Literal, 

82 NoReturn, 

83 Protocol, 

84 SupportsIndex, 

85 TextIO, 

86 TypeVar, 

87 cast, 

88 final, 

89) 

90 

91# constants: 

92PROG_NAME: Final[str] = "bzfs" 

93ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_" 

94DIE_STATUS: Final[int] = 3 

95DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset 

96LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between 

97LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between 

98LOG_DEBUG: Final[int] = logging.DEBUG 

99LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between 

100YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target 

101UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID 

102DONT_SKIP_DATASET: Final[str] = "" 

103SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\" # intentionally not included: -_.:/ 

104SHELL_CHARS_AND_SLASH: Final[str] = SHELL_CHARS + "/" 

105FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write) 

106DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute) 

107UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions 

108UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path' 

109 

110RegexList = list[tuple[re.Pattern[str], bool]] # Type alias 

111 

112 

113def getenv_any(key: str, default: str | None = None, env_var_prefix: str = ENV_VAR_PREFIX) -> str | None: 

114 """All shell environment variable names used for configuration start with this prefix.""" 

115 return os.getenv(env_var_prefix + key, default) 

116 

117 

118def getenv_int(key: str, default: int, env_var_prefix: str = ENV_VAR_PREFIX) -> int: 

119 """Returns environment variable ``key`` as int with ``default`` fallback.""" 

120 return int(cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix))) 

121 

122 

123def getenv_bool(key: str, default: bool = False, env_var_prefix: str = ENV_VAR_PREFIX) -> bool: 

124 """Returns environment variable ``key`` as bool with ``default`` fallback.""" 

125 return cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)).lower().strip() == "true" 

126 

127 

128def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]: 

129 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command.""" 

130 assert lines is not None 

131 assert isinstance(lines, list) 

132 assert len(separator) == 1 

133 if field == 1: 

134 return [line[: line.index(separator)] for line in lines] 

135 elif field == 2: 

136 return [line[line.index(separator) + 1 :] for line in lines] 

137 else: 

138 raise ValueError(f"Invalid field value: {field}") 

139 

140 

141def drain(iterable: Iterable[Any]) -> None: 

142 """Consumes all items in the iterable, effectively draining it.""" 

143 for _ in iterable: 

144 del _ # help gc (iterable can block) 

145 

146 

147_K_ = TypeVar("_K_") 

148_V_ = TypeVar("_V_") 

149_R_ = TypeVar("_R_") 

150 

151 

152def shuffle_dict(dictionary: dict[_K_, _V_], /, rand: random.Random = random.SystemRandom()) -> dict[_K_, _V_]: # noqa: B008 

153 """Returns a new dict with items shuffled randomly.""" 

154 items: list[tuple[_K_, _V_]] = list(dictionary.items()) 

155 rand.shuffle(items) 

156 return dict(items) 

157 

158 

159def sorted_dict( 

160 dictionary: dict[_K_, _V_], /, *, key: Callable[[tuple[_K_, _V_]], Any] | None = None, reverse: bool = False 

161) -> dict[_K_, _V_]: 

162 """Returns a new dict with items sorted, primarily by key and secondarily by value (unless a custom key is supplied).""" 

163 return dict(sorted(dictionary.items(), key=key, reverse=reverse)) 

164 

165 

166def tail(file: str, *, n: int, errors: str | None = None) -> Sequence[str]: 

167 """Return the last ``n`` lines of ``file`` without following symlinks.""" 

168 if not os.path.isfile(file): 

169 return [] 

170 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False) as fd: 

171 return deque(fd, maxlen=n) 

172 

173 

174_NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">")) 

175_NUMERIC_BACKREFERENCE_REGEX: Final[re.Pattern[str]] = re.compile(r"\\\d+") # example: \1 

176 

177 

178def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str: 

179 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky). 

180 

181 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\(' 

182 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash 

183 with the replacement string '(?:' 

184 

185 Named capturing group example: '(?P<name>abc)' --> '(?:abc)' 

186 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash 

187 with the replacement string '(?:' 

188 

189 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups 

190 """ 

191 i = regex.find("[") 

192 if i >= 0 and regex.find("(", i) >= 0: 

193 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the case where the regex might 

194 # contain a regex character class that contains parenthesis. 

195 # Rewriting a regex is a performance optimization; correctness comes first. 

196 return regex 

197 

198 if "(?P=" in regex or "(?(" in regex or _NUMERIC_BACKREFERENCE_REGEX.search(regex): 

199 # Conservative fallback to minimize code complexity: skip the rewrite entirely if the regex might contain a 

200 # (named or conditional or numeric) backreference. 

201 # Rewriting a regex is a performance optimization; correctness comes first. 

202 return regex 

203 

204 i = len(regex) - 2 

205 while i >= 0: 

206 i = regex.rfind("(", 0, i + 1) 

207 if i >= 0 and (i == 0 or regex[i - 1] != "\\"): 

208 if regex[i + 1] != "?": 

209 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group 

210 else: # potentially a valid named capturing group 

211 regex = regex[0:i] + _NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1) 

212 i -= 1 

213 return regex 

214 

215 

216def get_home_directory() -> str: 

217 """Reliably detects home dir without using HOME env var.""" 

218 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~') 

219 return pwd.getpwuid(os.getuid()).pw_dir 

220 

221 

222def human_readable_bytes(num_bytes: float, *, separator: str = " ", precision: int | None = None) -> str: 

223 """Formats 'num_bytes' as a human-readable size; for example "567 MiB".""" 

224 sign = "-" if num_bytes < 0 else "" 

225 s = abs(num_bytes) 

226 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB") 

227 n = len(units) - 1 

228 i = 0 

229 while s >= 1024 and i < n: 

230 s /= 1024 

231 i += 1 

232 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}" 

233 return f"{sign}{formatted_num}{separator}{units[i]}" 

234 

235 

236def human_readable_duration(duration: float, *, unit: str = "ns", separator: str = "", precision: int | None = None) -> str: 

237 """Formats a duration in human units, automatically scaling as needed; for example "567ms".""" 

238 sign = "-" if duration < 0 else "" 

239 t = abs(duration) 

240 units = ("ns", "μs", "ms", "s", "m", "h", "d") 

241 i = units.index(unit) 

242 if t < 1 and t != 0: 

243 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000) 

244 t *= nanos[i] 

245 i = 0 

246 while t >= 1000 and i < 3: 

247 t /= 1000 

248 i += 1 

249 if i >= 3: 

250 while t >= 60 and i < 5: 

251 t /= 60 

252 i += 1 

253 if i >= 5: 

254 while t >= 24 and i < len(units) - 1: 

255 t /= 24 

256 i += 1 

257 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}" 

258 return f"{sign}{formatted_num}{separator}{units[i]}" 

259 

260 

261def human_readable_float(number: float) -> str: 

262 """Formats ``number`` with a variable precision depending on magnitude. 

263 

264 This design mirrors the way humans round values when scanning logs. 

265 

266 If the number has one digit before the decimal point (0 <= abs(number) < 10): 

267 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15"). 

268 

269 If the number has two digits before the decimal point (10 <= abs(number) < 100): 

270 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4"). 

271 

272 If the number has three or more digits before the decimal point (abs(number) >= 100): 

273 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124"). 

274 

275 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1" 

276 """ 

277 abs_number = abs(number) 

278 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0 

279 if precision == 0: 

280 return str(round(number)) 

281 result = f"{number:.{precision}f}" 

282 assert "." in result 

283 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty 

284 return "0" if result == "-0" else result 

285 

286 

287def percent(number: int, total: int, *, print_total: bool = False) -> str: 

288 """Returns percentage string of ``number`` relative to ``total``.""" 

289 tot: str = f"/{total}" if print_total else "" 

290 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%" 

291 

292 

293def open_nofollow( 

294 path: str, 

295 mode: str = "r", 

296 buffering: int = -1, 

297 encoding: str | None = None, 

298 errors: str | None = None, 

299 newline: str | None = None, 

300 *, 

301 perm: int = FILE_PERMISSIONS, 

302 check_owner: bool = True, 

303 **kwargs: Any, 

304) -> IO[Any]: 

305 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with 

306 errno.ELOOP/EMLINK if basename of path is a symlink. 

307 

308 Also, can specify custom permissions on O_CREAT, and verify ownership. 

309 

310 If check_owner=True, write-capable opens require ownership by the effective UID; read-only opens also allow ownership by 

311 uid 0 (root). This allows safe reads of root-owned system files while preventing writes to files not owned by the caller. 

312 """ 

313 if not mode: 

314 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus") 

315 flags = { 

316 "r": os.O_RDONLY, 

317 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 

318 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND, 

319 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL, 

320 }.get(mode[0]) 

321 if flags is None: 

322 raise ValueError(f"invalid mode {mode!r}") 

323 if "+" in mode: # enable read-write access for r+, w+, a+, x+ 

324 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags 

325 flags |= os.O_NOFOLLOW | os.O_CLOEXEC 

326 fd: int = os.open(path, flags=flags, mode=perm) 

327 try: 

328 if check_owner: 

329 st_uid: int = os.fstat(fd).st_uid 

330 if st_uid != os.geteuid(): # verify ownership is current effective UID 

331 if (flags & (os.O_WRONLY | os.O_RDWR)) != 0: # require that writer owns the file 

332 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path) 

333 elif st_uid != 0: # it's ok for root to own a file that we'll merely read 

334 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()} or 0", path) 

335 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs) 

336 except Exception: 

337 try: 

338 os.close(fd) 

339 except OSError: 

340 pass 

341 raise 

342 

343 

344def close_quietly(fd: int) -> None: 

345 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this.""" 

346 if fd >= 0: 

347 try: 

348 os.close(fd) 

349 except OSError: 

350 pass 

351 

352 

353_P = TypeVar("_P") 

354 

355 

356def find_match( 

357 seq: Sequence[_P], 

358 predicate: Callable[[_P], bool], 

359 start: SupportsIndex | None = None, 

360 end: SupportsIndex | None = None, 

361 *, 

362 reverse: bool = False, 

363 raises: bool | object | Callable[[], object] = False, # raises: bool | object | Callable = False, # python >= 3.10 

364) -> int: 

365 """Returns the integer index within ``seq`` of the first item (or last item if reverse=True) that matches the given 

366 predicate condition. 

367 

368 If no matching item is found returns -1 or ValueError, depending on the ``raises`` parameter, which is a bool indicating 

369 whether to raise an error, or an object containing the error message, but can also be a Callable/lambda in order to 

370 support efficient deferred generation of error messages. 

371 

372 Analog to ``str.find()``, including slicing semantics with parameters start and end, i.e. respects Python slicing 

373 semantics for start/end (including clamping). For example, seq can be a list, tuple or str. 

374 

375 Example usage: 

376 lst = ["a", "b", "-c", "d"] 

377 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True) 

378 if i >= 0: 

379 print(lst[i]) 

380 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}") 

381 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}") 

382 """ 

383 if start is None and end is None: 

384 for i in range(len(seq) - 1, -1, -1) if reverse else range(len(seq)): 

385 if predicate(seq[i]): 

386 return i 

387 else: 

388 slice_start, slice_end, _ = slice(start, end).indices(len(seq)) 

389 for i in range(slice_end - 1, slice_start - 1, -1) if reverse else range(slice_start, slice_end): 

390 if predicate(seq[i]): 

391 return i 

392 if raises is False or raises is None: 

393 return -1 

394 if raises is True: 

395 raise ValueError("No matching item found in sequence") 

396 if callable(raises): 

397 raises = raises() 

398 raise ValueError(raises) 

399 

400 

401def is_descendant(dataset: str, of_root_dataset: str) -> bool: 

402 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same.""" 

403 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/") 

404 

405 

406def has_duplicates(sorted_list: list[Any]) -> bool: 

407 """Returns True if any adjacent items within the given sorted sequence are equal.""" 

408 return any(map(operator.eq, sorted_list, itertools.islice(sorted_list, 1, None))) 

409 

410 

411def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool: 

412 """Returns whether the (sorted) list of ZFS input datasets contains any siblings.""" 

413 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted" 

414 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates" 

415 skip_dataset: str = DONT_SKIP_DATASET 

416 parents: set[str] = set() 

417 for dataset in sorted_datasets: 

418 assert dataset 

419 parent = os.path.dirname(dataset) 

420 if parent in parents: 

421 return True # I have a sibling if my parent already has another child 

422 parents.add(parent) 

423 if is_descendant(dataset, of_root_dataset=skip_dataset): 

424 continue 

425 if skip_dataset != DONT_SKIP_DATASET: 

426 return True # I have a sibling if I am a root dataset and another root dataset already exists 

427 skip_dataset = dataset 

428 return False 

429 

430 

431def dry(msg: str, is_dry_run: bool) -> str: 

432 """Prefix ``msg`` with 'Dry' when in dry-run mode.""" 

433 return "Dry " + msg if is_dry_run else msg 

434 

435 

436def relativize_dataset(dataset: str, root_dataset: str) -> str: 

437 """Converts an absolute dataset path to one relative to ``root_dataset``. 

438 

439 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz. 

440 """ 

441 return dataset[len(root_dataset) :] 

442 

443 

444def dataset_paths(dataset: str) -> Iterator[str]: 

445 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c".""" 

446 i: int = 0 

447 while i >= 0: 

448 i = dataset.find("/", i) 

449 if i < 0: 

450 yield dataset 

451 else: 

452 yield dataset[:i] 

453 i += 1 

454 

455 

456def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str: 

457 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present.""" 

458 assert s.startswith(old_prefix) 

459 return new_prefix + s[len(old_prefix) :] 

460 

461 

462def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None: 

463 """Replaces ``old`` with ``new`` in-place for every string in ``lines``.""" 

464 for i in range(len(lines)): 

465 lines[i] = lines[i].replace(old, new, count) 

466 

467 

468_TAPPEND = TypeVar("_TAPPEND") 

469 

470 

471def append_if_absent(lst: list[_TAPPEND], *items: _TAPPEND) -> list[_TAPPEND]: 

472 """Appends items to list if they are not already present.""" 

473 for item in items: 

474 if item not in lst: 

475 lst.append(item) 

476 return lst 

477 

478 

479def xappend(lst: list[_TAPPEND], *items: _TAPPEND | Iterable[_TAPPEND]) -> list[_TAPPEND]: 

480 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If 

481 an item is an iterable does so recursively, flattening the output.""" 

482 for item in items: 

483 if isinstance(item, str) or not isinstance(item, Iterable): 

484 if item: 

485 lst.append(item) 

486 else: 

487 xappend(lst, *item) 

488 return lst 

489 

490 

491def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool: 

492 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False. 

493 

494 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match. 

495 """ 

496 for regex, is_negation in exclude_regexes: 

497 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

498 if is_negation: 

499 is_match = not is_match 

500 if is_match: 

501 return False 

502 

503 for regex, is_negation in include_regexes: 

504 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

505 if is_negation: 

506 is_match = not is_match 

507 if is_match: 

508 return True 

509 

510 return False 

511 

512 

513def compile_regexes(regexes: list[str], *, suffix: str = "") -> RegexList: 

514 """Compiles regex strings and keeps track of negations.""" 

515 assert isinstance(regexes, list) 

516 compiled_regexes: RegexList = [] 

517 for regex in regexes: 

518 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match 

519 if regex.endswith("\\$"): 

520 pass # trailing literal $ is ok 

521 elif regex.endswith("$"): 

522 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch() 

523 elif "$" in regex: 

524 raise re.error("Must not use non-trailing '$' character", regex) 

525 if is_negation := regex.startswith("!"): 

526 regex = regex[1:] 

527 regex = replace_capturing_groups_with_non_capturing_groups(regex) 

528 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")): 

529 regex = f"{regex}{suffix}" 

530 compiled_regexes.append((re.compile(regex), is_negation)) 

531 return compiled_regexes 

532 

533 

534def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any: 

535 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels.""" 

536 

537 @final 

538 class CustomListFormatter: 

539 """Formatter object that joins items when converted to ``str``.""" 

540 

541 def __str__(self) -> str: 

542 s = separator.join(map(str, iterable)) 

543 return s.lstrip() if lstrip else s 

544 

545 return CustomListFormatter() 

546 

547 

548def pretty_print_formatter(obj_to_format: Any) -> Any: 

549 """Lazy pprint formatter used to avoid overhead in disabled log levels.""" 

550 

551 @final 

552 class PrettyPrintFormatter: 

553 """Formatter that pretty-prints the object on conversion to ``str``.""" 

554 

555 def __str__(self) -> str: 

556 import pprint # lazy import for startup perf 

557 

558 return pprint.pformat(vars(obj_to_format)) 

559 

560 return PrettyPrintFormatter() 

561 

562 

563def stderr_to_str(stderr: Any) -> str: 

564 """Workaround for https://github.com/python/cpython/issues/87597.""" 

565 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace") 

566 

567 

568def xprint(log: logging.Logger, value: Any, *, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None: 

569 """Optionally logs ``value`` at stdout/stderr level.""" 

570 if run and value: 

571 value = value if end else str(value).rstrip() 

572 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR 

573 log.log(level, "%s", value) 

574 

575 

576def sha256_hex(text: str) -> str: 

577 """Returns the sha256 hex string for the given text.""" 

578 return hashlib.sha256(text.encode()).hexdigest() 

579 

580 

581def sha256_urlsafe_base64(text: str, *, padding: bool = True) -> str: 

582 """Returns the URL-safe base64-encoded sha256 value for the given text.""" 

583 digest: bytes = hashlib.sha256(text.encode()).digest() 

584 s: str = base64.urlsafe_b64encode(digest).decode() 

585 return s if padding else s.rstrip("=") 

586 

587 

588def sha256_128_urlsafe_base64(text: str) -> str: 

589 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

590 s: str = sha256_urlsafe_base64(text, padding=False) 

591 return s[: len(s) // 2] 

592 

593 

594def sha256_85_urlsafe_base64(text: str) -> str: 

595 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

596 s: str = sha256_urlsafe_base64(text, padding=False) 

597 return s[: len(s) // 3] 

598 

599 

600def urlsafe_base64( 

601 value: int, max_value: int = 2**64 - 1, *, padding: bool = True, byteorder: Literal["little", "big"] = "big" 

602) -> str: 

603 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value].""" 

604 assert 0 <= value <= max_value 

605 max_bytes: int = (max_value.bit_length() + 7) // 8 

606 value_bytes: bytes = value.to_bytes(max_bytes, byteorder) 

607 s: str = base64.urlsafe_b64encode(value_bytes).decode() 

608 return s if padding else s.rstrip("=") 

609 

610 

611def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn: 

612 """Exits the program with ``exit_code`` after logging ``msg``.""" 

613 if parser is None: 

614 ex = SystemExit(msg) 

615 ex.code = exit_code 

616 raise ex 

617 else: 

618 parser.error(msg) 

619 

620 

621def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

622 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and 

623 provides optional child PID tracking, and optional logging of execution status via ``log`` and ``loglevel`` params.""" 

624 input_value = kwargs.pop("input", None) 

625 timeout = kwargs.pop("timeout", None) 

626 check = kwargs.pop("check", False) 

627 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None) 

628 if input_value is not None: 

629 if kwargs.get("stdin") is not None: 

630 raise ValueError("input and stdin are mutually exclusive") 

631 kwargs["stdin"] = subprocess.PIPE 

632 

633 log: logging.Logger | None = kwargs.pop("log", None) 

634 loglevel: int | None = kwargs.pop("loglevel", None) 

635 start_time_nanos: int = time.monotonic_ns() 

636 is_timeout: bool = False 

637 is_cancel: bool = False 

638 exitcode: int | None = None 

639 

640 def log_status() -> None: 

641 if log is not None: 

642 _loglevel: int = loglevel if loglevel is not None else getenv_int("subprocess_run_loglevel", LOG_TRACE) 

643 if log.isEnabledFor(_loglevel): 

644 elapsed_time: str = human_readable_float((time.monotonic_ns() - start_time_nanos) / 1_000_000) + "ms" 

645 status = "cancel" if is_cancel else "timeout" if is_timeout else "success" if exitcode == 0 else "failure" 

646 cmd = kwargs["args"] if "args" in kwargs else (args[0] if args else None) 

647 cmd_str: str = " ".join(str(arg) for arg in iter(cmd)) if isinstance(cmd, (list, tuple)) else str(cmd) 

648 log.log(_loglevel, f"Executed [{status}] [{elapsed_time}]: %s", cmd_str) 

649 

650 with xfinally(log_status): 

651 ctx: contextlib.AbstractContextManager[subprocess.Popen] 

652 if subprocesses is None: 

653 ctx = subprocess.Popen(*args, **kwargs) 

654 else: 

655 ctx = subprocesses.popen_and_track(*args, **kwargs) 

656 with ctx as proc: 

657 try: 

658 sp = subprocesses 

659 if sp is not None and sp._is_terminated(): # noqa: SLF001 pylint: disable=protected-access 

660 is_cancel = True 

661 timeout = 0.0 

662 stdout, stderr = proc.communicate(input_value, timeout=timeout) 

663 except BaseException as e: 

664 try: 

665 if isinstance(e, subprocess.TimeoutExpired): 

666 is_timeout = True 

667 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child proc and descendants 

668 finally: 

669 proc.kill() 

670 raise 

671 else: 

672 exitcode = proc.poll() 

673 assert exitcode is not None 

674 if check and exitcode: 

675 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr) 

676 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr) 

677 

678 

679def terminate_process_subtree( 

680 *, except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM 

681) -> None: 

682 """For each root PID: Sends the given signal to the root PID and all its descendant processes.""" 

683 current_pid: int = os.getpid() 

684 root_pids = [current_pid] if root_pids is None else root_pids 

685 all_pids: list[list[int]] = _get_descendant_processes(root_pids) 

686 assert len(all_pids) == len(root_pids) 

687 for i, pids in enumerate(all_pids): 

688 root_pid = root_pids[i] 

689 if root_pid == current_pid: 

690 pids += [] if except_current_process else [current_pid] 

691 else: 

692 pids.insert(0, root_pid) 

693 for pid in pids: 

694 with contextlib.suppress(OSError): 

695 os.kill(pid, sig) 

696 

697 

698def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]: 

699 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems.""" 

700 if len(root_pids) == 0: 

701 return [] 

702 cmd: list[str] = ["ps", "-Ao", "pid,ppid"] 

703 try: 

704 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines() 

705 except PermissionError: 

706 # degrade gracefully in sandbox environments that deny executing `ps` entirely 

707 return [[] for _ in root_pids] 

708 procs: dict[int, list[int]] = defaultdict(list) 

709 for line in lines[1:]: # all lines except the header line 

710 splits: list[str] = line.split() 

711 assert len(splits) == 2 

712 pid = int(splits[0]) 

713 ppid = int(splits[1]) 

714 procs[ppid].append(pid) 

715 

716 def recursive_append(ppid: int, descendants: list[int]) -> None: 

717 """Recursively collect descendant PIDs starting from ``ppid``.""" 

718 for child_pid in procs[ppid]: 

719 descendants.append(child_pid) 

720 recursive_append(child_pid, descendants) 

721 

722 all_descendants: list[list[int]] = [] 

723 for root_pid in root_pids: 

724 descendants: list[int] = [] 

725 recursive_append(root_pid, descendants) 

726 all_descendants.append(descendants) 

727 return all_descendants 

728 

729 

730@contextlib.contextmanager 

731def termination_signal_handler( 

732 termination_events: list[threading.Event], 

733 *, 

734 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(), 

735) -> Iterator[None]: 

736 """Context manager that installs SIGINT/SIGTERM handlers that set all ``termination_events`` and, by default, terminate 

737 all descendant processes.""" 

738 termination_events = list(termination_events) # shallow copy 

739 

740 def _handler(_sig: int, _frame: object) -> None: 

741 for event in termination_events: 

742 event.set() 

743 termination_handler() 

744 

745 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler 

746 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler 

747 try: 

748 yield # run body of context manager 

749 finally: 

750 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler 

751 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler 

752 

753 

754def return_false() -> bool: 

755 """Always returns ``False``; picklable.""" 

756 return False 

757 

758 

759def sleep_nanos(delay_nanos: int) -> None: 

760 """Same as time.sleep() but expects a relative sleep duration in nanoseconds as input value; picklable.""" 

761 time.sleep(delay_nanos / 1_000_000_000) 

762 

763 

764############################################################################# 

765@dataclass(frozen=True) 

766@final 

767class TaskTiming: 

768 """Customizable callbacks for reading the current monotonic time, sleeping and optional async termination; immutable.""" 

769 

770 monotonic_ns: Callable[[], int] = time.monotonic_ns 

771 

772 is_terminated: Callable[[], bool] = return_false 

773 """Returns whether a predicate has become true; can be used to indicate system shutdown or similar cancellation 

774 conditions; default is to always return ``False``.""" 

775 

776 sleep: Callable[[int], None] = sleep_nanos 

777 """Sleeps N nanoseconds; thread-safe.""" 

778 

779 def copy(self, **override_kwargs: Any) -> TaskTiming: 

780 """Creates a new object copying an existing one with the specified fields overridden for customization; thread- 

781 safe.""" 

782 return dataclasses.replace(self, **override_kwargs) 

783 

784 @staticmethod 

785 def make_from(termination_event: threading.Event | None) -> TaskTiming: 

786 """Convenience factory that creates an object that performs async termination when ``termination_event`` is set.""" 

787 if termination_event is None: 

788 return TaskTiming() 

789 

790 def _sleep(delay_nanos: int) -> None: 

791 termination_event.wait(delay_nanos / 1_000_000_000) # allow early wakeup on async termination 

792 

793 return TaskTiming(is_terminated=termination_event.is_set, sleep=_sleep) 

794 

795 

796############################################################################# 

797@final 

798class Subprocesses: 

799 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used 

800 when multiple jobs run concurrently within the same Python process. 

801 

802 Optionally binds to an ``_is_terminated`` predicate to enforce async cancellation by forcing immediate timeouts for newly 

803 spawned subprocesses once cancellation is requested. 

804 """ 

805 

806 def __init__(self, is_terminated: Callable[[], bool] = return_false) -> None: 

807 self._is_terminated: Final[Callable[[], bool]] = is_terminated 

808 self._lock: Final[threading.Lock] = threading.Lock() 

809 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order 

810 

811 @contextlib.contextmanager 

812 def popen_and_track(self, *popen_args: Any, **popen_kwargs: Any) -> Iterator[subprocess.Popen]: 

813 """Context manager that calls subprocess.Popen() and tracks the child PID for per-job termination. 

814 

815 Holds a lock across Popen+PID registration to prevent a race when terminate_process_subtrees() is invoked (e.g. from 

816 SIGINT/SIGTERM handlers), ensuring newly spawned child processes cannot escape termination. The child PID is 

817 unregistered on context exit. 

818 """ 

819 with self._lock: 

820 proc: subprocess.Popen = subprocess.Popen(*popen_args, **popen_kwargs) 

821 self._child_pids[proc.pid] = None 

822 try: 

823 yield proc 

824 finally: 

825 with self._lock: 

826 self._child_pids.pop(proc.pid, None) 

827 

828 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

829 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination.""" 

830 return subprocess_run(*args, **kwargs, subprocesses=self) 

831 

832 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None: 

833 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs.""" 

834 with self._lock: 

835 pids: list[int] = list(self._child_pids) 

836 self._child_pids.clear() 

837 terminate_process_subtree(root_pids=pids, sig=sig) 

838 

839 

840############################################################################# 

841def pid_exists(pid: int) -> bool | None: 

842 """Returns True if a process with PID exists, False if not, or None on error.""" 

843 if pid <= 0: 

844 return False 

845 try: # with signal=0, no signal is actually sent, but error checking is still performed 

846 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems 

847 except OSError as err: 

848 if err.errno == errno.ESRCH: # No such process 

849 return False 

850 if err.errno == errno.EPERM: # Operation not permitted 

851 return True 

852 return None 

853 return True 

854 

855 

856def nprefix(s: str) -> str: 

857 """Returns a canonical snapshot prefix with trailing underscore.""" 

858 return sys.intern(s + "_") 

859 

860 

861def ninfix(s: str) -> str: 

862 """Returns a canonical infix with trailing underscore when not empty.""" 

863 return sys.intern(s + "_") if s else "" 

864 

865 

866def nsuffix(s: str) -> str: 

867 """Returns a canonical suffix with leading underscore when not empty.""" 

868 return sys.intern("_" + s) if s else "" 

869 

870 

871def format_dict(dictionary: dict[Any, Any]) -> str: 

872 """Returns a formatted dictionary using repr for consistent output.""" 

873 return f'"{dictionary}"' 

874 

875 

876def format_obj(obj: object) -> str: 

877 """Returns a formatted str using repr for consistent output.""" 

878 return f'"{obj}"' 

879 

880 

881def validate_dataset_name(dataset: str, input_text: str) -> None: 

882 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc.""" 

883 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424 

884 # and https://github.com/openzfs/zfs/issues/8798 

885 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html 

886 invalid_chars: str = SHELL_CHARS 

887 if ( 

888 dataset in ("", ".", "..") 

889 or dataset.startswith(("/", "./", "../")) 

890 or dataset.endswith(("/", "/.", "/..")) 

891 or any(substring in dataset for substring in ("//", "/./", "/../")) 

892 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset) 

893 or not dataset[0].isalpha() 

894 ): 

895 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'") 

896 

897 

898def validate_property_name(propname: str, input_text: str) -> str: 

899 """Checks that the ZFS property name contains no spaces or shell chars, etc.""" 

900 invalid_chars: str = SHELL_CHARS 

901 if (not propname) or propname.startswith("-") or any(char.isspace() or char in invalid_chars for char in propname): 

902 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'") 

903 return propname 

904 

905 

906def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None: 

907 """Checks that the given path is not a symbolic link.""" 

908 if os.path.islink(path): 

909 die(f"{msg}must not be a symlink: {path}", parser=parser) 

910 

911 

912def validate_file_permissions(path: str, mode: int) -> None: 

913 """Verify permissions and that ownership is current effective UID.""" 

914 stats: os.stat_result = os.stat(path, follow_symlinks=False) 

915 st_uid: int = stats.st_uid 

916 if st_uid != os.geteuid(): # verify ownership is current effective UID 

917 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}") 

918 st_mode = stat.S_IMODE(stats.st_mode) 

919 if st_mode != mode: 

920 die( 

921 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, " 

922 f"not {mode:03o} aka {stat.filemode(mode)[1:]})" 

923 ) 

924 

925 

926def parse_duration_to_milliseconds(duration: str, *, regex_suffix: str = "", context: str = "") -> int: 

927 """Parses human duration strings like '5minutes' or '2 hours' to milliseconds.""" 

928 unit_milliseconds: dict[str, int] = { 

929 "milliseconds": 1, 

930 "millis": 1, 

931 "seconds": 1000, 

932 "secs": 1000, 

933 "minutes": 60 * 1000, 

934 "mins": 60 * 1000, 

935 "hours": 60 * 60 * 1000, 

936 "days": 86400 * 1000, 

937 "weeks": 7 * 86400 * 1000, 

938 "months": round(30.5 * 86400 * 1000), 

939 "years": 365 * 86400 * 1000, 

940 } 

941 match = re.fullmatch( 

942 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix, 

943 duration, 

944 ) 

945 if not match: 

946 if context: 

947 die(f"Invalid duration format: {duration} within {context}") 

948 else: 

949 raise ValueError(f"Invalid duration format: {duration}") 

950 assert match 

951 quantity: int = int(match.group(1)) 

952 unit: str = match.group(2) 

953 return quantity * unit_milliseconds[unit] 

954 

955 

956def unixtime_fromisoformat(datetime_str: str) -> int: 

957 """Converts ISO 8601 datetime string into UTC Unix time in integer seconds.""" 

958 return int(datetime.fromisoformat(datetime_str).timestamp()) 

959 

960 

961def isotime_from_unixtime(unixtime_in_seconds: int) -> str: 

962 """Converts UTC Unix time seconds into ISO 8601 datetime string.""" 

963 tz: tzinfo = timezone.utc 

964 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz) 

965 return dt.isoformat(sep="_", timespec="seconds") 

966 

967 

968def current_datetime( 

969 tz_spec: str | None = None, 

970 now_fn: Callable[[tzinfo | None], datetime] | None = None, 

971) -> datetime: 

972 """Returns current time in ``tz_spec`` timezone or local timezone.""" 

973 if now_fn is None: 

974 now_fn = datetime.now 

975 return now_fn(get_timezone(tz_spec)) 

976 

977 

978def get_timezone(tz_spec: str | None = None) -> tzinfo | None: 

979 """Returns timezone from spec or local timezone if unspecified.""" 

980 tz: tzinfo | None 

981 if tz_spec is None: 

982 tz = None 

983 elif tz_spec == "UTC": 

984 tz = timezone.utc 

985 else: 

986 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec): 

987 sign, hours, minutes = match.groups() 

988 offset: int = int(hours) * 60 + int(minutes) 

989 offset = -offset if sign == "-" else offset 

990 tz = timezone(timedelta(minutes=offset)) 

991 elif "/" in tz_spec: 

992 from zoneinfo import ZoneInfo # lazy import for startup perf 

993 

994 tz = ZoneInfo(tz_spec) 

995 else: 

996 raise ValueError(f"Invalid timezone specification: {tz_spec}") 

997 return tz 

998 

999 

1000############################################################################### 

1001@final 

1002class SnapshotPeriods: # thread-safe 

1003 """Parses snapshot suffix strings and converts between durations.""" 

1004 

1005 def __init__(self) -> None: 

1006 """Initialize lookup tables of suffixes and corresponding millis.""" 

1007 self.suffix_milliseconds: Final[dict[str, int]] = { 

1008 "yearly": 365 * 86400 * 1000, 

1009 "monthly": round(30.5 * 86400 * 1000), 

1010 "weekly": 7 * 86400 * 1000, 

1011 "daily": 86400 * 1000, 

1012 "hourly": 60 * 60 * 1000, 

1013 "minutely": 60 * 1000, 

1014 "secondly": 1000, 

1015 "millisecondly": 1, 

1016 } 

1017 self.period_labels: Final[dict[str, str]] = { 

1018 "yearly": "years", 

1019 "monthly": "months", 

1020 "weekly": "weeks", 

1021 "daily": "days", 

1022 "hourly": "hours", 

1023 "minutely": "minutes", 

1024 "secondly": "seconds", 

1025 "millisecondly": "milliseconds", 

1026 } 

1027 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})") 

1028 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern) 

1029 

1030 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]: 

1031 """Parse suffix like '10minutely' to (10, 'minutely').""" 

1032 return self._suffix_to_duration(suffix, self._suffix_regex0) 

1033 

1034 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]: 

1035 """Like :meth:`suffix_to_duration0` but expects an underscore prefix.""" 

1036 return self._suffix_to_duration(suffix, self._suffix_regex1) 

1037 

1038 @staticmethod 

1039 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]: 

1040 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly').""" 

1041 if match := regex.fullmatch(suffix): 

1042 duration_amount: int = int(match.group(1)) if match.group(1) else 1 

1043 assert duration_amount > 0 

1044 duration_unit: str = match.group(2) 

1045 return duration_amount, duration_unit 

1046 else: 

1047 return 0, "" 

1048 

1049 def label_milliseconds(self, snapshot: str) -> int: 

1050 """Returns duration encoded in ``snapshot`` suffix, in milliseconds.""" 

1051 i = snapshot.rfind("_") 

1052 snapshot = "" if i < 0 else snapshot[i + 1 :] 

1053 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0) 

1054 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0) 

1055 

1056 

1057############################################################################# 

1058@final 

1059class JobStats: 

1060 """Simple thread-safe counters summarizing job progress.""" 

1061 

1062 def __init__(self, jobs_all: int) -> None: 

1063 assert jobs_all >= 0 

1064 self.lock: Final[threading.Lock] = threading.Lock() 

1065 self.jobs_all: int = jobs_all 

1066 self.jobs_started: int = 0 

1067 self.jobs_completed: int = 0 

1068 self.jobs_failed: int = 0 

1069 self.jobs_running: int = 0 

1070 self.sum_elapsed_nanos: int = 0 

1071 self.started_job_names: Final[set[str]] = set() 

1072 

1073 def submit_job(self, job_name: str) -> str: 

1074 """Counts a job submission.""" 

1075 with self.lock: 

1076 self.jobs_started += 1 

1077 self.jobs_running += 1 

1078 self.started_job_names.add(job_name) 

1079 return str(self) 

1080 

1081 def complete_job(self, failed: bool, elapsed_nanos: int) -> str: 

1082 """Counts a job completion.""" 

1083 assert elapsed_nanos >= 0 

1084 with self.lock: 

1085 self.jobs_running -= 1 

1086 self.jobs_completed += 1 

1087 self.jobs_failed += 1 if failed else 0 

1088 self.sum_elapsed_nanos += elapsed_nanos 

1089 msg = str(self) 

1090 assert self.sum_elapsed_nanos >= 0, msg 

1091 assert self.jobs_running >= 0, msg 

1092 assert self.jobs_failed >= 0, msg 

1093 assert self.jobs_failed <= self.jobs_completed, msg 

1094 assert self.jobs_completed <= self.jobs_started, msg 

1095 assert self.jobs_started <= self.jobs_all, msg 

1096 return msg 

1097 

1098 def __repr__(self) -> str: 

1099 def pct(number: int) -> str: 

1100 """Returns percentage string relative to total jobs.""" 

1101 return percent(number, total=self.jobs_all, print_total=True) 

1102 

1103 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed 

1104 running = self.jobs_running 

1105 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed)) 

1106 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}" 

1107 

1108 

1109############################################################################# 

1110class Comparable(Protocol): 

1111 """Partial ordering protocol.""" 

1112 

1113 def __lt__(self, other: Any) -> bool: ... 

1114 

1115 

1116TComparable = TypeVar("TComparable", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue 

1117 

1118 

1119@final 

1120class SmallPriorityQueue(Generic[TComparable]): 

1121 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and 

1122 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case 

1123 for us. 

1124 

1125 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed 

1126 priority queue via 

1127 https://github.com/nvictus/pqdict. 

1128 But, to avoid an external dependency, is actually implemented 

1129 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that 

1130 are already contained in the queue, via removal of the element, followed by update of the element, followed by 

1131 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates. 

1132 """ 

1133 

1134 def __init__(self, reverse: bool = False) -> None: 

1135 """Creates an empty queue; sort order flips when ``reverse`` is True.""" 

1136 self._lst: Final[list[TComparable]] = [] 

1137 self._reverse: Final[bool] = reverse 

1138 

1139 def clear(self) -> None: 

1140 """Removes all elements from the queue.""" 

1141 self._lst.clear() 

1142 

1143 def push(self, element: TComparable) -> None: 

1144 """Inserts ``element`` while maintaining sorted order.""" 

1145 bisect.insort(self._lst, element) 

1146 

1147 def pop(self) -> TComparable: 

1148 """Removes and returns the smallest (or largest if reverse == True) element from the queue.""" 

1149 return self._lst.pop() if self._reverse else self._lst.pop(0) 

1150 

1151 def peek(self) -> TComparable: 

1152 """Returns the smallest (or largest if reverse == True) element without removing it.""" 

1153 return self._lst[-1] if self._reverse else self._lst[0] 

1154 

1155 def remove(self, element: TComparable) -> bool: 

1156 """Removes the first occurrence (in insertion order aka FIFO) of ``element`` and returns True if it was present.""" 

1157 lst = self._lst 

1158 i = bisect.bisect_left(lst, element) 

1159 is_contained = i < len(lst) and lst[i] == element 

1160 if is_contained: 

1161 del lst[i] # is an optimized memmove() 

1162 return is_contained 

1163 

1164 def __len__(self) -> int: 

1165 """Returns the number of queued elements.""" 

1166 return len(self._lst) 

1167 

1168 def __contains__(self, element: TComparable) -> bool: 

1169 """Returns ``True`` if ``element`` is present.""" 

1170 lst = self._lst 

1171 i = bisect.bisect_left(lst, element) 

1172 return i < len(lst) and lst[i] == element 

1173 

1174 def __iter__(self) -> Iterator[TComparable]: 

1175 """Iterates over queued elements in priority order.""" 

1176 return reversed(self._lst) if self._reverse else iter(self._lst) 

1177 

1178 def __repr__(self) -> str: 

1179 """Representation showing queue contents in current order.""" 

1180 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst) 

1181 

1182 

1183############################################################################### 

1184@final 

1185class SortedInterner(Generic[TComparable]): 

1186 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search).""" 

1187 

1188 def __init__(self, sorted_list: list[TComparable]) -> None: 

1189 self._lst: Final[list[TComparable]] = sorted_list 

1190 

1191 def interned(self, element: TComparable) -> TComparable: 

1192 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1193 lst = self._lst 

1194 i = binary_search(lst, element) 

1195 return lst[i] if i >= 0 else element 

1196 

1197 def __contains__(self, element: TComparable) -> bool: 

1198 """Returns ``True`` if ``element`` is present.""" 

1199 return binary_search(self._lst, element) >= 0 

1200 

1201 

1202def binary_search(sorted_list: list[TComparable], item: TComparable) -> int: 

1203 """Java-style binary search; Returns index >= 0 if an equal item is found in list, else '-insertion_point-1'; If it 

1204 returns index >= 0, the index will be the left-most index in case multiple such equal items are contained.""" 

1205 i = bisect.bisect_left(sorted_list, item) 

1206 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1 

1207 

1208 

1209############################################################################### 

1210_S = TypeVar("_S") 

1211 

1212 

1213@final 

1214class HashedInterner(Generic[_S]): 

1215 """Same as sys.intern() except that it isn't global and can also be used for types other than str.""" 

1216 

1217 def __init__(self, items: Iterable[_S] = frozenset()) -> None: 

1218 self._items: Final[dict[_S, _S]] = {v: v for v in items} 

1219 

1220 def intern(self, item: _S) -> _S: 

1221 """Interns the given item.""" 

1222 return self._items.setdefault(item, item) 

1223 

1224 def interned(self, item: _S) -> _S: 

1225 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1226 return self._items.get(item, item) 

1227 

1228 def __contains__(self, item: _S) -> bool: 

1229 return item in self._items 

1230 

1231 

1232############################################################################# 

1233@final 

1234class SynchronizedBool: 

1235 """Thread-safe wrapper around a regular bool.""" 

1236 

1237 def __init__(self, val: bool) -> None: 

1238 assert isinstance(val, bool) 

1239 self._lock: Final[threading.Lock] = threading.Lock() 

1240 self._value: bool = val 

1241 

1242 @property 

1243 def value(self) -> bool: 

1244 """Returns the current boolean value.""" 

1245 with self._lock: 

1246 return self._value 

1247 

1248 @value.setter 

1249 def value(self, new_value: bool) -> None: 

1250 """Atomically assign ``new_value``.""" 

1251 with self._lock: 

1252 self._value = new_value 

1253 

1254 def get_and_set(self, new_value: bool) -> bool: 

1255 """Swaps in ``new_value`` and return the previous value.""" 

1256 with self._lock: 

1257 old_value = self._value 

1258 self._value = new_value 

1259 return old_value 

1260 

1261 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool: 

1262 """Sets to ``new_value`` only if current value equals ``expected_value``.""" 

1263 with self._lock: 

1264 eq: bool = self._value == expected_value 

1265 if eq: 

1266 self._value = new_value 

1267 return eq 

1268 

1269 def __bool__(self) -> bool: 

1270 return self.value 

1271 

1272 def __repr__(self) -> str: 

1273 return repr(self.value) 

1274 

1275 def __str__(self) -> str: 

1276 return str(self.value) 

1277 

1278 

1279############################################################################# 

1280_K = TypeVar("_K") 

1281_V = TypeVar("_V") 

1282 

1283 

1284@final 

1285class SynchronizedDict(Generic[_K, _V]): 

1286 """Thread-safe wrapper around a regular dict.""" 

1287 

1288 def __init__(self, val: dict[_K, _V]) -> None: 

1289 assert isinstance(val, dict) 

1290 self._lock: Final[threading.Lock] = threading.Lock() 

1291 self._dict: Final[dict[_K, _V]] = val 

1292 

1293 def __getitem__(self, key: _K) -> _V: 

1294 with self._lock: 

1295 return self._dict[key] 

1296 

1297 def __setitem__(self, key: _K, value: _V) -> None: 

1298 with self._lock: 

1299 self._dict[key] = value 

1300 

1301 def __delitem__(self, key: _K) -> None: 

1302 with self._lock: 

1303 self._dict.pop(key) 

1304 

1305 def __contains__(self, key: _K) -> bool: 

1306 with self._lock: 

1307 return key in self._dict 

1308 

1309 def __len__(self) -> int: 

1310 with self._lock: 

1311 return len(self._dict) 

1312 

1313 def __repr__(self) -> str: 

1314 with self._lock: 

1315 return repr(self._dict) 

1316 

1317 def __str__(self) -> str: 

1318 with self._lock: 

1319 return str(self._dict) 

1320 

1321 def get(self, key: _K, default: _V | None = None) -> _V | None: 

1322 """Returns ``self[key]`` or ``default`` if missing.""" 

1323 with self._lock: 

1324 return self._dict.get(key, default) 

1325 

1326 def pop(self, key: _K, default: _V | None = None) -> _V | None: 

1327 """Removes ``key`` and returns its value.""" 

1328 with self._lock: 

1329 return self._dict.pop(key, default) 

1330 

1331 def clear(self) -> None: 

1332 """Removes all items atomically.""" 

1333 with self._lock: 

1334 self._dict.clear() 

1335 

1336 def items(self) -> ItemsView[_K, _V]: 

1337 """Returns a snapshot of dictionary items.""" 

1338 with self._lock: 

1339 return self._dict.copy().items() 

1340 

1341 

1342############################################################################# 

1343@final 

1344class InterruptibleSleep: 

1345 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable.""" 

1346 

1347 def __init__(self, lock: threading.Lock | None = None) -> None: 

1348 self._is_stopping: bool = False 

1349 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock() 

1350 self._condition: Final[threading.Condition] = threading.Condition(self._lock) 

1351 

1352 def sleep(self, duration_nanos: int) -> bool: 

1353 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted; 

1354 Equivalent to threading.Event.wait().""" 

1355 end_time_nanos: int = time.monotonic_ns() + duration_nanos 

1356 with self._lock: 

1357 while not self._is_stopping: 

1358 diff_nanos: int = end_time_nanos - time.monotonic_ns() 

1359 if diff_nanos <= 0: 

1360 return False 

1361 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout 

1362 return True 

1363 

1364 def interrupt(self) -> None: 

1365 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set().""" 

1366 with self._lock: 

1367 if not self._is_stopping: 

1368 self._is_stopping = True 

1369 self._condition.notify_all() 

1370 

1371 def reset(self) -> None: 

1372 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear().""" 

1373 with self._lock: 

1374 self._is_stopping = False 

1375 

1376 

1377############################################################################# 

1378@final 

1379class SynchronousExecutor(Executor): 

1380 """Executor that runs tasks inline in the calling thread, sequentially.""" 

1381 

1382 def __init__(self) -> None: 

1383 self._shutdown: bool = False 

1384 

1385 def submit(self, fn: Callable[..., _R_], /, *args: Any, **kwargs: Any) -> Future[_R_]: 

1386 """Executes `fn(*args, **kwargs)` immediately and returns its Future.""" 

1387 future: Future[_R_] = Future() 

1388 if self._shutdown: 

1389 raise RuntimeError("cannot schedule new futures after shutdown") 

1390 try: 

1391 result: _R_ = fn(*args, **kwargs) 

1392 except BaseException as exc: 

1393 future.set_exception(exc) 

1394 else: 

1395 future.set_result(result) 

1396 return future 

1397 

1398 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: 

1399 """Prevents new submissions; no worker resources to join/cleanup.""" 

1400 self._shutdown = True 

1401 

1402 @classmethod 

1403 def executor_for(cls, max_workers: int) -> Executor: 

1404 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor.""" 

1405 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers) 

1406 

1407 

1408############################################################################# 

1409@final 

1410class _XFinally(contextlib.AbstractContextManager): 

1411 """Context manager ensuring cleanup code executes after ``with`` blocks.""" 

1412 

1413 def __init__(self, cleanup: Callable[[], None]) -> None: 

1414 """Records the callable to run upon exit.""" 

1415 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits. 

1416 

1417 def __exit__( 

1418 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None 

1419 ) -> Literal[False]: 

1420 """Runs cleanup and propagate any exceptions appropriately.""" 

1421 try: 

1422 self._cleanup() 

1423 except BaseException as cleanup_exc: 

1424 if exc is None: 

1425 raise # No main error --> propagate cleanup error normally 

1426 # Both failed 

1427 # if sys.version_info >= (3, 11): 

1428 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None 

1429 # <= 3.10: attach so it shows up in traceback but doesn't mask 

1430 exc.__context__ = cleanup_exc 

1431 return False # reraise original exception 

1432 return False # propagate main exception if any 

1433 

1434 

1435def xfinally(cleanup: Callable[[], None]) -> _XFinally: 

1436 """Usage: with xfinally(lambda: cleanup()): ... 

1437 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never 

1438 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible. 

1439 

1440 Problem it solves 

1441 ----------------- 

1442 A naive ``try ... finally`` may lose the original exception: 

1443 

1444 try: 

1445 work() 

1446 finally: 

1447 cleanup() # <-- if this raises an exception, it replaces the real error! 

1448 

1449 `_XFinally` preserves exception priority: 

1450 

1451 * Body raises, cleanup succeeds --> original body exception is re-raised. 

1452 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``. 

1453 * Body succeeds, cleanup raises --> cleanup exception propagates normally. 

1454 

1455 Example: 

1456 ------- 

1457 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP 

1458 ... run_tasks() 

1459 

1460 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information. 

1461 """ 

1462 return _XFinally(cleanup)