Coverage for bzfs_main/utils.py: 100%
723 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-07 04:44 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-07 04:44 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight
16concurrency primitives.
18Everything in this module relies only on the standard library so other modules remain dependency free. Each utility favors
19simple, predictable behavior on all supported platforms.
20"""
22from __future__ import (
23 annotations,
24)
25import argparse
26import base64
27import bisect
28import collections
29import contextlib
30import errno
31import hashlib
32import logging
33import os
34import platform
35import pwd
36import random
37import re
38import signal
39import stat
40import subprocess
41import sys
42import threading
43import time
44import types
45from collections import (
46 defaultdict,
47 deque,
48)
49from collections.abc import (
50 ItemsView,
51 Iterable,
52 Iterator,
53 Sequence,
54)
55from concurrent.futures import (
56 Executor,
57 Future,
58 ThreadPoolExecutor,
59)
60from datetime import (
61 datetime,
62 timedelta,
63 timezone,
64 tzinfo,
65)
66from subprocess import (
67 DEVNULL,
68 PIPE,
69)
70from typing import (
71 IO,
72 Any,
73 Callable,
74 Final,
75 Generic,
76 Literal,
77 NoReturn,
78 Protocol,
79 TextIO,
80 TypeVar,
81 cast,
82)
84# constants:
85PROG_NAME: Final[str] = "bzfs"
86ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_"
87DIE_STATUS: Final[int] = 3
88DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset
89LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between
90LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between
91LOG_DEBUG: Final[int] = logging.DEBUG
92LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between
93SNAPSHOT_FILTERS_VAR: Final[str] = "snapshot_filters_var"
94YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target
95UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID
96DONT_SKIP_DATASET: Final[str] = ""
97SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\"
98FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write)
99DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute)
100UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions
101UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path'
103RegexList = list[tuple[re.Pattern[str], bool]] # Type alias
106def getenv_any(key: str, default: str | None = None) -> str | None:
107 """All shell environment variable names used for configuration start with this prefix."""
108 return os.getenv(ENV_VAR_PREFIX + key, default)
111def getenv_int(key: str, default: int) -> int:
112 """Returns environment variable ``key`` as int with ``default`` fallback."""
113 return int(cast(str, getenv_any(key, str(default))))
116def getenv_bool(key: str, default: bool = False) -> bool:
117 """Returns environment variable ``key`` as bool with ``default`` fallback."""
118 return cast(str, getenv_any(key, str(default))).lower().strip() == "true"
121def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]:
122 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command."""
123 assert lines is not None
124 assert isinstance(lines, list)
125 assert len(separator) == 1
126 if field == 1:
127 return [line[0 : line.index(separator)] for line in lines]
128 elif field == 2:
129 return [line[line.index(separator) + 1 :] for line in lines]
130 else:
131 raise ValueError(f"Invalid field value: {field}")
134def drain(iterable: Iterable[Any]) -> None:
135 """Consumes all items in the iterable, effectively draining it."""
136 for _ in iterable:
137 _ = None # help gc (iterable can block)
140K_ = TypeVar("K_")
141V_ = TypeVar("V_")
142R_ = TypeVar("R_")
145def shuffle_dict(dictionary: dict[K_, V_], rand: random.Random = random.SystemRandom()) -> dict[K_, V_]: # noqa: B008
146 """Returns a new dict with items shuffled randomly."""
147 items: list[tuple[K_, V_]] = list(dictionary.items())
148 rand.shuffle(items)
149 return dict(items)
152def sorted_dict(dictionary: dict[K_, V_]) -> dict[K_, V_]:
153 """Returns a new dict with items sorted primarily by key and secondarily by value."""
154 return dict(sorted(dictionary.items()))
157def tail(file: str, n: int, errors: str | None = None) -> Sequence[str]:
158 """Return the last ``n`` lines of ``file`` without following symlinks."""
159 if not os.path.isfile(file):
160 return []
161 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False) as fd:
162 return deque(fd, maxlen=n)
165NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">"))
168def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str:
169 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky).
171 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\('
172 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash
173 with the replacement string '(?:'
175 Named capturing group example: '(?P<name>abc)' --> '(?:abc)'
176 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash
177 with the replacement string '(?:'
179 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups
180 """
181 if "(" in regex and (
182 "[" in regex # literal left square bracket
183 or "\\N{LEFT SQUARE BRACKET}" in regex # named Unicode escape for '['
184 or "\\x5b" in regex # hex escape for '[' (lowercase)
185 or "\\x5B" in regex # hex escape for '[' (uppercase)
186 or "\\u005b" in regex # 4-digit Unicode escape for '[' (lowercase)
187 or "\\u005B" in regex # 4-digit Unicode escape for '[' (uppercase)
188 or "\\U0000005b" in regex # 8-digit Unicode escape for '[' (lowercase)
189 or "\\U0000005B" in regex # 8-digit Unicode escape for '[' (uppercase)
190 or "\\133" in regex # octal escape for '['
191 ):
192 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the rare case where the regex might
193 # contain a pathological regex character class that contains parenthesis, or when '[' is expressed via escapes.
194 # Rewriting a regex is a performance optimization; correctness comes first.
195 return regex
197 i = len(regex) - 2
198 while i >= 0:
199 i = regex.rfind("(", 0, i + 1)
200 if i >= 0 and (i == 0 or regex[i - 1] != "\\"):
201 if regex[i + 1] != "?":
202 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group
203 else: # potentially a valid named capturing group
204 regex = regex[0:i] + NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1)
205 i -= 1
206 return regex
209def get_home_directory() -> str:
210 """Reliably detects home dir without using HOME env var."""
211 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~')
212 return pwd.getpwuid(os.getuid()).pw_dir
215def human_readable_bytes(num_bytes: float, separator: str = " ", precision: int | None = None) -> str:
216 """Formats 'num_bytes' as a human-readable size; for example "567 MiB"."""
217 sign = "-" if num_bytes < 0 else ""
218 s = abs(num_bytes)
219 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB")
220 n = len(units) - 1
221 i = 0
222 while s >= 1024 and i < n:
223 s /= 1024
224 i += 1
225 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}"
226 return f"{sign}{formatted_num}{separator}{units[i]}"
229def human_readable_duration(duration: float, unit: str = "ns", separator: str = "", precision: int | None = None) -> str:
230 """Formats a duration in human units, automatically scaling as needed; for example "567ms"."""
231 sign = "-" if duration < 0 else ""
232 t = abs(duration)
233 units = ("ns", "μs", "ms", "s", "m", "h", "d")
234 i = units.index(unit)
235 if t < 1 and t != 0:
236 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000)
237 t *= nanos[i]
238 i = 0
239 while t >= 1000 and i < 3:
240 t /= 1000
241 i += 1
242 if i >= 3:
243 while t >= 60 and i < 5:
244 t /= 60
245 i += 1
246 if i >= 5:
247 while t >= 24 and i < len(units) - 1:
248 t /= 24
249 i += 1
250 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}"
251 return f"{sign}{formatted_num}{separator}{units[i]}"
254def human_readable_float(number: float) -> str:
255 """Formats ``number`` with a variable precision depending on magnitude.
257 This design mirrors the way humans round values when scanning logs.
259 If the number has one digit before the decimal point (0 <= abs(number) < 10):
260 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15").
262 If the number has two digits before the decimal point (10 <= abs(number) < 100):
263 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4").
265 If the number has three or more digits before the decimal point (abs(number) >= 100):
266 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124").
268 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1"
269 """
270 abs_number = abs(number)
271 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0
272 if precision == 0:
273 return str(round(number))
274 result = f"{number:.{precision}f}"
275 assert "." in result
276 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty
277 return "0" if result == "-0" else result
280def percent(number: int, total: int, print_total: bool = False) -> str:
281 """Returns percentage string of ``number`` relative to ``total``."""
282 tot: str = f"/{total}" if print_total else ""
283 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%"
286def open_nofollow(
287 path: str,
288 mode: str = "r",
289 buffering: int = -1,
290 encoding: str | None = None,
291 errors: str | None = None,
292 newline: str | None = None,
293 *,
294 perm: int = FILE_PERMISSIONS,
295 check_owner: bool = True,
296 **kwargs: Any,
297) -> IO[Any]:
298 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with
299 errno.ELOOP/EMLINK if basename of path is a symlink.
301 Also, can specify permissions on O_CREAT, and verify ownership.
302 """
303 if not mode:
304 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus")
305 flags = {
306 "r": os.O_RDONLY,
307 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
308 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND,
309 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL,
310 }.get(mode[0])
311 if flags is None:
312 raise ValueError(f"invalid mode {mode!r}")
313 if "+" in mode: # enable read-write access for r+, w+, a+, x+
314 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags
315 flags |= os.O_NOFOLLOW | os.O_CLOEXEC
316 fd: int = os.open(path, flags=flags, mode=perm)
317 try:
318 if check_owner:
319 st_uid: int = os.fstat(fd).st_uid
320 if st_uid != os.geteuid(): # verify ownership is current effective UID
321 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path)
322 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs)
323 except Exception:
324 try:
325 os.close(fd)
326 except OSError:
327 pass
328 raise
331def close_quietly(fd: int) -> None:
332 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this."""
333 if fd >= 0:
334 try:
335 os.close(fd)
336 except OSError:
337 pass
340P = TypeVar("P")
343def find_match(
344 seq: Sequence[P],
345 predicate: Callable[[P], bool],
346 start: int | None = None,
347 end: int | None = None,
348 reverse: bool = False,
349 raises: bool | str | Callable[[], str] = False, # raises: bool | str | Callable = False, # python >= 3.10
350) -> int:
351 """Returns the integer index within seq of the first item (or last item if reverse==True) that matches the given
352 predicate condition. If no matching item is found returns -1 or ValueError, depending on the raises parameter, which is a
353 bool indicating whether to raise an error, or a string containing the error message, but can also be a Callable/lambda in
354 order to support efficient deferred generation of error messages. Analog to str.find(), including slicing semantics with
355 parameters start and end. For example, seq can be a list, tuple or str.
357 Example usage:
358 lst = ["a", "b", "-c", "d"]
359 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True)
360 if i >= 0:
361 ...
362 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}")
363 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}")
364 """
365 offset: int = 0 if start is None else start if start >= 0 else len(seq) + start
366 if start is not None or end is not None:
367 seq = seq[start:end]
368 for i, item in enumerate(reversed(seq) if reverse else seq):
369 if predicate(item):
370 if reverse:
371 return len(seq) - i - 1 + offset
372 else:
373 return i + offset
374 if raises is False or raises is None:
375 return -1
376 if raises is True:
377 raise ValueError("No matching item found in sequence")
378 if callable(raises):
379 raises = raises()
380 raise ValueError(raises)
383def is_descendant(dataset: str, of_root_dataset: str) -> bool:
384 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same."""
385 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/")
388def has_duplicates(sorted_list: list[Any]) -> bool:
389 """Returns True if any adjacent items within the given sorted sequence are equal."""
390 return any(a == b for a, b in zip(sorted_list, sorted_list[1:]))
393def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool:
394 """Returns whether the (sorted) list of ZFS input datasets contains any siblings."""
395 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted"
396 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates"
397 skip_dataset: str = DONT_SKIP_DATASET
398 parents: set[str] = set()
399 for dataset in sorted_datasets:
400 assert dataset
401 parent = os.path.dirname(dataset)
402 if parent in parents:
403 return True # I have a sibling if my parent already has another child
404 parents.add(parent)
405 if is_descendant(dataset, of_root_dataset=skip_dataset):
406 continue
407 if skip_dataset != DONT_SKIP_DATASET:
408 return True # I have a sibling if I am a root dataset and another root dataset already exists
409 skip_dataset = dataset
410 return False
413def dry(msg: str, is_dry_run: bool) -> str:
414 """Prefix ``msg`` with 'Dry' when in dry-run mode."""
415 return "Dry " + msg if is_dry_run else msg
418def relativize_dataset(dataset: str, root_dataset: str) -> str:
419 """Converts an absolute dataset path to one relative to ``root_dataset``.
421 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz.
422 """
423 return dataset[len(root_dataset) :]
426def dataset_paths(dataset: str) -> Iterator[str]:
427 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c"."""
428 i: int = 0
429 while i >= 0:
430 i = dataset.find("/", i)
431 if i < 0:
432 yield dataset
433 else:
434 yield dataset[:i]
435 i += 1
438def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str:
439 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present."""
440 assert s.startswith(old_prefix)
441 return new_prefix + s[len(old_prefix) :]
444def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None:
445 """Replaces ``old`` with ``new`` in-place for every string in ``lines``."""
446 for i in range(len(lines)):
447 lines[i] = lines[i].replace(old, new, count)
450TAPPEND = TypeVar("TAPPEND")
453def append_if_absent(lst: list[TAPPEND], *items: TAPPEND) -> list[TAPPEND]:
454 """Appends items to list if they are not already present."""
455 for item in items:
456 if item not in lst:
457 lst.append(item)
458 return lst
461def xappend(lst: list[TAPPEND], *items: TAPPEND | Iterable[TAPPEND]) -> list[TAPPEND]:
462 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If
463 an item is an iterable does so recursively, flattening the output."""
464 for item in items:
465 if isinstance(item, str) or not isinstance(item, collections.abc.Iterable):
466 if item:
467 lst.append(item)
468 else:
469 xappend(lst, *item)
470 return lst
473def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool:
474 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False.
476 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match.
477 """
478 for regex, is_negation in exclude_regexes:
479 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
480 if is_negation:
481 is_match = not is_match
482 if is_match:
483 return False
485 for regex, is_negation in include_regexes:
486 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
487 if is_negation:
488 is_match = not is_match
489 if is_match:
490 return True
492 return False
495def compile_regexes(regexes: list[str], suffix: str = "") -> RegexList:
496 """Compiles regex strings and keeps track of negations."""
497 assert isinstance(regexes, list)
498 compiled_regexes: RegexList = []
499 for regex in regexes:
500 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match
501 if regex.endswith("\\$"):
502 pass # trailing literal $ is ok
503 elif regex.endswith("$"):
504 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch()
505 elif "$" in regex:
506 raise re.error("Must not use non-trailing '$' character", regex)
507 if is_negation := regex.startswith("!"):
508 regex = regex[1:]
509 regex = replace_capturing_groups_with_non_capturing_groups(regex)
510 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")):
511 regex = f"{regex}{suffix}"
512 compiled_regexes.append((re.compile(regex), is_negation))
513 return compiled_regexes
516def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any:
517 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels."""
519 class CustomListFormatter:
520 """Formatter object that joins items when converted to ``str``."""
522 def __str__(self) -> str:
523 s = separator.join(map(str, iterable))
524 return s.lstrip() if lstrip else s
526 return CustomListFormatter()
529def pretty_print_formatter(obj_to_format: Any) -> Any:
530 """Lazy pprint formatter used to avoid overhead in disabled log levels."""
532 class PrettyPrintFormatter:
533 """Formatter that pretty-prints the object on conversion to ``str``."""
535 def __str__(self) -> str:
536 import pprint # lazy import for startup perf
538 return pprint.pformat(vars(obj_to_format))
540 return PrettyPrintFormatter()
543def stderr_to_str(stderr: Any) -> str:
544 """Workaround for https://github.com/python/cpython/issues/87597."""
545 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace")
548def xprint(log: logging.Logger, value: Any, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None:
549 """Optionally logs ``value`` at stdout/stderr level."""
550 if run and value:
551 value = value if end else str(value).rstrip()
552 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR
553 log.log(level, "%s", value)
556def sha256_hex(text: str) -> str:
557 """Returns the sha256 hex string for the given text."""
558 return hashlib.sha256(text.encode()).hexdigest()
561def sha256_urlsafe_base64(text: str, padding: bool = True) -> str:
562 """Returns the URL-safe base64-encoded sha256 value for the given text."""
563 digest: bytes = hashlib.sha256(text.encode()).digest()
564 s: str = base64.urlsafe_b64encode(digest).decode()
565 return s if padding else s.rstrip("=")
568def sha256_128_urlsafe_base64(text: str) -> str:
569 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
570 s: str = sha256_urlsafe_base64(text, padding=False)
571 return s[: len(s) // 2]
574def sha256_85_urlsafe_base64(text: str) -> str:
575 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
576 s: str = sha256_urlsafe_base64(text, padding=False)
577 return s[: len(s) // 3]
580def urlsafe_base64(
581 value: int, max_value: int = 2**64 - 1, padding: bool = True, byteorder: Literal["little", "big"] = "big"
582) -> str:
583 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value]."""
584 assert 0 <= value <= max_value
585 max_bytes: int = (max_value.bit_length() + 7) // 8
586 value_bytes: bytes = value.to_bytes(max_bytes, byteorder)
587 s: str = base64.urlsafe_b64encode(value_bytes).decode()
588 return s if padding else s.rstrip("=")
591def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn:
592 """Exits the program with ``exit_code`` after logging ``msg``."""
593 if parser is None:
594 ex = SystemExit(msg)
595 ex.code = exit_code
596 raise ex
597 else:
598 parser.error(msg)
601def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
602 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and
603 provides optional child PID tracking."""
604 input_value = kwargs.pop("input", None)
605 timeout = kwargs.pop("timeout", None)
606 check = kwargs.pop("check", False)
607 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None)
608 if input_value is not None:
609 if kwargs.get("stdin") is not None:
610 raise ValueError("input and stdin are mutually exclusive")
611 kwargs["stdin"] = subprocess.PIPE
613 pid: int | None = None
614 try:
615 with subprocess.Popen(*args, **kwargs) as proc:
616 pid = proc.pid
617 if subprocesses is not None:
618 subprocesses.register_child_pid(pid)
619 try:
620 stdout, stderr = proc.communicate(input_value, timeout=timeout)
621 except BaseException as e:
622 try:
623 if isinstance(e, subprocess.TimeoutExpired):
624 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child process and its descendants
625 finally:
626 proc.kill()
627 raise
628 else:
629 exitcode: int | None = proc.poll()
630 assert exitcode is not None
631 if check and exitcode:
632 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr)
633 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr)
634 finally:
635 if subprocesses is not None and isinstance(pid, int):
636 subprocesses.unregister_child_pid(pid)
639def terminate_process_subtree(
640 except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM
641) -> None:
642 """For each root PID: Sends the given signal to the root PID and all its descendant processes."""
643 current_pid: int = os.getpid()
644 root_pids = [current_pid] if root_pids is None else root_pids
645 all_pids: list[list[int]] = _get_descendant_processes(root_pids)
646 assert len(all_pids) == len(root_pids)
647 for i, pids in enumerate(all_pids):
648 root_pid = root_pids[i]
649 if root_pid == current_pid:
650 pids += [] if except_current_process else [current_pid]
651 else:
652 pids.insert(0, root_pid)
653 for pid in pids:
654 with contextlib.suppress(OSError):
655 os.kill(pid, sig)
658def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]:
659 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems."""
660 if len(root_pids) == 0:
661 return []
662 cmd: list[str] = ["ps", "-Ao", "pid,ppid"]
663 try:
664 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines()
665 except PermissionError:
666 # degrade gracefully in sandbox environments that deny executing `ps` entirely
667 return [[] for _ in root_pids]
668 procs: dict[int, list[int]] = defaultdict(list)
669 for line in lines[1:]: # all lines except the header line
670 splits: list[str] = line.split()
671 assert len(splits) == 2
672 pid = int(splits[0])
673 ppid = int(splits[1])
674 procs[ppid].append(pid)
676 def recursive_append(ppid: int, descendants: list[int]) -> None:
677 """Recursively collect descendant PIDs starting from ``ppid``."""
678 for child_pid in procs[ppid]:
679 descendants.append(child_pid)
680 recursive_append(child_pid, descendants)
682 all_descendants: list[list[int]] = []
683 for root_pid in root_pids:
684 descendants: list[int] = []
685 recursive_append(root_pid, descendants)
686 all_descendants.append(descendants)
687 return all_descendants
690@contextlib.contextmanager
691def termination_signal_handler(
692 termination_event: threading.Event,
693 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(),
694) -> Iterator[None]:
695 """Context manager that installs SIGINT/SIGTERM handlers that set ``termination_event`` and, by default, terminate all
696 descendant processes."""
697 assert termination_event is not None
699 def _handler(_sig: int, _frame: object) -> None:
700 termination_event.set()
701 termination_handler()
703 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler
704 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler
705 try:
706 yield # run body of context manager
707 finally:
708 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler
709 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler
712#############################################################################
713class Subprocesses:
714 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used
715 when multiple jobs run concurrently within the same Python process."""
717 def __init__(self) -> None:
718 self._lock: Final[threading.Lock] = threading.Lock()
719 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order
721 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
722 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination."""
723 return subprocess_run(*args, **kwargs, subprocesses=self)
725 def register_child_pid(self, pid: int) -> None:
726 """Registers a child PID as managed by this instance."""
727 with self._lock:
728 self._child_pids[pid] = None
730 def unregister_child_pid(self, pid: int) -> None:
731 """Unregisters a child PID that has exited or is no longer tracked."""
732 with self._lock:
733 self._child_pids.pop(pid, None)
735 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None:
736 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs."""
737 with self._lock:
738 pids: list[int] = list(self._child_pids)
739 self._child_pids.clear()
740 terminate_process_subtree(root_pids=pids, sig=sig)
743#############################################################################
744def pid_exists(pid: int) -> bool | None:
745 """Returns True if a process with PID exists, False if not, or None on error."""
746 if pid <= 0:
747 return False
748 try: # with signal=0, no signal is actually sent, but error checking is still performed
749 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems
750 except OSError as err:
751 if err.errno == errno.ESRCH: # No such process
752 return False
753 if err.errno == errno.EPERM: # Operation not permitted
754 return True
755 return None
756 return True
759def nprefix(s: str) -> str:
760 """Returns a canonical snapshot prefix with trailing underscore."""
761 return sys.intern(s + "_")
764def ninfix(s: str) -> str:
765 """Returns a canonical infix with trailing underscore when not empty."""
766 return sys.intern(s + "_") if s else ""
769def nsuffix(s: str) -> str:
770 """Returns a canonical suffix with leading underscore when not empty."""
771 return sys.intern("_" + s) if s else ""
774def format_dict(dictionary: dict[Any, Any]) -> str:
775 """Returns a formatted dictionary using repr for consistent output."""
776 return f'"{dictionary}"'
779def format_obj(obj: object) -> str:
780 """Returns a formatted str using repr for consistent output."""
781 return f'"{obj}"'
784def validate_dataset_name(dataset: str, input_text: str) -> None:
785 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc."""
786 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424
787 # and https://github.com/openzfs/zfs/issues/8798
788 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html
789 invalid_chars: str = SHELL_CHARS
790 if (
791 dataset in ("", ".", "..")
792 or dataset.startswith(("/", "./", "../"))
793 or dataset.endswith(("/", "/.", "/.."))
794 or any(substring in dataset for substring in ("//", "/./", "/../"))
795 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset)
796 or not dataset[0].isalpha()
797 ):
798 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'")
801def validate_property_name(propname: str, input_text: str) -> str:
802 """Checks that the ZFS property name contains no spaces or shell chars."""
803 invalid_chars: str = SHELL_CHARS
804 if not propname or any(char.isspace() or char in invalid_chars for char in propname):
805 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'")
806 return propname
809def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None:
810 """Checks that the given path is not a symbolic link."""
811 if os.path.islink(path):
812 die(f"{msg}must not be a symlink: {path}", parser=parser)
815def validate_file_permissions(path: str, mode: int) -> None:
816 """Verify permissions and that ownership is current effective UID."""
817 stats: os.stat_result = os.stat(path, follow_symlinks=False)
818 st_uid: int = stats.st_uid
819 if st_uid != os.geteuid(): # verify ownership is current effective UID
820 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}")
821 st_mode = stat.S_IMODE(stats.st_mode)
822 if st_mode != mode:
823 die(
824 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, "
825 f"not {mode:03o} aka {stat.filemode(mode)[1:]})"
826 )
829def parse_duration_to_milliseconds(duration: str, regex_suffix: str = "", context: str = "") -> int:
830 """Parses human duration strings like '5m' or '2 hours' to milliseconds."""
831 unit_milliseconds: dict[str, int] = {
832 "milliseconds": 1,
833 "millis": 1,
834 "seconds": 1000,
835 "secs": 1000,
836 "minutes": 60 * 1000,
837 "mins": 60 * 1000,
838 "hours": 60 * 60 * 1000,
839 "days": 86400 * 1000,
840 "weeks": 7 * 86400 * 1000,
841 "months": round(30.5 * 86400 * 1000),
842 "years": 365 * 86400 * 1000,
843 }
844 match = re.fullmatch(
845 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix,
846 duration,
847 )
848 if not match:
849 if context:
850 die(f"Invalid duration format: {duration} within {context}")
851 else:
852 raise ValueError(f"Invalid duration format: {duration}")
853 assert match
854 quantity: int = int(match.group(1))
855 unit: str = match.group(2)
856 return quantity * unit_milliseconds[unit]
859def unixtime_fromisoformat(datetime_str: str) -> int:
860 """Converts ISO 8601 datetime string into UTC Unix time seconds."""
861 return int(datetime.fromisoformat(datetime_str).timestamp())
864def isotime_from_unixtime(unixtime_in_seconds: int) -> str:
865 """Converts UTC Unix time seconds into ISO 8601 datetime string."""
866 tz: tzinfo = timezone.utc
867 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz)
868 return dt.isoformat(sep="_", timespec="seconds")
871def current_datetime(
872 tz_spec: str | None = None,
873 now_fn: Callable[[tzinfo | None], datetime] | None = None,
874) -> datetime:
875 """Returns current time in ``tz_spec`` timezone or local timezone."""
876 if now_fn is None:
877 now_fn = datetime.now
878 return now_fn(get_timezone(tz_spec))
881def get_timezone(tz_spec: str | None = None) -> tzinfo | None:
882 """Returns timezone from spec or local timezone if unspecified."""
883 tz: tzinfo | None
884 if tz_spec is None:
885 tz = None
886 elif tz_spec == "UTC":
887 tz = timezone.utc
888 else:
889 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec):
890 sign, hours, minutes = match.groups()
891 offset: int = int(hours) * 60 + int(minutes)
892 offset = -offset if sign == "-" else offset
893 tz = timezone(timedelta(minutes=offset))
894 elif "/" in tz_spec:
895 from zoneinfo import ZoneInfo # lazy import for startup perf
897 tz = ZoneInfo(tz_spec)
898 else:
899 raise ValueError(f"Invalid timezone specification: {tz_spec}")
900 return tz
903###############################################################################
904class SnapshotPeriods: # thread-safe
905 """Parses snapshot suffix strings and converts between durations."""
907 def __init__(self) -> None:
908 """Initialize lookup tables of suffixes and corresponding millis."""
909 self.suffix_milliseconds: Final[dict[str, int]] = {
910 "yearly": 365 * 86400 * 1000,
911 "monthly": round(30.5 * 86400 * 1000),
912 "weekly": 7 * 86400 * 1000,
913 "daily": 86400 * 1000,
914 "hourly": 60 * 60 * 1000,
915 "minutely": 60 * 1000,
916 "secondly": 1000,
917 "millisecondly": 1,
918 }
919 self.period_labels: Final[dict[str, str]] = {
920 "yearly": "years",
921 "monthly": "months",
922 "weekly": "weeks",
923 "daily": "days",
924 "hourly": "hours",
925 "minutely": "minutes",
926 "secondly": "seconds",
927 "millisecondly": "milliseconds",
928 }
929 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})")
930 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern)
932 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]:
933 """Parse suffix like '10minutely' to (10, 'minutely')."""
934 return self._suffix_to_duration(suffix, self._suffix_regex0)
936 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]:
937 """Like :meth:`suffix_to_duration0` but expects an underscore prefix."""
938 return self._suffix_to_duration(suffix, self._suffix_regex1)
940 @staticmethod
941 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]:
942 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly')."""
943 if match := regex.fullmatch(suffix):
944 duration_amount: int = int(match.group(1)) if match.group(1) else 1
945 assert duration_amount > 0
946 duration_unit: str = match.group(2)
947 return duration_amount, duration_unit
948 else:
949 return 0, ""
951 def label_milliseconds(self, snapshot: str) -> int:
952 """Returns duration encoded in ``snapshot`` suffix, in milliseconds."""
953 i = snapshot.rfind("_")
954 snapshot = "" if i < 0 else snapshot[i + 1 :]
955 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0)
956 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0)
959#############################################################################
960class JobStats:
961 """Simple thread-safe counters summarizing job progress."""
963 def __init__(self, jobs_all: int) -> None:
964 assert jobs_all >= 0
965 self.lock: Final[threading.Lock] = threading.Lock()
966 self.jobs_all: int = jobs_all
967 self.jobs_started: int = 0
968 self.jobs_completed: int = 0
969 self.jobs_failed: int = 0
970 self.jobs_running: int = 0
971 self.sum_elapsed_nanos: int = 0
972 self.started_job_names: Final[set[str]] = set()
974 def submit_job(self, job_name: str) -> str:
975 """Counts a job submission."""
976 with self.lock:
977 self.jobs_started += 1
978 self.jobs_running += 1
979 self.started_job_names.add(job_name)
980 return str(self)
982 def complete_job(self, failed: bool, elapsed_nanos: int) -> str:
983 """Counts a job completion."""
984 assert elapsed_nanos >= 0
985 with self.lock:
986 self.jobs_running -= 1
987 self.jobs_completed += 1
988 self.jobs_failed += 1 if failed else 0
989 self.sum_elapsed_nanos += elapsed_nanos
990 msg = str(self)
991 assert self.sum_elapsed_nanos >= 0, msg
992 assert self.jobs_running >= 0, msg
993 assert self.jobs_failed >= 0, msg
994 assert self.jobs_failed <= self.jobs_completed, msg
995 assert self.jobs_completed <= self.jobs_started, msg
996 assert self.jobs_started <= self.jobs_all, msg
997 return msg
999 def __repr__(self) -> str:
1000 def pct(number: int) -> str:
1001 """Returns percentage string relative to total jobs."""
1002 return percent(number, total=self.jobs_all, print_total=True)
1004 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed
1005 running = self.jobs_running
1006 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed))
1007 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}"
1010#############################################################################
1011class Comparable(Protocol):
1012 """Partial ordering protocol."""
1014 def __lt__(self, other: Any) -> bool: # pragma: no cover - behavior defined by implementer
1015 ...
1018T = TypeVar("T", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue
1021class SmallPriorityQueue(Generic[T]):
1022 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and
1023 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case
1024 for us.
1026 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed
1027 priority queue via
1028 https://github.com/nvictus/pqdict.
1029 But, to avoid an external dependency, is actually implemented
1030 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that
1031 are already contained in the queue, via removal of the element, followed by update of the element, followed by
1032 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates.
1033 """
1035 def __init__(self, reverse: bool = False) -> None:
1036 """Creates an empty queue; sort order flips when ``reverse`` is True."""
1037 self._lst: Final[list[T]] = []
1038 self._reverse: Final[bool] = reverse
1040 def clear(self) -> None:
1041 """Removes all elements from the queue."""
1042 self._lst.clear()
1044 def push(self, element: T) -> None:
1045 """Inserts ``element`` while maintaining sorted order."""
1046 bisect.insort(self._lst, element)
1048 def pop(self) -> T:
1049 """Removes and returns the smallest (or largest if reverse == True) element from the queue."""
1050 return self._lst.pop() if self._reverse else self._lst.pop(0)
1052 def peek(self) -> T:
1053 """Returns the smallest (or largest if reverse == True) element without removing it."""
1054 return self._lst[-1] if self._reverse else self._lst[0]
1056 def remove(self, element: T) -> bool:
1057 """Removes the first occurrence of ``element`` and returns True if it was present."""
1058 lst = self._lst
1059 i = bisect.bisect_left(lst, element)
1060 is_contained = i < len(lst) and lst[i] == element
1061 if is_contained:
1062 del lst[i] # is an optimized memmove()
1063 return is_contained
1065 def __len__(self) -> int:
1066 """Returns the number of queued elements."""
1067 return len(self._lst)
1069 def __contains__(self, element: T) -> bool:
1070 """Returns ``True`` if ``element`` is present."""
1071 lst = self._lst
1072 i = bisect.bisect_left(lst, element)
1073 return i < len(lst) and lst[i] == element
1075 def __iter__(self) -> Iterator[T]:
1076 """Iterates over queued elements in priority order."""
1077 return reversed(self._lst) if self._reverse else iter(self._lst)
1079 def __repr__(self) -> str:
1080 """Representation showing queue contents in current order."""
1081 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst)
1084###############################################################################
1085class SortedInterner(Generic[T]):
1086 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search)."""
1088 def __init__(self, sorted_list: list[T]) -> None:
1089 self._lst: Final[list[T]] = sorted_list
1091 def interned(self, element: T) -> T:
1092 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1093 lst = self._lst
1094 i = binary_search(lst, element)
1095 return lst[i] if i >= 0 else element
1097 def __contains__(self, element: T) -> bool:
1098 """Returns ``True`` if ``element`` is present."""
1099 return binary_search(self._lst, element) >= 0
1102def binary_search(sorted_list: list[T], item: T) -> int:
1103 """Java-style binary search; Returns index >=0 if an equal item is found in list, else '-insertion_point-1'; If it
1104 returns index >=0, the index will be the left-most index in case multiple such equal items are contained."""
1105 i = bisect.bisect_left(sorted_list, item)
1106 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1
1109###############################################################################
1110S = TypeVar("S")
1113class Interner(Generic[S]):
1114 """Same as sys.intern() except that it isn't global and can also be used for types other than str."""
1116 def __init__(self, items: Iterable[S] = frozenset()) -> None:
1117 self._items: Final[dict[S, S]] = {v: v for v in items}
1119 def intern(self, item: S) -> S:
1120 """Interns the given item."""
1121 return self._items.setdefault(item, item)
1123 def interned(self, item: S) -> S:
1124 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1125 return self._items.get(item, item)
1127 def __contains__(self, item: S) -> bool:
1128 return item in self._items
1131#############################################################################
1132class SynchronizedBool:
1133 """Thread-safe wrapper around a regular bool."""
1135 def __init__(self, val: bool) -> None:
1136 assert isinstance(val, bool)
1137 self._lock: Final[threading.Lock] = threading.Lock()
1138 self._value: bool = val
1140 @property
1141 def value(self) -> bool:
1142 """Returns the current boolean value."""
1143 with self._lock:
1144 return self._value
1146 @value.setter
1147 def value(self, new_value: bool) -> None:
1148 """Atomically assign ``new_value``."""
1149 with self._lock:
1150 self._value = new_value
1152 def get_and_set(self, new_value: bool) -> bool:
1153 """Swaps in ``new_value`` and return the previous value."""
1154 with self._lock:
1155 old_value = self._value
1156 self._value = new_value
1157 return old_value
1159 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool:
1160 """Sets to ``new_value`` only if current value equals ``expected_value``."""
1161 with self._lock:
1162 eq: bool = self._value == expected_value
1163 if eq:
1164 self._value = new_value
1165 return eq
1167 def __bool__(self) -> bool:
1168 return self.value
1170 def __repr__(self) -> str:
1171 return repr(self.value)
1173 def __str__(self) -> str:
1174 return str(self.value)
1177#############################################################################
1178K = TypeVar("K")
1179V = TypeVar("V")
1182class SynchronizedDict(Generic[K, V]):
1183 """Thread-safe wrapper around a regular dict."""
1185 def __init__(self, val: dict[K, V]) -> None:
1186 assert isinstance(val, dict)
1187 self._lock: Final[threading.Lock] = threading.Lock()
1188 self._dict: Final[dict[K, V]] = val
1190 def __getitem__(self, key: K) -> V:
1191 with self._lock:
1192 return self._dict[key]
1194 def __setitem__(self, key: K, value: V) -> None:
1195 with self._lock:
1196 self._dict[key] = value
1198 def __delitem__(self, key: K) -> None:
1199 with self._lock:
1200 self._dict.pop(key)
1202 def __contains__(self, key: K) -> bool:
1203 with self._lock:
1204 return key in self._dict
1206 def __len__(self) -> int:
1207 with self._lock:
1208 return len(self._dict)
1210 def __repr__(self) -> str:
1211 with self._lock:
1212 return repr(self._dict)
1214 def __str__(self) -> str:
1215 with self._lock:
1216 return str(self._dict)
1218 def get(self, key: K, default: V | None = None) -> V | None:
1219 """Returns ``self[key]`` or ``default`` if missing."""
1220 with self._lock:
1221 return self._dict.get(key, default)
1223 def pop(self, key: K, default: V | None = None) -> V | None:
1224 """Removes ``key`` and returns its value."""
1225 with self._lock:
1226 return self._dict.pop(key, default)
1228 def clear(self) -> None:
1229 """Removes all items atomically."""
1230 with self._lock:
1231 self._dict.clear()
1233 def items(self) -> ItemsView[K, V]:
1234 """Returns a snapshot of dictionary items."""
1235 with self._lock:
1236 return self._dict.copy().items()
1239#############################################################################
1240class InterruptibleSleep:
1241 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable."""
1243 def __init__(self, lock: threading.Lock | None = None) -> None:
1244 self._is_stopping: bool = False
1245 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock()
1246 self._condition: Final[threading.Condition] = threading.Condition(self._lock)
1248 def sleep(self, duration_nanos: int) -> bool:
1249 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted;
1250 Equivalent to threading.Event.wait()."""
1251 end_time_nanos: int = time.monotonic_ns() + duration_nanos
1252 with self._lock:
1253 while not self._is_stopping:
1254 diff_nanos: int = end_time_nanos - time.monotonic_ns()
1255 if diff_nanos <= 0:
1256 return False
1257 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout
1258 return True
1260 def interrupt(self) -> None:
1261 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set()."""
1262 with self._lock:
1263 if not self._is_stopping:
1264 self._is_stopping = True
1265 self._condition.notify_all()
1267 def reset(self) -> None:
1268 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear()."""
1269 with self._lock:
1270 self._is_stopping = False
1273#############################################################################
1274class SynchronousExecutor(Executor):
1275 """Executor that runs tasks inline in the calling thread, sequentially."""
1277 def __init__(self) -> None:
1278 self._shutdown: bool = False
1280 def submit(self, fn: Callable[..., R_], /, *args: Any, **kwargs: Any) -> Future[R_]:
1281 """Executes `fn(*args, **kwargs)` immediately and returns its Future."""
1282 future: Future[R_] = Future()
1283 if self._shutdown:
1284 raise RuntimeError("cannot schedule new futures after shutdown")
1285 try:
1286 result: R_ = fn(*args, **kwargs)
1287 except BaseException as exc:
1288 future.set_exception(exc)
1289 else:
1290 future.set_result(result)
1291 return future
1293 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
1294 """Prevents new submissions; no worker resources to join/cleanup."""
1295 self._shutdown = True
1297 @classmethod
1298 def executor_for(cls, max_workers: int) -> Executor:
1299 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor."""
1300 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers)
1303#############################################################################
1304class _XFinally(contextlib.AbstractContextManager):
1305 """Context manager ensuring cleanup code executes after ``with`` blocks."""
1307 def __init__(self, cleanup: Callable[[], None]) -> None:
1308 """Records the callable to run upon exit."""
1309 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits.
1311 def __exit__(
1312 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None
1313 ) -> Literal[False]:
1314 """Runs cleanup and propagate any exceptions appropriately."""
1315 try:
1316 self._cleanup()
1317 except BaseException as cleanup_exc:
1318 if exc is None:
1319 raise # No main error --> propagate cleanup error normally
1320 # Both failed
1321 # if sys.version_info >= (3, 11):
1322 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None
1323 # <= 3.10: attach so it shows up in traceback but doesn't mask
1324 exc.__context__ = cleanup_exc
1325 return False # reraise original exception
1326 return False # propagate main exception if any
1329def xfinally(cleanup: Callable[[], None]) -> _XFinally:
1330 """Usage: with xfinally(lambda: cleanup()): ...
1331 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never
1332 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible.
1334 Problem it solves
1335 -----------------
1336 A naive ``try ... finally`` may lose the original exception:
1338 try:
1339 work()
1340 finally:
1341 cleanup() # <-- if this raises an exception, it replaces the real error!
1343 `_XFinally` preserves exception priority:
1345 * Body raises, cleanup succeeds --> original body exception is re-raised.
1346 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``.
1347 * Body succeeds, cleanup raises --> cleanup exception propagates normally.
1349 Example:
1350 -------
1351 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP
1352 ... run_tasks()
1354 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information.
1355 """
1356 return _XFinally(cleanup)