Coverage for bzfs_main/utils.py: 100%

723 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-07 04:44 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight 

16concurrency primitives. 

17 

18Everything in this module relies only on the standard library so other modules remain dependency free. Each utility favors 

19simple, predictable behavior on all supported platforms. 

20""" 

21 

22from __future__ import ( 

23 annotations, 

24) 

25import argparse 

26import base64 

27import bisect 

28import collections 

29import contextlib 

30import errno 

31import hashlib 

32import logging 

33import os 

34import platform 

35import pwd 

36import random 

37import re 

38import signal 

39import stat 

40import subprocess 

41import sys 

42import threading 

43import time 

44import types 

45from collections import ( 

46 defaultdict, 

47 deque, 

48) 

49from collections.abc import ( 

50 ItemsView, 

51 Iterable, 

52 Iterator, 

53 Sequence, 

54) 

55from concurrent.futures import ( 

56 Executor, 

57 Future, 

58 ThreadPoolExecutor, 

59) 

60from datetime import ( 

61 datetime, 

62 timedelta, 

63 timezone, 

64 tzinfo, 

65) 

66from subprocess import ( 

67 DEVNULL, 

68 PIPE, 

69) 

70from typing import ( 

71 IO, 

72 Any, 

73 Callable, 

74 Final, 

75 Generic, 

76 Literal, 

77 NoReturn, 

78 Protocol, 

79 TextIO, 

80 TypeVar, 

81 cast, 

82) 

83 

84# constants: 

85PROG_NAME: Final[str] = "bzfs" 

86ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_" 

87DIE_STATUS: Final[int] = 3 

88DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset 

89LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between 

90LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between 

91LOG_DEBUG: Final[int] = logging.DEBUG 

92LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between 

93SNAPSHOT_FILTERS_VAR: Final[str] = "snapshot_filters_var" 

94YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target 

95UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID 

96DONT_SKIP_DATASET: Final[str] = "" 

97SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\" 

98FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write) 

99DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute) 

100UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions 

101UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path' 

102 

103RegexList = list[tuple[re.Pattern[str], bool]] # Type alias 

104 

105 

106def getenv_any(key: str, default: str | None = None) -> str | None: 

107 """All shell environment variable names used for configuration start with this prefix.""" 

108 return os.getenv(ENV_VAR_PREFIX + key, default) 

109 

110 

111def getenv_int(key: str, default: int) -> int: 

112 """Returns environment variable ``key`` as int with ``default`` fallback.""" 

113 return int(cast(str, getenv_any(key, str(default)))) 

114 

115 

116def getenv_bool(key: str, default: bool = False) -> bool: 

117 """Returns environment variable ``key`` as bool with ``default`` fallback.""" 

118 return cast(str, getenv_any(key, str(default))).lower().strip() == "true" 

119 

120 

121def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]: 

122 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command.""" 

123 assert lines is not None 

124 assert isinstance(lines, list) 

125 assert len(separator) == 1 

126 if field == 1: 

127 return [line[0 : line.index(separator)] for line in lines] 

128 elif field == 2: 

129 return [line[line.index(separator) + 1 :] for line in lines] 

130 else: 

131 raise ValueError(f"Invalid field value: {field}") 

132 

133 

134def drain(iterable: Iterable[Any]) -> None: 

135 """Consumes all items in the iterable, effectively draining it.""" 

136 for _ in iterable: 

137 _ = None # help gc (iterable can block) 

138 

139 

140K_ = TypeVar("K_") 

141V_ = TypeVar("V_") 

142R_ = TypeVar("R_") 

143 

144 

145def shuffle_dict(dictionary: dict[K_, V_], rand: random.Random = random.SystemRandom()) -> dict[K_, V_]: # noqa: B008 

146 """Returns a new dict with items shuffled randomly.""" 

147 items: list[tuple[K_, V_]] = list(dictionary.items()) 

148 rand.shuffle(items) 

149 return dict(items) 

150 

151 

152def sorted_dict(dictionary: dict[K_, V_]) -> dict[K_, V_]: 

153 """Returns a new dict with items sorted primarily by key and secondarily by value.""" 

154 return dict(sorted(dictionary.items())) 

155 

156 

157def tail(file: str, n: int, errors: str | None = None) -> Sequence[str]: 

158 """Return the last ``n`` lines of ``file`` without following symlinks.""" 

159 if not os.path.isfile(file): 

160 return [] 

161 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False) as fd: 

162 return deque(fd, maxlen=n) 

163 

164 

165NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">")) 

166 

167 

168def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str: 

169 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky). 

170 

