Coverage for bzfs_main / util / utils.py: 100%

759 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 08:03 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight 

16concurrency primitives, etc. 

17 

18Everything in this module relies only on the Python standard library so other modules remain dependency free. Each utility 

19favors simple, predictable behavior on all supported platforms. 

20""" 

21 

22from __future__ import ( 

23 annotations, 

24) 

25import argparse 

26import base64 

27import bisect 

28import collections 

29import contextlib 

30import errno 

31import hashlib 

32import itertools 

33import logging 

34import operator 

35import os 

36import platform 

37import pwd 

38import random 

39import re 

40import signal 

41import stat 

42import subprocess 

43import sys 

44import threading 

45import time 

46import types 

47from collections import ( 

48 defaultdict, 

49 deque, 

50) 

51from collections.abc import ( 

52 ItemsView, 

53 Iterable, 

54 Iterator, 

55 Sequence, 

56) 

57from concurrent.futures import ( 

58 Executor, 

59 Future, 

60 ThreadPoolExecutor, 

61) 

62from datetime import ( 

63 datetime, 

64 timedelta, 

65 timezone, 

66 tzinfo, 

67) 

68from subprocess import ( 

69 DEVNULL, 

70 PIPE, 

71) 

72from typing import ( 

73 IO, 

74 Any, 

75 Callable, 

76 Final, 

77 Generic, 

78 Literal, 

79 NoReturn, 

80 Protocol, 

81 TextIO, 

82 TypeVar, 

83 cast, 

84 final, 

85) 

86 

87# constants: 

88PROG_NAME: Final[str] = "bzfs" 

89ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_" 

90DIE_STATUS: Final[int] = 3 

91DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset 

92LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between 

93LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between 

94LOG_DEBUG: Final[int] = logging.DEBUG 

95LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between 

96YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target 

97UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID 

98DONT_SKIP_DATASET: Final[str] = "" 

99SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\" 

100FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write) 

101DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute) 

102UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions 

103UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path' 

104 

105RegexList = list[tuple[re.Pattern[str], bool]] # Type alias 

106 

107 

108def getenv_any(key: str, default: str | None = None, env_var_prefix: str = ENV_VAR_PREFIX) -> str | None: 

109 """All shell environment variable names used for configuration start with this prefix.""" 

110 return os.getenv(env_var_prefix + key, default) 

111 

112 

113def getenv_int(key: str, default: int, env_var_prefix: str = ENV_VAR_PREFIX) -> int: 

114 """Returns environment variable ``key`` as int with ``default`` fallback.""" 

115 return int(cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix))) 

116 

117 

118def getenv_bool(key: str, default: bool = False, env_var_prefix: str = ENV_VAR_PREFIX) -> bool: 

119 """Returns environment variable ``key`` as bool with ``default`` fallback.""" 

120 return cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)).lower().strip() == "true" 

121 

122 

123def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]: 

124 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command.""" 

125 assert lines is not None 

126 assert isinstance(lines, list) 

127 assert len(separator) == 1 

128 if field == 1: 

129 return [line[0 : line.index(separator)] for line in lines] 

130 elif field == 2: 

131 return [line[line.index(separator) + 1 :] for line in lines] 

132 else: 

133 raise ValueError(f"Invalid field value: {field}") 

134 

135 

136def drain(iterable: Iterable[Any]) -> None: 

137 """Consumes all items in the iterable, effectively draining it.""" 

138 for _ in iterable: 

139 del _ # help gc (iterable can block) 

140 

141 

142_K_ = TypeVar("_K_") 

143_V_ = TypeVar("_V_") 

144_R_ = TypeVar("_R_") 

145 

146 

147def shuffle_dict(dictionary: dict[_K_, _V_], rand: random.Random = random.SystemRandom()) -> dict[_K_, _V_]: # noqa: B008 

148 """Returns a new dict with items shuffled randomly.""" 

149 items: list[tuple[_K_, _V_]] = list(dictionary.items()) 

150 rand.shuffle(items) 

151 return dict(items) 

152 

153 

154def sorted_dict(dictionary: dict[_K_, _V_]) -> dict[_K_, _V_]: 

155 """Returns a new dict with items sorted primarily by key and secondarily by value.""" 

156 return dict(sorted(dictionary.items())) 

157 

158 

159def tail(file: str, n: int, errors: str | None = None) -> Sequence[str]: 

160 """Return the last ``n`` lines of ``file`` without following symlinks.""" 

161 if not os.path.isfile(file): 

162 return [] 

163 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False) as fd: 

164 return deque(fd, maxlen=n) 

165 

166 

167_NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">")) 

168 

169 

170def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str: 

171 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky). 

172 

173 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\(' 

174 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash 

175 with the replacement string '(?:' 

176 

177 Named capturing group example: '(?P<name>abc)' --> '(?:abc)' 

178 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash 

179 with the replacement string '(?:' 

180 

181 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups 

182 """ 

183 if "(" in regex and ( 

184 "[" in regex # literal left square bracket 

185 or "\\N{LEFT SQUARE BRACKET}" in regex # named Unicode escape for '[' 

186 or "\\x5b" in regex # hex escape for '[' (lowercase) 

187 or "\\x5B" in regex # hex escape for '[' (uppercase) 

188 or "\\u005b" in regex # 4-digit Unicode escape for '[' (lowercase) 

189 or "\\u005B" in regex # 4-digit Unicode escape for '[' (uppercase) 

190 or "\\U0000005b" in regex # 8-digit Unicode escape for '[' (lowercase) 

191 or "\\U0000005B" in regex # 8-digit Unicode escape for '[' (uppercase) 

192 or "\\133" in regex # octal escape for '[' 

193 ): 

194 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the rare case where the regex might 

195 # contain a pathological regex character class that contains parenthesis, or when '[' is expressed via escapes. 

196 # Rewriting a regex is a performance optimization; correctness comes first. 

197 return regex 

198 

199 i = len(regex) - 2 

200 while i >= 0: 

201 i = regex.rfind("(", 0, i + 1) 

202 if i >= 0 and (i == 0 or regex[i - 1] != "\\"): 

203 if regex[i + 1] != "?": 

204 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group 

205 else: # potentially a valid named capturing group 

206 regex = regex[0:i] + _NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1) 

207 i -= 1 

208 return regex 

209 

210 

211def get_home_directory() -> str: 

212 """Reliably detects home dir without using HOME env var.""" 

213 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~') 

214 return pwd.getpwuid(os.getuid()).pw_dir 

215 

216 

217def human_readable_bytes(num_bytes: float, separator: str = " ", precision: int | None = None) -> str: 

218 """Formats 'num_bytes' as a human-readable size; for example "567 MiB".""" 

219 sign = "-" if num_bytes < 0 else "" 

220 s = abs(num_bytes) 

221 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB") 

222 n = len(units) - 1 

223 i = 0 

224 while s >= 1024 and i < n: 

225 s /= 1024 

226 i += 1 

227 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}" 

228 return f"{sign}{formatted_num}{separator}{units[i]}" 

229 

230 

231def human_readable_duration(duration: float, unit: str = "ns", separator: str = "", precision: int | None = None) -> str: 

232 """Formats a duration in human units, automatically scaling as needed; for example "567ms".""" 

233 sign = "-" if duration < 0 else "" 

234 t = abs(duration) 

235 units = ("ns", "μs", "ms", "s", "m", "h", "d") 

236 i = units.index(unit) 

237 if t < 1 and t != 0: 

238 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000) 

239 t *= nanos[i] 

240 i = 0 

241 while t >= 1000 and i < 3: 

242 t /= 1000 

243 i += 1 

244 if i >= 3: 

245 while t >= 60 and i < 5: 

246 t /= 60 

247 i += 1 

248 if i >= 5: 

249 while t >= 24 and i < len(units) - 1: 

250 t /= 24 

251 i += 1 

252 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}" 

253 return f"{sign}{formatted_num}{separator}{units[i]}" 

254 

255 

256def human_readable_float(number: float) -> str: 

257 """Formats ``number`` with a variable precision depending on magnitude. 

258 

259 This design mirrors the way humans round values when scanning logs. 

260 

261 If the number has one digit before the decimal point (0 <= abs(number) < 10): 

262 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15"). 

263 

264 If the number has two digits before the decimal point (10 <= abs(number) < 100): 

265 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4"). 

266 

267 If the number has three or more digits before the decimal point (abs(number) >= 100): 

268 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124"). 

269 

270 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1" 

271 """ 

272 abs_number = abs(number) 

273 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0 

274 if precision == 0: 

275 return str(round(number)) 

276 result = f"{number:.{precision}f}" 

277 assert "." in result 

278 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty 

279 return "0" if result == "-0" else result 

280 

281 

282def percent(number: int, total: int, print_total: bool = False) -> str: 

283 """Returns percentage string of ``number`` relative to ``total``.""" 

284 tot: str = f"/{total}" if print_total else "" 

285 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%" 

286 

287 

288def open_nofollow( 

289 path: str, 

290 mode: str = "r", 

291 buffering: int = -1, 

292 encoding: str | None = None, 

293 errors: str | None = None, 

294 newline: str | None = None, 

295 *, 

296 perm: int = FILE_PERMISSIONS, 

297 check_owner: bool = True, 

298 **kwargs: Any, 

299) -> IO[Any]: 

300 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with 

301 errno.ELOOP/EMLINK if basename of path is a symlink. 

302 

303 Also, can specify permissions on O_CREAT, and verify ownership. 

304 """ 

305 if not mode: 

306 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus") 

307 flags = { 

308 "r": os.O_RDONLY, 

309 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 

310 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND, 

311 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL, 

312 }.get(mode[0]) 

313 if flags is None: 

314 raise ValueError(f"invalid mode {mode!r}") 

315 if "+" in mode: # enable read-write access for r+, w+, a+, x+ 

316 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags 

317 flags |= os.O_NOFOLLOW | os.O_CLOEXEC 

318 fd: int = os.open(path, flags=flags, mode=perm) 

319 try: 

320 if check_owner: 

321 st_uid: int = os.fstat(fd).st_uid 

322 if st_uid != os.geteuid(): # verify ownership is current effective UID 

323 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path) 

324 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs) 

325 except Exception: 

326 try: 

327 os.close(fd) 

328 except OSError: 

329 pass 

330 raise 

331 

332 

333def close_quietly(fd: int) -> None: 

334 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this.""" 

335 if fd >= 0: 

336 try: 

337 os.close(fd) 

338 except OSError: 

339 pass 

340 

341 

342_P = TypeVar("_P") 

343 

344 

345def find_match( 

346 seq: Sequence[_P], 

347 predicate: Callable[[_P], bool], 

348 start: int | None = None, 

349 end: int | None = None, 

350 reverse: bool = False, 

351 raises: bool | str | Callable[[], str] = False, # raises: bool | str | Callable = False, # python >= 3.10 

352) -> int: 

353 """Returns the integer index within seq of the first item (or last item if reverse==True) that matches the given 

354 predicate condition. If no matching item is found returns -1 or ValueError, depending on the raises parameter, which is a 

355 bool indicating whether to raise an error, or a string containing the error message, but can also be a Callable/lambda in 

356 order to support efficient deferred generation of error messages. Analog to str.find(), including slicing semantics with 

357 parameters start and end. For example, seq can be a list, tuple or str. 

358 

359 Example usage: 

360 lst = ["a", "b", "-c", "d"] 

361 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True) 

362 if i >= 0: 

363 print(lst[i]) 

364 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}") 

365 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}") 

366 """ 

367 offset: int = 0 if start is None else start if start >= 0 else len(seq) + start 

368 if start is not None or end is not None: 

369 seq = seq[start:end] 

370 for i, item in enumerate(reversed(seq) if reverse else seq): 

371 if predicate(item): 

372 if reverse: 

373 return len(seq) - i - 1 + offset 

374 else: 

375 return i + offset 

376 if raises is False or raises is None: 

377 return -1 

378 if raises is True: 

379 raise ValueError("No matching item found in sequence") 

380 if callable(raises): 

381 raises = raises() 

382 raise ValueError(raises) 

383 

384 

385def is_descendant(dataset: str, of_root_dataset: str) -> bool: 

386 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same.""" 

387 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/") 

388 

389 

390def has_duplicates(sorted_list: list[Any]) -> bool: 