171 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\(' 

172 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash 

173 with the replacement string '(?:' 

174 

175 Named capturing group example: '(?P<name>abc)' --> '(?:abc)' 

176 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash 

177 with the replacement string '(?:' 

178 

179 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups 

180 """ 

181 if "(" in regex and ( 

182 "[" in regex # literal left square bracket 

183 or "\\N{LEFT SQUARE BRACKET}" in regex # named Unicode escape for '[' 

184 or "\\x5b" in regex # hex escape for '[' (lowercase) 

185 or "\\x5B" in regex # hex escape for '[' (uppercase) 

186 or "\\u005b" in regex # 4-digit Unicode escape for '[' (lowercase) 

187 or "\\u005B" in regex # 4-digit Unicode escape for '[' (uppercase) 

188 or "\\U0000005b" in regex # 8-digit Unicode escape for '[' (lowercase) 

189 or "\\U0000005B" in regex # 8-digit Unicode escape for '[' (uppercase) 

190 or "\\133" in regex # octal escape for '[' 

191 ): 

192 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the rare case where the regex might 

193 # contain a pathological regex character class that contains parenthesis, or when '[' is expressed via escapes. 

194 # Rewriting a regex is a performance optimization; correctness comes first. 

195 return regex 

196 

197 i = len(regex) - 2 

198 while i >= 0: 

199 i = regex.rfind("(", 0, i + 1) 

200 if i >= 0 and (i == 0 or regex[i - 1] != "\\"): 

201 if regex[i + 1] != "?": 

202 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group 

203 else: # potentially a valid named capturing group 

204 regex = regex[0:i] + NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1) 

205 i -= 1 

206 return regex 

207 

208 

209def get_home_directory() -> str: 

210 """Reliably detects home dir without using HOME env var.""" 

211 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~') 

212 return pwd.getpwuid(os.getuid()).pw_dir 

213 

214 

215def human_readable_bytes(num_bytes: float, separator: str = " ", precision: int | None = None) -> str: 

216 """Formats 'num_bytes' as a human-readable size; for example "567 MiB".""" 

217 sign = "-" if num_bytes < 0 else "" 

218 s = abs(num_bytes) 

219 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB") 

220 n = len(units) - 1 

221 i = 0 

222 while s >= 1024 and i < n: 

223 s /= 1024 

224 i += 1 

225 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}" 

226 return f"{sign}{formatted_num}{separator}{units[i]}" 

227 

228 

229def human_readable_duration(duration: float, unit: str = "ns", separator: str = "", precision: int | None = None) -> str: 

230 """Formats a duration in human units, automatically scaling as needed; for example "567ms".""" 

231 sign = "-" if duration < 0 else "" 

232 t = abs(duration) 

233 units = ("ns", "μs", "ms", "s", "m", "h", "d") 

234 i = units.index(unit) 

235 if t < 1 and t != 0: 

236 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000) 

237 t *= nanos[i] 

238 i = 0 

239 while t >= 1000 and i < 3: 

240 t /= 1000 

241 i += 1 

242 if i >= 3: 

243 while t >= 60 and i < 5: 

244 t /= 60 

245 i += 1 

246 if i >= 5: 

247 while t >= 24 and i < len(units) - 1: 

248 t /= 24 

249 i += 1 

250 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}" 

251 return f"{sign}{formatted_num}{separator}{units[i]}" 

252 

253 

254def human_readable_float(number: float) -> str: 

255 """Formats ``number`` with a variable precision depending on magnitude. 

256 

257 This design mirrors the way humans round values when scanning logs. 

258 

259 If the number has one digit before the decimal point (0 <= abs(number) < 10): 

260 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15"). 

261 

262 If the number has two digits before the decimal point (10 <= abs(number) < 100): 

263 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4"). 

264 

265 If the number has three or more digits before the decimal point (abs(number) >= 100): 

266 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124"). 

267 

268 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1" 

269 """ 

270 abs_number = abs(number) 

271 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0 

272 if precision == 0: 

273 return str(round(number)) 

274 result = f"{number:.{precision}f}" 

275 assert "." in result 

276 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty 

277 return "0" if result == "-0" else result 

278 

279 

280def percent(number: int, total: int, print_total: bool = False) -> str: 

281 """Returns percentage string of ``number`` relative to ``total``.""" 

282 tot: str = f"/{total}" if print_total else "" 

283 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%" 

284 

285 

286def open_nofollow( 

287 path: str, 

288 mode: str = "r", 

289 buffering: int = -1, 

290 encoding: str | None = None, 

291 errors: str | None = None, 

292 newline: str | None = None, 

293 *, 

294 perm: int = FILE_PERMISSIONS, 

295 check_owner: bool = True, 

296 **kwargs: Any, 

297) -> IO[Any]: 

298 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with 

299 errno.ELOOP/EMLINK if basename of path is a symlink. 

300 

301 Also, can specify permissions on O_CREAT, and verify ownership. 

302 """ 

303 if not mode: 

304 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus") 

305 flags = { 

306 "r": os.O_RDONLY, 

307 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 

308 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND, 

309 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL, 

310 }.get(mode[0]) 

311 if flags is None: 

312 raise ValueError(f"invalid mode {mode!r}") 

313 if "+" in mode: # enable read-write access for r+, w+, a+, x+ 

314 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags 

315 flags |= os.O_NOFOLLOW | os.O_CLOEXEC 

316 fd: int = os.open(path, flags=flags, mode=perm) 

317 try: 

318 if check_owner: 

319 st_uid: int = os.fstat(fd).st_uid 

320 if st_uid != os.geteuid(): # verify ownership is current effective UID 

321 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path) 

322 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs) 

323 except Exception: 

324 try: 

325 os.close(fd) 

326 except OSError: 

327 pass 

328 raise 

329 

330 

331def close_quietly(fd: int) -> None: 

332 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this.""" 

333 if fd >= 0: 

334 try: 

335 os.close(fd) 

336 except OSError: 

337 pass 

338 

339 

340P = TypeVar("P") 

341 

342 

343def find_match( 

344 seq: Sequence[P], 

345 predicate: Callable[[P], bool], 

346 start: int | None = None, 

347 end: int | None = None, 

348 reverse: bool = False, 

349 raises: bool | str | Callable[[], str] = False, # raises: bool | str | Callable = False, # python >= 3.10 

350) -> int: 

351 """Returns the integer index within seq of the first item (or last item if reverse==True) that matches the given 

352 predicate condition. If no matching item is found returns -1 or ValueError, depending on the raises parameter, which is a 

353 bool indicating whether to raise an error, or a string containing the error message, but can also be a Callable/lambda in 

354 order to support efficient deferred generation of error messages. Analog to str.find(), including slicing semantics with 

355 parameters start and end. For example, seq can be a list, tuple or str. 

356 

357 Example usage: 

358 lst = ["a", "b", "-c", "d"] 

359 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True) 

360 if i >= 0: 

361 ... 

362 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}") 

363 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}") 