391 """Returns True if any adjacent items within the given sorted sequence are equal.""" 

392 return any(map(operator.eq, sorted_list, itertools.islice(sorted_list, 1, None))) 

393 

394 

395def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool: 

396 """Returns whether the (sorted) list of ZFS input datasets contains any siblings.""" 

397 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted" 

398 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates" 

399 skip_dataset: str = DONT_SKIP_DATASET 

400 parents: set[str] = set() 

401 for dataset in sorted_datasets: 

402 assert dataset 

403 parent = os.path.dirname(dataset) 

404 if parent in parents: 

405 return True # I have a sibling if my parent already has another child 

406 parents.add(parent) 

407 if is_descendant(dataset, of_root_dataset=skip_dataset): 

408 continue 

409 if skip_dataset != DONT_SKIP_DATASET: 

410 return True # I have a sibling if I am a root dataset and another root dataset already exists 

411 skip_dataset = dataset 

412 return False 

413 

414 

415def dry(msg: str, is_dry_run: bool) -> str: 

416 """Prefix ``msg`` with 'Dry' when in dry-run mode.""" 

417 return "Dry " + msg if is_dry_run else msg 

418 

419 

420def relativize_dataset(dataset: str, root_dataset: str) -> str: 

421 """Converts an absolute dataset path to one relative to ``root_dataset``. 

422 

423 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz. 

424 """ 

425 return dataset[len(root_dataset) :] 

426 

427 

428def dataset_paths(dataset: str) -> Iterator[str]: 

429 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c".""" 

430 i: int = 0 

431 while i >= 0: 

432 i = dataset.find("/", i) 

433 if i < 0: 

434 yield dataset 

435 else: 

436 yield dataset[:i] 

437 i += 1 

438 

439 

440def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str: 

441 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present.""" 

442 assert s.startswith(old_prefix) 

443 return new_prefix + s[len(old_prefix) :] 

444 

445 

446def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None: 

447 """Replaces ``old`` with ``new`` in-place for every string in ``lines``.""" 

448 for i in range(len(lines)): 

449 lines[i] = lines[i].replace(old, new, count) 

450 

451 

452_TAPPEND = TypeVar("_TAPPEND") 

453 

454 

455def append_if_absent(lst: list[_TAPPEND], *items: _TAPPEND) -> list[_TAPPEND]: 

456 """Appends items to list if they are not already present.""" 

457 for item in items: 

458 if item not in lst: 

459 lst.append(item) 

460 return lst 

461 

462 

463def xappend(lst: list[_TAPPEND], *items: _TAPPEND | Iterable[_TAPPEND]) -> list[_TAPPEND]: 

464 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If 

465 an item is an iterable does so recursively, flattening the output.""" 

466 for item in items: 

467 if isinstance(item, str) or not isinstance(item, collections.abc.Iterable): 

468 if item: 

469 lst.append(item) 

470 else: 

471 xappend(lst, *item) 

472 return lst 

473 

474 

475def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool: 

476 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False. 

477 

478 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match. 

479 """ 

480 for regex, is_negation in exclude_regexes: 

481 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

482 if is_negation: 

483 is_match = not is_match 

484 if is_match: 

485 return False 

486 

487 for regex, is_negation in include_regexes: 

488 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True 

489 if is_negation: 

490 is_match = not is_match 

491 if is_match: 

492 return True 

493 

494 return False 

495 

496 

497def compile_regexes(regexes: list[str], suffix: str = "") -> RegexList: 

498 """Compiles regex strings and keeps track of negations.""" 

499 assert isinstance(regexes, list) 

500 compiled_regexes: RegexList = [] 

501 for regex in regexes: 

502 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match 

503 if regex.endswith("\\$"): 

504 pass # trailing literal $ is ok 

505 elif regex.endswith("$"): 

506 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch() 

507 elif "$" in regex: 

508 raise re.error("Must not use non-trailing '$' character", regex) 

509 if is_negation := regex.startswith("!"): 

510 regex = regex[1:] 

511 regex = replace_capturing_groups_with_non_capturing_groups(regex) 

512 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")): 

513 regex = f"{regex}{suffix}" 

514 compiled_regexes.append((re.compile(regex), is_negation)) 

515 return compiled_regexes 

516 

517 

518def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any: 

519 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels.""" 

520 

521 @final 

522 class CustomListFormatter: 

523 """Formatter object that joins items when converted to ``str``.""" 

524 

525 def __str__(self) -> str: 

526 s = separator.join(map(str, iterable)) 

527 return s.lstrip() if lstrip else s 

528 

529 return CustomListFormatter() 

530 

531 

532def pretty_print_formatter(obj_to_format: Any) -> Any: 

533 """Lazy pprint formatter used to avoid overhead in disabled log levels.""" 

534 

535 @final 

536 class PrettyPrintFormatter: 

537 """Formatter that pretty-prints the object on conversion to ``str``.""" 

538 

539 def __str__(self) -> str: 

540 import pprint # lazy import for startup perf 

541 

542 return pprint.pformat(vars(obj_to_format)) 

543 

544 return PrettyPrintFormatter() 

545 

546 

547def stderr_to_str(stderr: Any) -> str: 

548 """Workaround for https://github.com/python/cpython/issues/87597.""" 

549 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace") 

550 

551 