364 """ 

365 offset: int = 0 if start is None else start if start >= 0 else len(seq) + start 

366 if start is not None or end is not None: 

367 seq = seq[start:end] 

368 for i, item in enumerate(reversed(seq) if reverse else seq): 

369 if predicate(item): 

370 if reverse: 

371 return len(seq) - i - 1 + offset 

372 else: 

373 return i + offset 

374 if raises is False or raises is None: 

375 return -1 

376 if raises is True: 

377 raise ValueError("No matching item found in sequence") 

378 if callable(raises): 

379 raises = raises() 

380 raise ValueError(raises) 

381 

382 

383def is_descendant(dataset: str, of_root_dataset: str) -> bool: 

384 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same.""" 

385 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/") 

386 

387 

388def has_duplicates(sorted_list: list[Any]) -> bool: 

389 """Returns True if any adjacent items within the given sorted sequence are equal.""" 

390 return any(a == b for a, b in zip(sorted_list, sorted_list[1:])) 

391 

392 

393def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool: 

394 """Returns whether the (sorted) list of ZFS input datasets contains any siblings.""" 

395 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted" 

396 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates" 

397 skip_dataset: str = DONT_SKIP_DATASET 

398 parents: set[str] = set() 

399 for dataset in sorted_datasets: 

400 assert dataset 

401 parent = os.path.dirname(dataset) 

402 if parent in parents: 

403 return True # I have a sibling if my parent already has another child 

404 parents.add(parent) 

405 if is_descendant(dataset, of_root_dataset=skip_dataset): 

406 continue 

407 if skip_dataset != DONT_SKIP_DATASET: 

408 return True # I have a sibling if I am a root dataset and another root dataset already exists 

409 skip_dataset = dataset 

410 return False 

411 

412 

413def dry(msg: str, is_dry_run: bool) -> str: 

414 """Prefix ``msg`` with 'Dry' when in dry-run mode.""" 

415 return "Dry " + msg if is_dry_run else msg 

416 

417 

418def relativize_dataset(dataset: str, root_dataset: str) -> str: 

419 """Converts an absolute dataset path to one relative to ``root_dataset``. 

420 

421 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz. 

422 """ 

423 return dataset[len(root_dataset) :] 

424 

425 

426def dataset_paths(dataset: str) -> Iterator[str]: 

427 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c".""" 

428 i: int = 0 

429 while i >= 0: 

430 i = dataset.find("/", i) 

431 if i < 0: 

432 yield dataset 

433 else: 

434 yield dataset[:i] 

435 i += 1 

436 

437 

438def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str: 

439 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present.""" 

440 assert s.startswith(old_prefix) 

441 return new_prefix + s[len(old_prefix) :] 

442 

443 

444def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None: 

445 """Replaces ``old`` with ``new`` in-place for every string in ``lines``.""" 

446 for i in range(len(lines)): 

447 lines[i] = lines[i].replace(old, new, count) 

448 

449 

450TAPPEND = TypeVar("TAPPEND") 

451 

452 

453def append_if_absent(lst: list[TAPPEND], *items: TAPPEND) -> list[TAPPEND]: 

454 """Appends items to list if they are not already present.""" 

455 for item in items: 

456 if item not in lst: 

457 lst.append(item) 

458 return lst 

459 

460 

461def xappend(lst: list[TAPPEND], *items: TAPPEND | Iterable[TAPPEND]) -> list[TAPPEND]: 

462 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If 

463 an item is an iterable does so recursively, flattening the output.""" 

464 for item in items: 

465 if isinstance(item, str) or not isinstance(item, collections.abc.Iterable): 

466 if item: 

467 lst.append(item) 

468 else: 

469 xappend(lst, *item) 

470 return lst 

471 

472 

473def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool: 

474 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False. 

475 

476 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match. 

477 """ 

478 for regex, is_negation in exclude_regexes: 

479 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

480 if is_negation: 

481 is_match = not is_match 

482 if is_match: 

483 return False 

484 

485 for regex, is_negation in include_regexes: 

486 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

487 if is_negation: 

488 is_match = not is_match 

489 if is_match: 

490 return True 

491 

492 return False 

493 

494 

495def compile_regexes(regexes: list[str], suffix: str = "") -> RegexList: 

496 """Compiles regex strings and keeps track of negations.""" 

497 assert isinstance(regexes, list) 

498 compiled_regexes: RegexList = [] 

499 for regex in regexes: 

500 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match 

501 if regex.endswith("\\$"): 

502 pass # trailing literal $ is ok 

503 elif regex.endswith("$"): 

504 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch() 

505 elif "$" in regex: 

506 raise re.error("Must not use non-trailing '$' character", regex) 

507 if is_negation := regex.startswith("!"): 

508 regex = regex[1:] 

509 regex = replace_capturing_groups_with_non_capturing_groups(regex) 

510 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")): 

511 regex = f"{regex}{suffix}" 

512 compiled_regexes.append((re.compile(regex), is_negation)) 

513 return compiled_regexes 

514 

515 

516def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any: 

517 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels.""" 

518 

519 class CustomListFormatter: 

520 """Formatter object that joins items when converted to ``str``.""" 

521 

522 def __str__(self) -> str: 

523 s = separator.join(map(str, iterable)) 

524 return s.lstrip() if lstrip else s 

525 

526 return CustomListFormatter() 

527 

528 

529def pretty_print_formatter(obj_to_format: Any) -> Any: 

530 """Lazy pprint formatter used to avoid overhead in disabled log levels.""" 

531 

532 class PrettyPrintFormatter: 

533 """Formatter that pretty-prints the object on conversion to ``str``.""" 

534 

535 def __str__(self) -> str: 

536 import pprint # lazy import for startup perf 

537 

538 return pprint.pformat(vars(obj_to_format)) 

539 

540 return PrettyPrintFormatter() 

541 

542 

543def stderr_to_str(stderr: Any) -> str: 

544 """Workaround for https://github.com/python/cpython/issues/87597.""" 

545 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace") 

546 

547 

548def xprint(log: logging.Logger, value: Any, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None: 

549 """Optionally logs ``value`` at stdout/stderr level.""" 

550 if run and value: 

551 value = value if end else str(value).rstrip() 

552 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR 

553 log.log(level, "%s", value) 

554 

555 

556def sha256_hex(text: str) -> str: 

557 """Returns the sha256 hex string for the given text.""" 

558 return hashlib.sha256(text.encode()).hexdigest() 

559 

560 

561def sha256_urlsafe_base64(text: str, padding: bool = True) -> str: 

562 """Returns the URL-safe base64-encoded sha256 value for the given text.""" 

563 digest: bytes = hashlib.sha256(text.encode()).digest() 

564 s: str = base64.urlsafe_b64encode(digest).decode() 

565 return s if padding else s.rstrip("=") 

566 

567 

568def sha256_128_urlsafe_base64(text: str) -> str: 

569 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

570 s: str = sha256_urlsafe_base64(text, padding=False) 

571 return s[: len(s) // 2] 

572 

573 

574def sha256_85_urlsafe_base64(text: str) -> str: 

575 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

576 s: str = sha256_urlsafe_base64(text, padding=False) 

577 return s[: len(s) // 3] 

578 

579 

580def urlsafe_base64( 

581 value: int, max_value: int = 2**64 - 1, padding: bool = True, byteorder: Literal["little", "big"] = "big" 

582) -> str: 

583 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value].""" 

584 assert 0 <= value <= max_value 

585 max_bytes: int = (max_value.bit_length() + 7) // 8 

586 value_bytes: bytes = value.to_bytes(max_bytes, byteorder) 

587 s: str = base64.urlsafe_b64encode(value_bytes).decode() 

588 return s if padding else s.rstrip("=") 

589 

590 

591def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn: 

592 """Exits the program with ``exit_code`` after logging ``msg``.""" 

593 if parser is None: 

594 ex = SystemExit(msg) 

595 ex.code = exit_code 

596 raise ex 

597 else: 

598 parser.error(msg) 

599 

600 

601def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

602 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and 

603 provides optional child PID tracking.""" 

604 input_value = kwargs.pop("input", None) 

605 timeout = kwargs.pop("timeout", None) 

606 check = kwargs.pop("check", False) 

607 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None) 

608 if input_value is not None: 

609 if kwargs.get("stdin") is not None: 

610 raise ValueError("input and stdin are mutually exclusive") 

611 kwargs["stdin"] = subprocess.PIPE 

612 

613 pid: int | None = None 

614 try: 

615 with subprocess.Popen(*args, **kwargs) as proc: 

616 pid = proc.pid 

617 if subprocesses is not None: 

618 subprocesses.register_child_pid(pid) 

619 try: 

620 stdout, stderr = proc.communicate(input_value, timeout=timeout) 

621 except BaseException as e: 

622 try: 

623 if isinstance(e, subprocess.TimeoutExpired): 

624 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child process and its descendants 

625 finally: 

626 proc.kill() 

627 raise 

628 else: 

629 exitcode: int | None = proc.poll() 

630 assert exitcode is not None 

631 if check and exitcode: 

632 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr) 

633 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr) 

634 finally: 

635 if subprocesses is not None and isinstance(pid, int): 

636 subprocesses.unregister_child_pid(pid) 

637 

638 

639def terminate_process_subtree( 

640 except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM 

641) -> None: 

642 """For each root PID: Sends the given signal to the root PID and all its descendant processes.""" 

643 current_pid: int = os.getpid() 

644 root_pids = [current_pid] if root_pids is None else root_pids 

645 all_pids: list[list[int]] = _get_descendant_processes(root_pids) 

646 assert len(all_pids) == len(root_pids) 

647 for i, pids in enumerate(all_pids): 

648 root_pid = root_pids[i] 

649 if root_pid == current_pid: 

650 pids += [] if except_current_process else [current_pid] 

651 else: 

652 pids.insert(0, root_pid) 

653 for pid in pids: 

654 with contextlib.suppress(OSError): 

655 os.kill(pid, sig) 

656 

657 

658def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]: 

659 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems.""" 

660 if len(root_pids) == 0: 

661 return [] 

662 cmd: list[str] = ["ps", "-Ao", "pid,ppid"] 

663 try: 

664 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines() 

665 except PermissionError: 

666 # degrade gracefully in sandbox environments that deny executing `ps` entirely 

667 return [[] for _ in root_pids] 

668 procs: dict[int, list[int]] = defaultdict(list) 

669 for line in lines[1:]: # all lines except the header line 

670 splits: list[str] = line.split() 

671 assert len(splits) == 2 

672 pid = int(splits[0]) 

673 ppid = int(splits[1]) 

674 procs[ppid].append(pid) 

675 

676 def recursive_append(ppid: int, descendants: list[int]) -> None: 

677 """Recursively collect descendant PIDs starting from ``ppid``.""" 

678 for child_pid in procs[ppid]: 

679 descendants.append(child_pid) 

680 recursive_append(child_pid, descendants) 

681 

682 all_descendants: list[list[int]] = [] 

683 for root_pid in root_pids: 

684 descendants: list[int] = [] 

685 recursive_append(root_pid, descendants) 

686 all_descendants.append(descendants) 

687 return all_descendants 

688 

689 

690@contextlib.contextmanager 

691def termination_signal_handler( 

692 termination_event: threading.Event, 

693 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(), 

694) -> Iterator[None]: 

695 """Context manager that installs SIGINT/SIGTERM handlers that set ``termination_event`` and, by default, terminate all 

696 descendant processes.""" 

697 assert termination_event is not None 

698 

699 def _handler(_sig: int, _frame: object) -> None: 

700 termination_event.set() 

701 termination_handler() 

702 

703 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler 

704 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler 

705 try: 

706 yield # run body of context manager 

707 finally: 

708 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler 

709 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler 

710 

711 

712############################################################################# 

713class Subprocesses: 

714 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used 

715 when multiple jobs run concurrently within the same Python process.""" 