552def xprint(log: logging.Logger, value: Any, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None: 

553 """Optionally logs ``value`` at stdout/stderr level.""" 

554 if run and value: 

555 value = value if end else str(value).rstrip() 

556 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR 

557 log.log(level, "%s", value) 

558 

559 

560def sha256_hex(text: str) -> str: 

561 """Returns the sha256 hex string for the given text.""" 

562 return hashlib.sha256(text.encode()).hexdigest() 

563 

564 

565def sha256_urlsafe_base64(text: str, padding: bool = True) -> str: 

566 """Returns the URL-safe base64-encoded sha256 value for the given text.""" 

567 digest: bytes = hashlib.sha256(text.encode()).digest() 

568 s: str = base64.urlsafe_b64encode(digest).decode() 

569 return s if padding else s.rstrip("=") 

570 

571 

572def sha256_128_urlsafe_base64(text: str) -> str: 

573 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

574 s: str = sha256_urlsafe_base64(text, padding=False) 

575 return s[: len(s) // 2] 

576 

577 

578def sha256_85_urlsafe_base64(text: str) -> str: 

579 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text.""" 

580 s: str = sha256_urlsafe_base64(text, padding=False) 

581 return s[: len(s) // 3] 

582 

583 

584def urlsafe_base64( 

585 value: int, max_value: int = 2**64 - 1, padding: bool = True, byteorder: Literal["little", "big"] = "big" 

586) -> str: 

587 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value].""" 

588 assert 0 <= value <= max_value 

589 max_bytes: int = (max_value.bit_length() + 7) // 8 

590 value_bytes: bytes = value.to_bytes(max_bytes, byteorder) 

591 s: str = base64.urlsafe_b64encode(value_bytes).decode() 

592 return s if padding else s.rstrip("=") 

593 

594 

595def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn: 

596 """Exits the program with ``exit_code`` after logging ``msg``.""" 

597 if parser is None: 

598 ex = SystemExit(msg) 

599 ex.code = exit_code 

600 raise ex 

601 else: 

602 parser.error(msg) 

603 

604 

605def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

606 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and 

607 provides optional child PID tracking, and optional logging of execution status via ``log`` and ``loglevel`` params.""" 

608 input_value = kwargs.pop("input", None) 

609 timeout = kwargs.pop("timeout", None) 

610 check = kwargs.pop("check", False) 

611 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None) 

612 if input_value is not None: 

613 if kwargs.get("stdin") is not None: 

614 raise ValueError("input and stdin are mutually exclusive") 

615 kwargs["stdin"] = subprocess.PIPE 

616 

617 log: logging.Logger | None = kwargs.pop("log", None) 

618 loglevel: int | None = kwargs.pop("loglevel", None) 

619 start_time_nanos: int = time.monotonic_ns() 

620 is_timeout: bool = False 

621 is_cancel: bool = False 

622 exitcode: int | None = None 

623 

624 def log_status() -> None: 

625 if log is not None: 

626 _loglevel: int = loglevel if loglevel is not None else getenv_int("subprocess_run_loglevel", LOG_TRACE) 

627 if log.isEnabledFor(_loglevel): 

628 elapsed_time: str = human_readable_float((time.monotonic_ns() - start_time_nanos) / 1_000_000) + "ms" 

629 status = "cancel" if is_cancel else "timeout" if is_timeout else "success" if exitcode == 0 else "failure" 

630 cmd = kwargs["args"] if "args" in kwargs else (args[0] if args else None) 

631 cmd_str: str = " ".join(str(arg) for arg in iter(cmd)) if isinstance(cmd, (list, tuple)) else str(cmd) 

632 log.log(_loglevel, f"Executed [{status}] [{elapsed_time}]: %s", cmd_str) 

633 

634 with xfinally(log_status): 

635 ctx: contextlib.AbstractContextManager[subprocess.Popen] 

636 if subprocesses is None: 

637 ctx = subprocess.Popen(*args, **kwargs) 

638 else: 

639 ctx = subprocesses.popen_and_track(*args, **kwargs) 

640 with ctx as proc: 

641 try: 

642 sp = subprocesses 

643 if sp is not None and sp._termination_event.is_set(): # noqa: SLF001 # pylint: disable=protected-access 

644 is_cancel = True 

645 timeout = 0.0 

646 stdout, stderr = proc.communicate(input_value, timeout=timeout) 

647 except BaseException as e: 

648 try: 

649 if isinstance(e, subprocess.TimeoutExpired): 

650 is_timeout = True 

651 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child proc and descendants 

652 finally: 

653 proc.kill() 

654 raise 

655 else: 

656 exitcode = proc.poll() 

657 assert exitcode is not None 

658 if check and exitcode: 

659 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr) 

660 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr) 

661 

662 

663def terminate_process_subtree( 

664 except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM 

665) -> None: 

666 """For each root PID: Sends the given signal to the root PID and all its descendant processes.""" 

667 current_pid: int = os.getpid() 

668 root_pids = [current_pid] if root_pids is None else root_pids 

669 all_pids: list[list[int]] = _get_descendant_processes(root_pids) 

670 assert len(all_pids) == len(root_pids) 

671 for i, pids in enumerate(all_pids): 

672 root_pid = root_pids[i] 

673 if root_pid == current_pid: 

674 pids += [] if except_current_process else [current_pid] 

675 else: 

676 pids.insert(0, root_pid) 

677 for pid in pids: 

678 with contextlib.suppress(OSError): 

679 os.kill(pid, sig) 

680 

681 

682def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]: 

683 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems.""" 

684 if len(root_pids) == 0: 

685 return [] 

686 cmd: list[str] = ["ps", "-Ao", "pid,ppid"] 

687 try: 

688 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines() 

689 except PermissionError: 

690 # degrade gracefully in sandbox environments that deny executing `ps` entirely 

691 return [[] for _ in root_pids] 

692 procs: dict[int, list[int]] = defaultdict(list) 

693 for line in lines[1:]: # all lines except the header line 

694 splits: list[str] = line.split() 

695 assert len(splits) == 2 

696 pid = int(splits[0]) 

697 ppid = int(splits[1]) 

698 procs[ppid].append(pid) 

699 

700 def recursive_append(ppid: int, descendants: list[int]) -> None: 

701 """Recursively collect descendant PIDs starting from ``ppid``.""" 

702 for child_pid in procs[ppid]: 

703 descendants.append(child_pid) 

704 recursive_append(child_pid, descendants) 

705 

706 all_descendants: list[list[int]] = [] 

707 for root_pid in root_pids: 

708 descendants: list[int] = [] 

709 recursive_append(root_pid, descendants) 

710 all_descendants.append(descendants) 

711 return all_descendants 

712 

713 

714@contextlib.contextmanager 

715def termination_signal_handler( 

716 termination_events: list[threading.Event], 

717 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(), 

718) -> Iterator[None]: 

719 """Context manager that installs SIGINT/SIGTERM handlers that set all ``termination_events`` and, by default, terminate 

720 all descendant processes.""" 

721 termination_events = list(termination_events) # shallow copy 

722 

723 def _handler(_sig: int, _frame: object) -> None: 

724 for event in termination_events: 

725 event.set() 

726 termination_handler() 

727 

728 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler 

729 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler 

730 try: 

731 yield # run body of context manager 

732 finally: 

733 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler 

734 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler 

735 

736 

737############################################################################# 

738@final 

739class Subprocesses: 

740 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used 

741 when multiple jobs run concurrently within the same Python process. 

742 

743 Optionally binds to a termination_event to enforce asynchronous cancellation by forcing immediate timeouts for newly 

744 spawned subprocesses once cancellation is requested. 

745 """ 

746 

747 def __init__(self, termination_event: threading.Event | None = None) -> None: 

748 self._termination_event: Final[threading.Event] = termination_event or threading.Event() 

749 self._lock: Final[threading.Lock] = threading.Lock() 

750 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order 

751 

752 @contextlib.contextmanager 

753 def popen_and_track(self, *popen_args: Any, **popen_kwargs: Any) -> Iterator[subprocess.Popen]: 

754 """Context manager that calls subprocess.Popen() and tracks the child PID for per-job termination. 

755 

756 Holds a lock across Popen+PID registration to prevent a race when terminate_process_subtrees() is invoked (e.g. from 

757 SIGINT/SIGTERM handlers), ensuring newly spawned child processes cannot escape termination. The child PID is 

758 unregistered on context exit. 

759 """ 

760 with self._lock: 

761 proc: subprocess.Popen = subprocess.Popen(*popen_args, **popen_kwargs) 

762 self._child_pids[proc.pid] = None 

763 try: 

764 yield proc 

765 finally: 

766 with self._lock: 

767 self._child_pids.pop(proc.pid, None) 

768 

769 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess: 

770 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination.""" 

771 return subprocess_run(*args, **kwargs, subprocesses=self) 

772 

773 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None: 

774 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs.""" 

775 with self._lock: 

776 pids: list[int] = list(self._child_pids) 

777 self._child_pids.clear() 

778 terminate_process_subtree(root_pids=pids, sig=sig) 

779 

780 

781############################################################################# 

782def pid_exists(pid: int) -> bool | None: 

783 """Returns True if a process with PID exists, False if not, or None on error.""" 

784 if pid <= 0: 

785 return False 

786 try: # with signal=0, no signal is actually sent, but error checking is still performed 

787 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems 

788 except OSError as err: 

789 if err.errno == errno.ESRCH: # No such process 

790 return False 

791 if err.errno == errno.EPERM: # Operation not permitted 

792 return True 

793 return None 

794 return True 

795 

796 

797def nprefix(s: str) -> str: 

798 """Returns a canonical snapshot prefix with trailing underscore.""" 

799 return sys.intern(s + "_") 

800 

801 

802def ninfix(s: str) -> str: 

803 """Returns a canonical infix with trailing underscore when not empty.""" 

804 return sys.intern(s + "_") if s else "" 

805 

806 

807def nsuffix(s: str) -> str: 

808 """Returns a canonical suffix with leading underscore when not empty.""" 

809 return sys.intern("_" + s) if s else "" 

810 

811 

812def format_dict(dictionary: dict[Any, Any]) -> str: 

813 """Returns a formatted dictionary using repr for consistent output.""" 

814 return f'"{dictionary}"' 

815 

816 

817def format_obj(obj: object) -> str: 

818 """Returns a formatted str using repr for consistent output.""" 

819 return f'"{obj}"' 

820 

821 

822def validate_dataset_name(dataset: str, input_text: str) -> None: 

823 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc.""" 

824 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424 

825 # and https://github.com/openzfs/zfs/issues/8798 

826 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html 

827 invalid_chars: str = SHELL_CHARS 

828 if ( 

829 dataset in ("", ".", "..") 

830 or dataset.startswith(("/", "./", "../")) 

831 or dataset.endswith(("/", "/.", "/..")) 

832 or any(substring in dataset for substring in ("//", "/./", "/../")) 

833 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset) 

834 or not dataset[0].isalpha() 

835 ): 

836 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'") 

837 

838 

839def validate_property_name(propname: str, input_text: str) -> str: 

840 """Checks that the ZFS property name contains no spaces or shell chars.""" 

841 invalid_chars: str = SHELL_CHARS 

842 if not propname or any(char.isspace() or char in invalid_chars for char in propname): 

843 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'") 

844 return propname 

845 

846 

847def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None: 

848 """Checks that the given path is not a symbolic link.""" 

849 if os.path.islink(path): 

850 die(f"{msg}must not be a symlink: {path}", parser=parser) 

851 

852 

853def validate_file_permissions(path: str, mode: int) -> None: 

854 """Verify permissions and that ownership is current effective UID.""" 

855 stats: os.stat_result = os.stat(path, follow_symlinks=False) 

856 st_uid: int = stats.st_uid 

857 if st_uid != os.geteuid(): # verify ownership is current effective UID 

858 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}") 