716 

717 def __init__(self) -> None: 

718 self._lock: Final[threading.Lock] = threading.Lock() 

719 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order 

720 

721 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

722 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination.""" 

723 return subprocess_run(*args, **kwargs, subprocesses=self) 

724 

725 def register_child_pid(self, pid: int) -> None: 

726 """Registers a child PID as managed by this instance.""" 

727 with self._lock: 

728 self._child_pids[pid] = None 

729 

730 def unregister_child_pid(self, pid: int) -> None: 

731 """Unregisters a child PID that has exited or is no longer tracked.""" 

732 with self._lock: 

733 self._child_pids.pop(pid, None) 

734 

735 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None: 

736 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs.""" 

737 with self._lock: 

738 pids: list[int] = list(self._child_pids) 

739 self._child_pids.clear() 

740 terminate_process_subtree(root_pids=pids, sig=sig) 

741 

742 

743############################################################################# 

744def pid_exists(pid: int) -> bool | None: 

745 """Returns True if a process with PID exists, False if not, or None on error.""" 

746 if pid <= 0: 

747 return False 

748 try: # with signal=0, no signal is actually sent, but error checking is still performed 

749 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems 

750 except OSError as err: 

751 if err.errno == errno.ESRCH: # No such process 

752 return False 

753 if err.errno == errno.EPERM: # Operation not permitted 

754 return True 

755 return None 

756 return True 

757 

758 

759def nprefix(s: str) -> str: 

760 """Returns a canonical snapshot prefix with trailing underscore.""" 

761 return sys.intern(s + "_") 

762 

763 

764def ninfix(s: str) -> str: 

765 """Returns a canonical infix with trailing underscore when not empty.""" 

766 return sys.intern(s + "_") if s else "" 

767 

768 

769def nsuffix(s: str) -> str: 

770 """Returns a canonical suffix with leading underscore when not empty.""" 

771 return sys.intern("_" + s) if s else "" 

772 

773 

774def format_dict(dictionary: dict[Any, Any]) -> str: 

775 """Returns a formatted dictionary using repr for consistent output.""" 

776 return f'"{dictionary}"' 

777 

778 

779def format_obj(obj: object) -> str: 

780 """Returns a formatted str using repr for consistent output.""" 

781 return f'"{obj}"' 

782 

783 

784def validate_dataset_name(dataset: str, input_text: str) -> None: 

785 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc.""" 

786 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424 

787 # and https://github.com/openzfs/zfs/issues/8798 

788 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html 

789 invalid_chars: str = SHELL_CHARS 

790 if ( 

791 dataset in ("", ".", "..") 

792 or dataset.startswith(("/", "./", "../")) 

793 or dataset.endswith(("/", "/.", "/..")) 

794 or any(substring in dataset for substring in ("//", "/./", "/../")) 

795 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset) 

796 or not dataset[0].isalpha() 

797 ): 

798 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'") 

799 

800 

801def validate_property_name(propname: str, input_text: str) -> str: 

802 """Checks that the ZFS property name contains no spaces or shell chars.""" 

803 invalid_chars: str = SHELL_CHARS 

804 if not propname or any(char.isspace() or char in invalid_chars for char in propname): 

805 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'") 

806 return propname 

807 

808 

809def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None: 

810 """Checks that the given path is not a symbolic link.""" 

811 if os.path.islink(path): 

812 die(f"{msg}must not be a symlink: {path}", parser=parser) 

813 

814 

815def validate_file_permissions(path: str, mode: int) -> None: 

816 """Verify permissions and that ownership is current effective UID.""" 

817 stats: os.stat_result = os.stat(path, follow_symlinks=False) 

818 st_uid: int = stats.st_uid 

819 if st_uid != os.geteuid(): # verify ownership is current effective UID 

820 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}") 

821 st_mode = stat.S_IMODE(stats.st_mode) 

822 if st_mode != mode: 

823 die( 

824 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, " 

825 f"not {mode:03o} aka {stat.filemode(mode)[1:]})" 

826 ) 

827 

828 

829def parse_duration_to_milliseconds(duration: str, regex_suffix: str = "", context: str = "") -> int: 

830 """Parses human duration strings like '5m' or '2 hours' to milliseconds.""" 

831 unit_milliseconds: dict[str, int] = { 

832 "milliseconds": 1, 

833 "millis": 1, 

834 "seconds": 1000, 

835 "secs": 1000, 

836 "minutes": 60 * 1000, 

837 "mins": 60 * 1000, 

838 "hours": 60 * 60 * 1000, 

839 "days": 86400 * 1000, 

840 "weeks": 7 * 86400 * 1000, 

841 "months": round(30.5 * 86400 * 1000), 

842 "years": 365 * 86400 * 1000, 

843 } 

844 match = re.fullmatch( 

845 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix, 

846 duration, 

847 ) 

848 if not match: 

849 if context: 

850 die(f"Invalid duration format: {duration} within {context}") 

851 else: 

852 raise ValueError(f"Invalid duration format: {duration}") 

853 assert match 

854 quantity: int = int(match.group(1)) 

855 unit: str = match.group(2) 

856 return quantity * unit_milliseconds[unit] 

857 

858 

859def unixtime_fromisoformat(datetime_str: str) -> int: 

860 """Converts ISO 8601 datetime string into UTC Unix time seconds.""" 

861 return int(datetime.fromisoformat(datetime_str).timestamp()) 

862 

863 

864def isotime_from_unixtime(unixtime_in_seconds: int) -> str: 

865 """Converts UTC Unix time seconds into ISO 8601 datetime string.""" 

866 tz: tzinfo = timezone.utc 

867 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz) 

868 return dt.isoformat(sep="_", timespec="seconds") 

869 

870 

871def current_datetime( 

872 tz_spec: str | None = None, 

873 now_fn: Callable[[tzinfo | None], datetime] | None = None, 

874) -> datetime: 

875 """Returns current time in ``tz_spec`` timezone or local timezone.""" 

876 if now_fn is None: 

877 now_fn = datetime.now 

878 return now_fn(get_timezone(tz_spec)) 

879 

880 

881def get_timezone(tz_spec: str | None = None) -> tzinfo | None: 

882 """Returns timezone from spec or local timezone if unspecified.""" 

883 tz: tzinfo | None 

884 if tz_spec is None: 

885 tz = None 

886 elif tz_spec == "UTC": 

887 tz = timezone.utc 

888 else: 

889 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec): 

890 sign, hours, minutes = match.groups() 

891 offset: int = int(hours) * 60 + int(minutes) 

892 offset = -offset if sign == "-" else offset 

893 tz = timezone(timedelta(minutes=offset)) 

894 elif "/" in tz_spec: 

895 from zoneinfo import ZoneInfo # lazy import for startup perf 

896 

897 tz = ZoneInfo(tz_spec) 

898 else: 

899 raise ValueError(f"Invalid timezone specification: {tz_spec}") 

900 return tz 

901 

902 

903############################################################################### 

904class SnapshotPeriods: # thread-safe 

905 """Parses snapshot suffix strings and converts between durations.""" 

906 

907 def __init__(self) -> None: 

908 """Initialize lookup tables of suffixes and corresponding millis.""" 

909 self.suffix_milliseconds: Final[dict[str, int]] = { 

910 "yearly": 365 * 86400 * 1000, 

911 "monthly": round(30.5 * 86400 * 1000), 

912 "weekly": 7 * 86400 * 1000, 

913 "daily": 86400 * 1000, 

914 "hourly": 60 * 60 * 1000, 

915 "minutely": 60 * 1000, 

916 "secondly": 1000, 

917 "millisecondly": 1, 

918 } 

919 self.period_labels: Final[dict[str, str]] = { 

920 "yearly": "years", 

921 "monthly": "months", 

922 "weekly": "weeks", 

923 "daily": "days", 

924 "hourly": "hours", 

925 "minutely": "minutes", 

926 "secondly": "seconds", 

927 "millisecondly": "milliseconds", 

928 } 

929 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})") 

930 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern) 

931 

932 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]: 

933 """Parse suffix like '10minutely' to (10, 'minutely').""" 

934 return self._suffix_to_duration(suffix, self._suffix_regex0) 

935 

936 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]: 

937 """Like :meth:`suffix_to_duration0` but expects an underscore prefix.""" 

938 return self._suffix_to_duration(suffix, self._suffix_regex1) 

939 

940 @staticmethod 

941 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]: 

942 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly').""" 