859 st_mode = stat.S_IMODE(stats.st_mode) 

860 if st_mode != mode: 

861 die( 

862 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, " 

863 f"not {mode:03o} aka {stat.filemode(mode)[1:]})" 

864 ) 

865 

866 

867def parse_duration_to_milliseconds(duration: str, regex_suffix: str = "", context: str = "") -> int: 

868 """Parses human duration strings like '5m' or '2 hours' to milliseconds.""" 

869 unit_milliseconds: dict[str, int] = { 

870 "milliseconds": 1, 

871 "millis": 1, 

872 "seconds": 1000, 

873 "secs": 1000, 

874 "minutes": 60 * 1000, 

875 "mins": 60 * 1000, 

876 "hours": 60 * 60 * 1000, 

877 "days": 86400 * 1000, 

878 "weeks": 7 * 86400 * 1000, 

879 "months": round(30.5 * 86400 * 1000), 

880 "years": 365 * 86400 * 1000, 

881 } 

882 match = re.fullmatch( 

883 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix, 

884 duration, 

885 ) 

886 if not match: 

887 if context: 

888 die(f"Invalid duration format: {duration} within {context}") 

889 else: 

890 raise ValueError(f"Invalid duration format: {duration}") 

891 assert match 

892 quantity: int = int(match.group(1)) 

893 unit: str = match.group(2) 

894 return quantity * unit_milliseconds[unit] 

895 

896 

897def unixtime_fromisoformat(datetime_str: str) -> int: 

898 """Converts ISO 8601 datetime string into UTC Unix time seconds.""" 

899 return int(datetime.fromisoformat(datetime_str).timestamp()) 

900 

901 

902def isotime_from_unixtime(unixtime_in_seconds: int) -> str: 

903 """Converts UTC Unix time seconds into ISO 8601 datetime string.""" 

904 tz: tzinfo = timezone.utc 

905 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz) 

906 return dt.isoformat(sep="_", timespec="seconds") 

907 

908 

909def current_datetime( 

910 tz_spec: str | None = None, 

911 now_fn: Callable[[tzinfo | None], datetime] | None = None, 

912) -> datetime: 

913 """Returns current time in ``tz_spec`` timezone or local timezone.""" 

914 if now_fn is None: 

915 now_fn = datetime.now 

916 return now_fn(get_timezone(tz_spec)) 

917 

918 

919def get_timezone(tz_spec: str | None = None) -> tzinfo | None: 

920 """Returns timezone from spec or local timezone if unspecified.""" 

921 tz: tzinfo | None 

922 if tz_spec is None: 

923 tz = None 

924 elif tz_spec == "UTC": 

925 tz = timezone.utc 

926 else: 

927 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec): 

928 sign, hours, minutes = match.groups() 

929 offset: int = int(hours) * 60 + int(minutes) 

930 offset = -offset if sign == "-" else offset 

931 tz = timezone(timedelta(minutes=offset)) 

932 elif "/" in tz_spec: 

933 from zoneinfo import ZoneInfo # lazy import for startup perf 

934 

935 tz = ZoneInfo(tz_spec) 

936 else: 

937 raise ValueError(f"Invalid timezone specification: {tz_spec}") 

938 return tz 

939 

940 

941############################################################################### 

942@final 

943class SnapshotPeriods: # thread-safe 

944 """Parses snapshot suffix strings and converts between durations.""" 

945 

946 def __init__(self) -> None: 

947 """Initialize lookup tables of suffixes and corresponding millis.""" 

948 self.suffix_milliseconds: Final[dict[str, int]] = { 

949 "yearly": 365 * 86400 * 1000, 

950 "monthly": round(30.5 * 86400 * 1000), 

951 "weekly": 7 * 86400 * 1000, 

952 "daily": 86400 * 1000, 

953 "hourly": 60 * 60 * 1000, 

954 "minutely": 60 * 1000, 

955 "secondly": 1000, 

956 "millisecondly": 1, 

957 } 

958 self.period_labels: Final[dict[str, str]] = { 

959 "yearly": "years", 

960 "monthly": "months", 

961 "weekly": "weeks", 

962 "daily": "days", 

963 "hourly": "hours", 

964 "minutely": "minutes", 

965 "secondly": "seconds", 

966 "millisecondly": "milliseconds", 

967 } 

968 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})") 

969 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern) 

970 

971 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]: 

972 """Parse suffix like '10minutely' to (10, 'minutely').""" 

973 return self._suffix_to_duration(suffix, self._suffix_regex0) 

974 

975 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]: 

976 """Like :meth:`suffix_to_duration0` but expects an underscore prefix.""" 

977 return self._suffix_to_duration(suffix, self._suffix_regex1) 

978 

979 @staticmethod 

980 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]: 

981 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly').""" 

982 if match := regex.fullmatch(suffix): 

983 duration_amount: int = int(match.group(1)) if match.group(1) else 1 

984 assert duration_amount > 0 

985 duration_unit: str = match.group(2) 

986 return duration_amount, duration_unit 

987 else: 

988 return 0, "" 

989 

990 def label_milliseconds(self, snapshot: str) -> int: 

991 """Returns duration encoded in ``snapshot`` suffix, in milliseconds.""" 

992 i = snapshot.rfind("_") 

993 snapshot = "" if i < 0 else snapshot[i + 1 :] 

994 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0) 

995 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0) 

996 

997 

998############################################################################# 

999@final 

1000class JobStats: 

1001 """Simple thread-safe counters summarizing job progress.""" 

1002 

1003 def __init__(self, jobs_all: int) -> None: 

1004 assert jobs_all >= 0 

1005 self.lock: Final[threading.Lock] = threading.Lock() 

1006 self.jobs_all: int = jobs_all 

1007 self.jobs_started: int = 0 

1008 self.jobs_completed: int = 0 

1009 self.jobs_failed: int = 0 

1010 self.jobs_running: int = 0 

1011 self.sum_elapsed_nanos: int = 0 

1012 self.started_job_names: Final[set[str]] = set() 

1013 

1014 def submit_job(self, job_name: str) -> str: 

1015 """Counts a job submission.""" 

1016 with self.lock: 

1017 self.jobs_started += 1 

1018 self.jobs_running += 1 

1019 self.started_job_names.add(job_name) 

1020 return str(self) 

1021 

1022 def complete_job(self, failed: bool, elapsed_nanos: int) -> str: 

1023 """Counts a job completion.""" 

1024 assert elapsed_nanos >= 0 

1025 with self.lock: 

1026 self.jobs_running -= 1 

1027 self.jobs_completed += 1 

1028 self.jobs_failed += 1 if failed else 0 

1029 self.sum_elapsed_nanos += elapsed_nanos 

1030 msg = str(self) 

1031 assert self.sum_elapsed_nanos >= 0, msg 

1032 assert self.jobs_running >= 0, msg 

1033 assert self.jobs_failed >= 0, msg 

1034 assert self.jobs_failed <= self.jobs_completed, msg 

1035 assert self.jobs_completed <= self.jobs_started, msg 

1036 assert self.jobs_started <= self.jobs_all, msg 

1037 return msg 

1038 

1039 def __repr__(self) -> str: 

1040 def pct(number: int) -> str: 

1041 """Returns percentage string relative to total jobs.""" 

1042 return percent(number, total=self.jobs_all, print_total=True) 

1043 

1044 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed 

1045 running = self.jobs_running 

1046 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed)) 

1047 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}" 

1048 

1049 

1050############################################################################# 

1051class Comparable(Protocol): 

1052 """Partial ordering protocol.""" 

1053 

1054 def __lt__(self, other: Any) -> bool: ... 

1055 