943 if match := regex.fullmatch(suffix): 

944 duration_amount: int = int(match.group(1)) if match.group(1) else 1 

945 assert duration_amount > 0 

946 duration_unit: str = match.group(2) 

947 return duration_amount, duration_unit 

948 else: 

949 return 0, "" 

950 

951 def label_milliseconds(self, snapshot: str) -> int: 

952 """Returns duration encoded in ``snapshot`` suffix, in milliseconds.""" 

953 i = snapshot.rfind("_") 

954 snapshot = "" if i < 0 else snapshot[i + 1 :] 

955 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0) 

956 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0) 

957 

958 

959############################################################################# 

960class JobStats: 

961 """Simple thread-safe counters summarizing job progress.""" 

962 

963 def __init__(self, jobs_all: int) -> None: 

964 assert jobs_all >= 0 

965 self.lock: Final[threading.Lock] = threading.Lock() 

966 self.jobs_all: int = jobs_all 

967 self.jobs_started: int = 0 

968 self.jobs_completed: int = 0 

969 self.jobs_failed: int = 0 

970 self.jobs_running: int = 0 

971 self.sum_elapsed_nanos: int = 0 

972 self.started_job_names: Final[set[str]] = set() 

973 

974 def submit_job(self, job_name: str) -> str: 

975 """Counts a job submission.""" 

976 with self.lock: 

977 self.jobs_started += 1 

978 self.jobs_running += 1 

979 self.started_job_names.add(job_name) 

980 return str(self) 

981 

982 def complete_job(self, failed: bool, elapsed_nanos: int) -> str: 

983 """Counts a job completion.""" 

984 assert elapsed_nanos >= 0 

985 with self.lock: 

986 self.jobs_running -= 1 

987 self.jobs_completed += 1 

988 self.jobs_failed += 1 if failed else 0 

989 self.sum_elapsed_nanos += elapsed_nanos 

990 msg = str(self) 

991 assert self.sum_elapsed_nanos >= 0, msg 

992 assert self.jobs_running >= 0, msg 

993 assert self.jobs_failed >= 0, msg 

994 assert self.jobs_failed <= self.jobs_completed, msg 

995 assert self.jobs_completed <= self.jobs_started, msg 

996 assert self.jobs_started <= self.jobs_all, msg 

997 return msg 

998 

999 def __repr__(self) -> str: 

1000 def pct(number: int) -> str: 

1001 """Returns percentage string relative to total jobs.""" 

1002 return percent(number, total=self.jobs_all, print_total=True) 

1003 

1004 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed 

1005 running = self.jobs_running 

1006 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed)) 

1007 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}" 

1008 

1009 

1010############################################################################# 

1011class Comparable(Protocol): 

1012 """Partial ordering protocol.""" 

1013 

1014 def __lt__(self, other: Any) -> bool: # pragma: no cover - behavior defined by implementer 

1015 ... 

1016 

1017 