1056 

1057TComparable = TypeVar("TComparable", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue 

1058 

1059 

1060@final 

1061class SmallPriorityQueue(Generic[TComparable]): 

1062 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and 

1063 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case 

1064 for us. 

1065 

1066 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed 

1067 priority queue via 

1068 https://github.com/nvictus/pqdict. 

1069 But, to avoid an external dependency, is actually implemented 

1070 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that 

1071 are already contained in the queue, via removal of the element, followed by update of the element, followed by 

1072 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates. 

1073 """ 

1074 

1075 def __init__(self, reverse: bool = False) -> None: 

1076 """Creates an empty queue; sort order flips when ``reverse`` is True.""" 

1077 self._lst: Final[list[TComparable]] = [] 

1078 self._reverse: Final[bool] = reverse 

1079 

1080 def clear(self) -> None: 

1081 """Removes all elements from the queue.""" 

1082 self._lst.clear() 

1083 

1084 def push(self, element: TComparable) -> None: 

1085 """Inserts ``element`` while maintaining sorted order.""" 

1086 bisect.insort(self._lst, element) 

1087 

1088 def pop(self) -> TComparable: 

1089 """Removes and returns the smallest (or largest if reverse == True) element from the queue.""" 

1090 return self._lst.pop() if self._reverse else self._lst.pop(0) 

1091 

1092 def peek(self) -> TComparable: 

1093 """Returns the smallest (or largest if reverse == True) element without removing it.""" 

1094 return self._lst[-1] if self._reverse else self._lst[0] 

1095 

1096 def remove(self, element: TComparable) -> bool: 

1097 """Removes the first occurrence (in insertion order aka FIFO) of ``element`` and returns True if it was present.""" 

1098 lst = self._lst 

1099 i = bisect.bisect_left(lst, element) 

1100 is_contained = i < len(lst) and lst[i] == element 

1101 if is_contained: 

1102 del lst[i] # is an optimized memmove() 

1103 return is_contained 

1104 

1105 def __len__(self) -> int: 

1106 """Returns the number of queued elements.""" 

1107 return len(self._lst) 

1108 

1109 def __contains__(self, element: TComparable) -> bool: 

1110 """Returns ``True`` if ``element`` is present.""" 

1111 lst = self._lst 

1112 i = bisect.bisect_left(lst, element) 

1113 return i < len(lst) and lst[i] == element 

1114 

1115 def __iter__(self) -> Iterator[TComparable]: 

1116 """Iterates over queued elements in priority order.""" 

1117 return reversed(self._lst) if self._reverse else iter(self._lst) 

1118 

1119 def __repr__(self) -> str: 

1120 """Representation showing queue contents in current order.""" 

1121 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst) 

1122 

1123 

1124############################################################################### 

1125@final 

1126class SortedInterner(Generic[TComparable]): 

1127 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search).""" 

1128 

1129 def __init__(self, sorted_list: list[TComparable]) -> None: 

1130 self._lst: Final[list[TComparable]] = sorted_list 

1131 

1132 def interned(self, element: TComparable) -> TComparable: 

1133 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1134 lst = self._lst 

1135 i = binary_search(lst, element) 

1136 return lst[i] if i >= 0 else element 

1137 

1138 def __contains__(self, element: TComparable) -> bool: 

1139 """Returns ``True`` if ``element`` is present.""" 

1140 return binary_search(self._lst, element) >= 0 

1141 

1142 

1143def binary_search(sorted_list: list[TComparable], item: TComparable) -> int: 

1144 """Java-style binary search; Returns index >= 0 if an equal item is found in list, else '-insertion_point-1'; If it 

1145 returns index >= 0, the index will be the left-most index in case multiple such equal items are contained.""" 

1146 i = bisect.bisect_left(sorted_list, item) 

1147 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1 

1148 

1149 

1150############################################################################### 

1151_S = TypeVar("_S") 

1152 

1153 

1154@final 

1155class HashedInterner(Generic[_S]): 

1156 """Same as sys.intern() except that it isn't global and can also be used for types other than str.""" 

1157 

1158 def __init__(self, items: Iterable[_S] = frozenset()) -> None: 

1159 self._items: Final[dict[_S, _S]] = {v: v for v in items} 

1160 

1161 def intern(self, item: _S) -> _S: 

1162 """Interns the given item.""" 

1163 return self._items.setdefault(item, item) 

1164 

1165 def interned(self, item: _S) -> _S: 

1166 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item.""" 

1167 return self._items.get(item, item) 

1168 

1169 def __contains__(self, item: _S) -> bool: 

1170 return item in self._items 

1171 

1172 

1173############################################################################# 

1174@final 

1175class SynchronizedBool: 

1176 """Thread-safe wrapper around a regular bool.""" 

1177 

1178 def __init__(self, val: bool) -> None: 

1179 assert isinstance(val, bool) 

1180 self._lock: Final[threading.Lock] = threading.Lock() 

1181 self._value: bool = val 

1182 

1183 @property 

1184 def value(self) -> bool: 

1185 """Returns the current boolean value.""" 

1186 with self._lock: 

1187 return self._value 

1188 

1189 @value.setter 

1190 def value(self, new_value: bool) -> None: 

1191 """Atomically assign ``new_value``.""" 

1192 with self._lock: 

1193 self._value = new_value 

1194 

1195 def get_and_set(self, new_value: bool) -> bool: 

1196 """Swaps in ``new_value`` and return the previous value.""" 

1197 with self._lock: 

1198 old_value = self._value 

1199 self._value = new_value 

1200 return old_value 

1201 

1202 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool: 

1203 """Sets to ``new_value`` only if current value equals ``expected_value``.""" 

1204 with self._lock: 

1205 eq: bool = self._value == expected_value 

1206 if eq: 

1207 self._value = new_value 

1208 return eq 

1209 

1210 def __bool__(self) -> bool: 

1211 return self.value 

1212 

1213 def __repr__(self) -> str: 

1214 return repr(self.value) 

1215 

1216 def __str__(self) -> str: 

1217 return str(self.value) 

1218 

1219 

1220############################################################################# 

1221_K = TypeVar("_K") 

1222_V = TypeVar("_V") 

1223 

1224 

1225@final 

1226class SynchronizedDict(Generic[_K, _V]): 

1227 """Thread-safe wrapper around a regular dict.""" 

1228 

1229 def __init__(self, val: dict[_K, _V]) -> None: 

1230 assert isinstance(val, dict) 

1231 self._lock: Final[threading.Lock] = threading.Lock() 

1232 self._dict: Final[dict[_K, _V]] = val 

1233 

1234 def __getitem__(self, key: _K) -> _V: 

1235 with self._lock: 

1236 return self._dict[key] 

1237 

1238 def __setitem__(self, key: _K, value: _V) -> None: 

1239 with self._lock: 

1240 self._dict[key] = value 

1241 

1242 def __delitem__(self, key: _K) -> None: 

1243 with self._lock: 

1244 self._dict.pop(key) 

1245 

1246 def __contains__(self, key: _K) -> bool: 

1247 with self._lock: 

1248 return key in self._dict 

1249 

1250 def __len__(self) -> int: 

1251 with self._lock: 

1252 return len(self._dict) 

1253 

1254 def __repr__(self) -> str: 

1255 with self._lock: 

1256 return repr(self._dict) 

1257 

1258 def __str__(self) -> str: 

1259 with self._lock: 

1260 return str(self._dict) 

1261 

1262 def get(self, key: _K, default: _V | None = None) -> _V | None: 

1263 """Returns ``self[key]`` or ``default`` if missing.""" 

1264 with self._lock: 

1265 return self._dict.get(key, default) 

1266 

1267 def pop(self, key: _K, default: _V | None = None) -> _V | None: 

1268 """Removes ``key`` and returns its value.""" 

1269 with self._lock: 

1270 return self._dict.pop(key, default) 

1271 

1272 def clear(self) -> None: 

1273 """Removes all items atomically.""" 

1274 with self._lock: 

1275 self._dict.clear() 

1276 

1277 def items(self) -> ItemsView[_K, _V]: 

1278 """Returns a snapshot of dictionary items.""" 

1279 with self._lock: 

1280 return self._dict.copy().items() 

1281 

1282 

1283############################################################################# 

1284@final 

1285class InterruptibleSleep: 

1286 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable.""" 

1287 

1288 def __init__(self, lock: threading.Lock | None = None) -> None: 

1289 self._is_stopping: bool = False 

1290 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock() 

1291 self._condition: Final[threading.Condition] = threading.Condition(self._lock) 

1292 

1293 def sleep(self, duration_nanos: int) -> bool: 

1294 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted; 

1295 Equivalent to threading.Event.wait().""" 

1296 end_time_nanos: int = time.monotonic_ns() + duration_nanos 

1297 with self._lock: 

1298 while not self._is_stopping: 

1299 diff_nanos: int = end_time_nanos - time.monotonic_ns() 

1300 if diff_nanos <= 0: 

1301 return False 

1302 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout 

1303 return True 

1304 

1305 def interrupt(self) -> None: 

1306 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set().""" 

1307 with self._lock: 

1308 if not self._is_stopping: 

1309 self._is_stopping = True 

1310 self._condition.notify_all() 

1311 

1312 def reset(self) -> None: 

1313 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear().""" 

1314 with self._lock: 

1315 self._is_stopping = False 

1316 

1317 

1318############################################################################# 

1319@final 

1320class SynchronousExecutor(Executor): 

1321 """Executor that runs tasks inline in the calling thread, sequentially.""" 

1322 

1323 def __init__(self) -> None: 

1324 self._shutdown: bool = False 

1325 

1326 def submit(self, fn: Callable[..., _R_], /, *args: Any, **kwargs: Any) -> Future[_R_]: 

1327 """Executes `fn(*args, **kwargs)` immediately and returns its Future.""" 

1328 future: Future[_R_] = Future() 

1329 if self._shutdown: 

1330 raise RuntimeError("cannot schedule new futures after shutdown") 

1331 try: 

1332 result: _R_ = fn(*args, **kwargs) 

1333 except BaseException as exc: 

1334 future.set_exception(exc) 

1335 else: 

1336 future.set_result(result) 

1337 return future 

1338 

1339 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None: 

1340 """Prevents new submissions; no worker resources to join/cleanup.""" 

1341 self._shutdown = True 

1342 

1343 @classmethod 

1344 def executor_for(cls, max_workers: int) -> Executor: 

1345 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor.""" 

1346 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers) 