1018T = TypeVar("T", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue 

1019 

1020 

1021class SmallPriorityQueue(Generic[T]): 

1022 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and 

1023 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case 

1024 for us. 

1025 

1026 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed 

1027 priority queue via 

1028 https://github.com/nvictus/pqdict. 

1029 But, to avoid an external dependency, is actually implemented 

1030 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that 

1031 are already contained in the queue, via removal of the element, followed by update of the element, followed by 

1032 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates. 

1033 """ 

1034 

1035 def __init__(self, reverse: bool = False) -> None: 

1036 """Creates an empty queue; sort order flips when ``reverse`` is True.""" 

1037 self._lst: Final[list[T]] = [] 

1038 self._reverse: Final[bool] = reverse 

1039 

1040 def clear(self) -> None: 

1041 """Removes all elements from the queue.""" 

1042 self._lst.clear() 

1043 

1044 def push(self, element: T) -> None: 

1045 """Inserts ``element`` while maintaining sorted order.""" 

1046 bisect.insort(self._lst, element) 

1047 

1048 def pop(self) -> T: 

1049 """Removes and returns the smallest (or largest if reverse == True) element from the queue.""" 

1050 return self._lst.pop() if self._reverse else self._lst.pop(0) 

1051 

1052 def peek(self) -> T: 

1053 """Returns the smallest (or largest if reverse == True) element without removing it.""" 

1054 return self._lst[-1] if self._reverse else self._lst[0] 

1055 

1056 def remove(self, element: T) -> bool: 

1057 """Removes the first occurrence of ``element`` and returns True if it was present.""" 

1058 lst = self._lst 

1059 i = bisect.bisect_left(lst, element) 

1060 is_contained = i < len(lst) and lst[i] == element 

1061 if is_contained: 

1062 del lst[i] # is an optimized memmove() 

1063 return is_contained 

1064 

1065 def __len__(self) -> int: 

1066 """Returns the number of queued elements.""" 

1067 return len(self._lst) 

1068 

1069 def __contains__(self, element: T) -> bool: 

1070 """Returns ``True`` if ``element`` is present.""" 

1071 lst = self._lst 

1072 i = bisect.bisect_left(lst, element) 

1073 return i < len(lst) and lst[i] == element 

1074 

1075 def __iter__(self) -> Iterator[T]: 

1076 """Iterates over queued elements in priority order.""" 

1077 return reversed(self._lst) if self._reverse else iter(self._lst) 

1078 

1079 def __repr__(self) -> str: 

1080 """Representation showing queue contents in current order.""" 

1081 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst) 

1082 

1083 

1084############################################################################### 

1085class SortedInterner(Generic[T]): 

1086 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search).""" 

1087 

1088 def __init__(self, sorted_list: list[T]) -> None: 

1089 self._lst: Final[list[T]] = sorted_list 

1090 

1091 def interned(self, element: T) -> T: 

1092 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1093 lst = self._lst 

1094 i = binary_search(lst, element) 

1095 return lst[i] if i >= 0 else element 

1096 

1097 def __contains__(self, element: T) -> bool: 

1098 """Returns ``True`` if ``element`` is present.""" 

1099 return binary_search(self._lst, element) >= 0 

1100 

1101 

1102def binary_search(sorted_list: list[T], item: T) -> int: 

1103 """Java-style binary search; Returns index >=0 if an equal item is found in list, else '-insertion_point-1'; If it 

1104 returns index >=0, the index will be the left-most index in case multiple such equal items are contained.""" 

1105 i = bisect.bisect_left(sorted_list, item) 

1106 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1 

1107 

1108 

1109############################################################################### 

1110S = TypeVar("S") 

1111 

1112 

1113class Interner(Generic[S]): 

1114 """Same as sys.intern() except that it isn't global and can also be used for types other than str.""" 

1115 

1116 def __init__(self, items: Iterable[S] = frozenset()) -> None: 

1117 self._items: Final[dict[S, S]] = {v: v for v in items} 

1118 

1119 def intern(self, item: S) -> S: 

1120 """Interns the given item.""" 

1121 return self._items.setdefault(item, item) 

1122 

1123 def interned(self, item: S) -> S: 

1124 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1125 return self._items.get(item, item) 

1126 

1127 def __contains__(self, item: S) -> bool: 

1128 return item in self._items 

1129 

1130 

1131############################################################################# 

1132class SynchronizedBool: 

1133 """Thread-safe wrapper around a regular bool.""" 

1134 

1135 def __init__(self, val: bool) -> None: 

1136 assert isinstance(val, bool) 

1137 self._lock: Final[threading.Lock] = threading.Lock() 

1138 self._value: bool = val 

1139 

1140 @property 

1141 def value(self) -> bool: 

1142 """Returns the current boolean value.""" 

1143 with self._lock: 

1144 return self._value 

1145 

1146 @value.setter 

1147 def value(self, new_value: bool) -> None: 

1148 """Atomically assign ``new_value``.""" 

1149 with self._lock: 

1150 self._value = new_value 

1151 

1152 def get_and_set(self, new_value: bool) -> bool: 

1153 """Swaps in ``new_value`` and return the previous value.""" 

1154 with self._lock: 

1155 old_value = self._value 

1156 self._value = new_value 

1157 return old_value 

1158 

1159 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool: 

1160 """Sets to ``new_value`` only if current value equals ``expected_value``.""" 

1161 with self._lock: 

1162 eq: bool = self._value == expected_value 

1163 if eq: 

1164 self._value = new_value 

1165 return eq 

1166 

1167 def __bool__(self) -> bool: 

1168 return self.value 

1169 

1170 def __repr__(self) -> str: 

1171 return repr(self.value) 

1172 

1173 def __str__(self) -> str: 

1174 return str(self.value) 

1175 

1176 

1177############################################################################# 

1178K = TypeVar("K") 

1179V = TypeVar("V") 

1180 

1181 

1182class SynchronizedDict(Generic[K, V]): 

1183 """Thread-safe wrapper around a regular dict.""" 

1184 

1185 def __init__(self, val: dict[K, V]) -> None: 

1186 assert isinstance(val, dict) 

1187 self._lock: Final[threading.Lock] = threading.Lock() 

1188 self._dict: Final[dict[K, V]] = val 

1189 

1190 def __getitem__(self, key: K) -> V: 

1191 with self._lock: 

1192 return self._dict[key] 

1193 

1194 def __setitem__(self, key: K, value: V) -> None: 

1195 with self._lock: 

1196 self._dict[key] = value 

1197 

1198 def __delitem__(self, key: K) -> None: 

1199 with self._lock: 

1200 self._dict.pop(key) 

1201 

1202 def __contains__(self, key: K) -> bool: 

1203 with self._lock: 

1204 return key in self._dict 

1205 

1206 def __len__(self) -> int: 

1207 with self._lock: 

1208 return len(self._dict) 

1209 

1210 def __repr__(self) -> str: 

1211 with self._lock: 

1212 return repr(self._dict) 

1213 

1214 def __str__(self) -> str: 

1215 with self._lock: 

1216 return str(self._dict) 

1217 

1218 def get(self, key: K, default: V | None = None) -> V | None: 

1219 """Returns ``self[key]`` or ``default`` if missing.""" 

1220 with self._lock: 

1221 return self._dict.get(key, default) 

1222 

1223 def pop(self, key: K, default: V | None = None) -> V | None: 

1224 """Removes ``key`` and returns its value.""" 

1225 with self._lock: 

1226 return self._dict.pop(key, default) 

1227 

1228 def clear(self) -> None: 

1229 """Removes all items atomically.""" 

1230 with self._lock: 

1231 self._dict.clear() 

1232 

1233 def items(self) -> ItemsView[K, V]: 

1234 """Returns a snapshot of dictionary items.""" 

1235 with self._lock: 

1236 return self._dict.copy().items() 

1237 

1238 

1239############################################################################# 

1240class InterruptibleSleep: 

1241 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable.""" 

1242 

1243 def __init__(self, lock: threading.Lock | None = None) -> None: 

1244 self._is_stopping: bool = False 

1245 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock() 

1246 self._condition: Final[threading.Condition] = threading.Condition(self._lock) 

1247 

1248 def sleep(self, duration_nanos: int) -> bool: 

1249 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted; 

1250 Equivalent to threading.Event.wait().""" 

1251 end_time_nanos: int = time.monotonic_ns() + duration_nanos 

1252 with self._lock: 

1253 while not self._is_stopping: 

1254 diff_nanos: int = end_time_nanos - time.monotonic_ns() 

1255 if diff_nanos <= 0: 

1256 return False 

1257 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout 

1258 return True 

1259 

1260 def interrupt(self) -> None: 

1261 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set().""" 

1262 with self._lock: 

1263 if not self._is_stopping: 

1264 self._is_stopping = True 

1265 self._condition.notify_all() 

1266 

1267 def reset(self) -> None: 

1268 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear().""" 

1269 with self._lock: 

1270 self._is_stopping = False 

1271 

1272 

1273############################################################################# 

1274class SynchronousExecutor(Executor): 

1275 """Executor that runs tasks inline in the calling thread, sequentially.""" 

1276 

1277 def __init__(self) -> None: 

1278 self._shutdown: bool = False 

1279 

1280 def submit(self, fn: Callable[..., R_], /, *args: Any, **kwargs: Any) -> Future[R_]: 

1281 """Executes `fn(*args, **kwargs)` immediately and returns its Future.""" 

1282 future: Future[R_] = Future() 

1283 if self._shutdown: 

1284 raise RuntimeError("cannot schedule new futures after shutdown") 

1285 try: 

1286 result: R_ = fn(*args, **kwargs) 

1287 except BaseException as exc: 

1288 future.set_exception(exc) 

1289 else: 

1290 future.set_result(result) 

1291 return future 

1292 

1293 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: 

1294 """Prevents new submissions; no worker resources to join/cleanup.""" 

1295 self._shutdown = True 

1296 

1297 @classmethod 

1298 def executor_for(cls, max_workers: int) -> Executor: 

1299 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor.""" 

1300 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers) 

1301 

1302 

1303############################################################################# 

1304class _XFinally(contextlib.AbstractContextManager): 

1305 """Context manager ensuring cleanup code executes after ``with`` blocks.""" 

1306 

1307 def __init__(self, cleanup: Callable[[], None]) -> None: 

1308 """Records the callable to run upon exit.""" 

1309 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits. 

1310 

1311 def __exit__( 

1312 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None 

1313 ) -> Literal[False]: 

1314 """Runs cleanup and propagate any exceptions appropriately.""" 

1315 try: 

1316 self._cleanup() 

1317 except BaseException as cleanup_exc: 

1318 if exc is None: 

1319 raise # No main error --> propagate cleanup error normally 

1320 # Both failed 

1321 # if sys.version_info >= (3, 11): 

1322 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None 

1323 # <= 3.10: attach so it shows up in traceback but doesn't mask 

1324 exc.__context__ = cleanup_exc 

1325 return False # reraise original exception 

1326 return False # propagate main exception if any 

1327 

1328 

1329def xfinally(cleanup: Callable[[], None]) -> _XFinally: 

1330 """Usage: with xfinally(lambda: cleanup()): ... 

1331 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never 

1332 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible. 

1333 

1334 Problem it solves 

1335 ----------------- 

1336 A naive ``try ... finally`` may lose the original exception: 

1337 

1338 try: 

1339 work() 

1340 finally: 

1341 cleanup() # <-- if this raises an exception, it replaces the real error! 

1342 

1343 `_XFinally` preserves exception priority: 

1344 

1345 * Body raises, cleanup succeeds --> original body exception is re-raised. 

1346 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``. 

1347 * Body succeeds, cleanup raises --> cleanup exception propagates normally. 

1348 

1349 Example: 

1350 ------- 

1351 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP 

1352 ... run_tasks() 

1353 

1354 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information. 

1355 """ 

1356 return _XFinally(cleanup)