1347 

1348 

1349############################################################################# 

1350@final 

1351class _XFinally(contextlib.AbstractContextManager): 

1352 """Context manager ensuring cleanup code executes after ``with`` blocks.""" 

1353 

1354 def __init__(self, cleanup: Callable[[], None]) -> None: 

1355 """Records the callable to run upon exit.""" 

1356 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits. 

1357 

1358 def __exit__( 

1359 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None 

1360 ) -> Literal[False]: 

1361 """Runs cleanup and propagate any exceptions appropriately.""" 

1362 try: 

1363 self._cleanup() 

1364 except BaseException as cleanup_exc: 

1365 if exc is None: 

1366 raise # No main error --> propagate cleanup error normally 

1367 # Both failed 

1368 # if sys.version_info >= (3, 11): 

1369 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None 

1370 # <= 3.10: attach so it shows up in traceback but doesn't mask 

1371 exc.__context__ = cleanup_exc 

1372 return False # reraise original exception 

1373 return False # propagate main exception if any 

1374 

1375 

1376def xfinally(cleanup: Callable[[], None]) -> _XFinally: 

1377 """Usage: with xfinally(lambda: cleanup()): ... 

1378 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never 

1379 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible. 

1380 

1381 Problem it solves 

1382 ----------------- 

1383 A naive ``try ... finally`` may lose the original exception: 

1384 

1385 try: 

1386 work() 

1387 finally: 

1388 cleanup() # <-- if this raises an exception, it replaces the real error! 

1389 

1390 `_XFinally` preserves exception priority: 

1391 

1392 * Body raises, cleanup succeeds --> original body exception is re-raised. 

1393 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``. 

1394 * Body succeeds, cleanup raises --> cleanup exception propagates normally. 

1395 

1396 Example: 

1397 ------- 

1398 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP 

1399 ... run_tasks() 

1400 

1401 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information. 

1402 """ 

1403 return _XFinally(cleanup)