Coverage for bzfs_main / util / utils.py: 100%
759 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 08:03 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 08:03 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight
16concurrency primitives, etc.
18Everything in this module relies only on the Python standard library so other modules remain dependency free. Each utility
19favors simple, predictable behavior on all supported platforms.
20"""
22from __future__ import (
23 annotations,
24)
25import argparse
26import base64
27import bisect
28import collections
29import contextlib
30import errno
31import hashlib
32import itertools
33import logging
34import operator
35import os
36import platform
37import pwd
38import random
39import re
40import signal
41import stat
42import subprocess
43import sys
44import threading
45import time
46import types
47from collections import (
48 defaultdict,
49 deque,
50)
51from collections.abc import (
52 ItemsView,
53 Iterable,
54 Iterator,
55 Sequence,
56)
57from concurrent.futures import (
58 Executor,
59 Future,
60 ThreadPoolExecutor,
61)
62from datetime import (
63 datetime,
64 timedelta,
65 timezone,
66 tzinfo,
67)
68from subprocess import (
69 DEVNULL,
70 PIPE,
71)
72from typing import (
73 IO,
74 Any,
75 Callable,
76 Final,
77 Generic,
78 Literal,
79 NoReturn,
80 Protocol,
81 TextIO,
82 TypeVar,
83 cast,
84 final,
85)
87# constants:
88PROG_NAME: Final[str] = "bzfs"
89ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_"
90DIE_STATUS: Final[int] = 3
91DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset
92LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between
93LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between
94LOG_DEBUG: Final[int] = logging.DEBUG
95LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between
96YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target
97UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID
98DONT_SKIP_DATASET: Final[str] = ""
99SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\"
100FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write)
101DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute)
102UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions
103UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path'
105RegexList = list[tuple[re.Pattern[str], bool]] # Type alias
108def getenv_any(key: str, default: str | None = None, env_var_prefix: str = ENV_VAR_PREFIX) -> str | None:
109 """All shell environment variable names used for configuration start with this prefix."""
110 return os.getenv(env_var_prefix + key, default)
113def getenv_int(key: str, default: int, env_var_prefix: str = ENV_VAR_PREFIX) -> int:
114 """Returns environment variable ``key`` as int with ``default`` fallback."""
115 return int(cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)))
118def getenv_bool(key: str, default: bool = False, env_var_prefix: str = ENV_VAR_PREFIX) -> bool:
119 """Returns environment variable ``key`` as bool with ``default`` fallback."""
120 return cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)).lower().strip() == "true"
123def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]:
124 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command."""
125 assert lines is not None
126 assert isinstance(lines, list)
127 assert len(separator) == 1
128 if field == 1:
129 return [line[0 : line.index(separator)] for line in lines]
130 elif field == 2:
131 return [line[line.index(separator) + 1 :] for line in lines]
132 else:
133 raise ValueError(f"Invalid field value: {field}")
136def drain(iterable: Iterable[Any]) -> None:
137 """Consumes all items in the iterable, effectively draining it."""
138 for _ in iterable:
139 del _ # help gc (iterable can block)
142_K_ = TypeVar("_K_")
143_V_ = TypeVar("_V_")
144_R_ = TypeVar("_R_")
147def shuffle_dict(dictionary: dict[_K_, _V_], rand: random.Random = random.SystemRandom()) -> dict[_K_, _V_]: # noqa: B008
148 """Returns a new dict with items shuffled randomly."""
149 items: list[tuple[_K_, _V_]] = list(dictionary.items())
150 rand.shuffle(items)
151 return dict(items)
154def sorted_dict(dictionary: dict[_K_, _V_]) -> dict[_K_, _V_]:
155 """Returns a new dict with items sorted primarily by key and secondarily by value."""
156 return dict(sorted(dictionary.items()))
159def tail(file: str, n: int, errors: str | None = None) -> Sequence[str]:
160 """Return the last ``n`` lines of ``file`` without following symlinks."""
161 if not os.path.isfile(file):
162 return []
163 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False) as fd:
164 return deque(fd, maxlen=n)
167_NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">"))
170def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str:
171 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky).
173 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\('
174 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash
175 with the replacement string '(?:'
177 Named capturing group example: '(?P<name>abc)' --> '(?:abc)'
178 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash
179 with the replacement string '(?:'
181 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups
182 """
183 if "(" in regex and (
184 "[" in regex # literal left square bracket
185 or "\\N{LEFT SQUARE BRACKET}" in regex # named Unicode escape for '['
186 or "\\x5b" in regex # hex escape for '[' (lowercase)
187 or "\\x5B" in regex # hex escape for '[' (uppercase)
188 or "\\u005b" in regex # 4-digit Unicode escape for '[' (lowercase)
189 or "\\u005B" in regex # 4-digit Unicode escape for '[' (uppercase)
190 or "\\U0000005b" in regex # 8-digit Unicode escape for '[' (lowercase)
191 or "\\U0000005B" in regex # 8-digit Unicode escape for '[' (uppercase)
192 or "\\133" in regex # octal escape for '['
193 ):
194 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the rare case where the regex might
195 # contain a pathological regex character class that contains parenthesis, or when '[' is expressed via escapes.
196 # Rewriting a regex is a performance optimization; correctness comes first.
197 return regex
199 i = len(regex) - 2
200 while i >= 0:
201 i = regex.rfind("(", 0, i + 1)
202 if i >= 0 and (i == 0 or regex[i - 1] != "\\"):
203 if regex[i + 1] != "?":
204 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group
205 else: # potentially a valid named capturing group
206 regex = regex[0:i] + _NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1)
207 i -= 1
208 return regex
211def get_home_directory() -> str:
212 """Reliably detects home dir without using HOME env var."""
213 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~')
214 return pwd.getpwuid(os.getuid()).pw_dir
217def human_readable_bytes(num_bytes: float, separator: str = " ", precision: int | None = None) -> str:
218 """Formats 'num_bytes' as a human-readable size; for example "567 MiB"."""
219 sign = "-" if num_bytes < 0 else ""
220 s = abs(num_bytes)
221 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB")
222 n = len(units) - 1
223 i = 0
224 while s >= 1024 and i < n:
225 s /= 1024
226 i += 1
227 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}"
228 return f"{sign}{formatted_num}{separator}{units[i]}"
231def human_readable_duration(duration: float, unit: str = "ns", separator: str = "", precision: int | None = None) -> str:
232 """Formats a duration in human units, automatically scaling as needed; for example "567ms"."""
233 sign = "-" if duration < 0 else ""
234 t = abs(duration)
235 units = ("ns", "μs", "ms", "s", "m", "h", "d")
236 i = units.index(unit)
237 if t < 1 and t != 0:
238 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000)
239 t *= nanos[i]
240 i = 0
241 while t >= 1000 and i < 3:
242 t /= 1000
243 i += 1
244 if i >= 3:
245 while t >= 60 and i < 5:
246 t /= 60
247 i += 1
248 if i >= 5:
249 while t >= 24 and i < len(units) - 1:
250 t /= 24
251 i += 1
252 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}"
253 return f"{sign}{formatted_num}{separator}{units[i]}"
256def human_readable_float(number: float) -> str:
257 """Formats ``number`` with a variable precision depending on magnitude.
259 This design mirrors the way humans round values when scanning logs.
261 If the number has one digit before the decimal point (0 <= abs(number) < 10):
262 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15").
264 If the number has two digits before the decimal point (10 <= abs(number) < 100):
265 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4").
267 If the number has three or more digits before the decimal point (abs(number) >= 100):
268 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124").
270 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1"
271 """
272 abs_number = abs(number)
273 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0
274 if precision == 0:
275 return str(round(number))
276 result = f"{number:.{precision}f}"
277 assert "." in result
278 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty
279 return "0" if result == "-0" else result
282def percent(number: int, total: int, print_total: bool = False) -> str:
283 """Returns percentage string of ``number`` relative to ``total``."""
284 tot: str = f"/{total}" if print_total else ""
285 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%"
288def open_nofollow(
289 path: str,
290 mode: str = "r",
291 buffering: int = -1,
292 encoding: str | None = None,
293 errors: str | None = None,
294 newline: str | None = None,
295 *,
296 perm: int = FILE_PERMISSIONS,
297 check_owner: bool = True,
298 **kwargs: Any,
299) -> IO[Any]:
300 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with
301 errno.ELOOP/EMLINK if basename of path is a symlink.
303 Also, can specify permissions on O_CREAT, and verify ownership.
304 """
305 if not mode:
306 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus")
307 flags = {
308 "r": os.O_RDONLY,
309 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
310 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND,
311 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL,
312 }.get(mode[0])
313 if flags is None:
314 raise ValueError(f"invalid mode {mode!r}")
315 if "+" in mode: # enable read-write access for r+, w+, a+, x+
316 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags
317 flags |= os.O_NOFOLLOW | os.O_CLOEXEC
318 fd: int = os.open(path, flags=flags, mode=perm)
319 try:
320 if check_owner:
321 st_uid: int = os.fstat(fd).st_uid
322 if st_uid != os.geteuid(): # verify ownership is current effective UID
323 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path)
324 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs)
325 except Exception:
326 try:
327 os.close(fd)
328 except OSError:
329 pass
330 raise
333def close_quietly(fd: int) -> None:
334 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this."""
335 if fd >= 0:
336 try:
337 os.close(fd)
338 except OSError:
339 pass
342_P = TypeVar("_P")
345def find_match(
346 seq: Sequence[_P],
347 predicate: Callable[[_P], bool],
348 start: int | None = None,
349 end: int | None = None,
350 reverse: bool = False,
351 raises: bool | str | Callable[[], str] = False, # raises: bool | str | Callable = False, # python >= 3.10
352) -> int:
353 """Returns the integer index within seq of the first item (or last item if reverse==True) that matches the given
354 predicate condition. If no matching item is found returns -1 or ValueError, depending on the raises parameter, which is a
355 bool indicating whether to raise an error, or a string containing the error message, but can also be a Callable/lambda in
356 order to support efficient deferred generation of error messages. Analog to str.find(), including slicing semantics with
357 parameters start and end. For example, seq can be a list, tuple or str.
359 Example usage:
360 lst = ["a", "b", "-c", "d"]
361 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True)
362 if i >= 0:
363 print(lst[i])
364 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}")
365 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}")
366 """
367 offset: int = 0 if start is None else start if start >= 0 else len(seq) + start
368 if start is not None or end is not None:
369 seq = seq[start:end]
370 for i, item in enumerate(reversed(seq) if reverse else seq):
371 if predicate(item):
372 if reverse:
373 return len(seq) - i - 1 + offset
374 else:
375 return i + offset
376 if raises is False or raises is None:
377 return -1
378 if raises is True:
379 raise ValueError("No matching item found in sequence")
380 if callable(raises):
381 raises = raises()
382 raise ValueError(raises)
385def is_descendant(dataset: str, of_root_dataset: str) -> bool:
386 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same."""
387 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/")
390def has_duplicates(sorted_list: list[Any]) -> bool:
391 """Returns True if any adjacent items within the given sorted sequence are equal."""
392 return any(map(operator.eq, sorted_list, itertools.islice(sorted_list, 1, None)))
395def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool:
396 """Returns whether the (sorted) list of ZFS input datasets contains any siblings."""
397 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted"
398 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates"
399 skip_dataset: str = DONT_SKIP_DATASET
400 parents: set[str] = set()
401 for dataset in sorted_datasets:
402 assert dataset
403 parent = os.path.dirname(dataset)
404 if parent in parents:
405 return True # I have a sibling if my parent already has another child
406 parents.add(parent)
407 if is_descendant(dataset, of_root_dataset=skip_dataset):
408 continue
409 if skip_dataset != DONT_SKIP_DATASET:
410 return True # I have a sibling if I am a root dataset and another root dataset already exists
411 skip_dataset = dataset
412 return False
415def dry(msg: str, is_dry_run: bool) -> str:
416 """Prefix ``msg`` with 'Dry' when in dry-run mode."""
417 return "Dry " + msg if is_dry_run else msg
420def relativize_dataset(dataset: str, root_dataset: str) -> str:
421 """Converts an absolute dataset path to one relative to ``root_dataset``.
423 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz.
424 """
425 return dataset[len(root_dataset) :]
428def dataset_paths(dataset: str) -> Iterator[str]:
429 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c"."""
430 i: int = 0
431 while i >= 0:
432 i = dataset.find("/", i)
433 if i < 0:
434 yield dataset
435 else:
436 yield dataset[:i]
437 i += 1
440def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str:
441 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present."""
442 assert s.startswith(old_prefix)
443 return new_prefix + s[len(old_prefix) :]
446def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None:
447 """Replaces ``old`` with ``new`` in-place for every string in ``lines``."""
448 for i in range(len(lines)):
449 lines[i] = lines[i].replace(old, new, count)
452_TAPPEND = TypeVar("_TAPPEND")
455def append_if_absent(lst: list[_TAPPEND], *items: _TAPPEND) -> list[_TAPPEND]:
456 """Appends items to list if they are not already present."""
457 for item in items:
458 if item not in lst:
459 lst.append(item)
460 return lst
463def xappend(lst: list[_TAPPEND], *items: _TAPPEND | Iterable[_TAPPEND]) -> list[_TAPPEND]:
464 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If
465 an item is an iterable does so recursively, flattening the output."""
466 for item in items:
467 if isinstance(item, str) or not isinstance(item, collections.abc.Iterable):
468 if item:
469 lst.append(item)
470 else:
471 xappend(lst, *item)
472 return lst
475def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool:
476 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False.
478 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match.
479 """
480 for regex, is_negation in exclude_regexes:
481 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
482 if is_negation:
483 is_match = not is_match
484 if is_match:
485 return False
487 for regex, is_negation in include_regexes:
488 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
489 if is_negation:
490 is_match = not is_match
491 if is_match:
492 return True
494 return False
497def compile_regexes(regexes: list[str], suffix: str = "") -> RegexList:
498 """Compiles regex strings and keeps track of negations."""
499 assert isinstance(regexes, list)
500 compiled_regexes: RegexList = []
501 for regex in regexes:
502 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match
503 if regex.endswith("\\$"):
504 pass # trailing literal $ is ok
505 elif regex.endswith("$"):
506 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch()
507 elif "$" in regex:
508 raise re.error("Must not use non-trailing '$' character", regex)
509 if is_negation := regex.startswith("!"):
510 regex = regex[1:]
511 regex = replace_capturing_groups_with_non_capturing_groups(regex)
512 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")):
513 regex = f"{regex}{suffix}"
514 compiled_regexes.append((re.compile(regex), is_negation))
515 return compiled_regexes
518def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any:
519 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels."""
521 @final
522 class CustomListFormatter:
523 """Formatter object that joins items when converted to ``str``."""
525 def __str__(self) -> str:
526 s = separator.join(map(str, iterable))
527 return s.lstrip() if lstrip else s
529 return CustomListFormatter()
532def pretty_print_formatter(obj_to_format: Any) -> Any:
533 """Lazy pprint formatter used to avoid overhead in disabled log levels."""
535 @final
536 class PrettyPrintFormatter:
537 """Formatter that pretty-prints the object on conversion to ``str``."""
539 def __str__(self) -> str:
540 import pprint # lazy import for startup perf
542 return pprint.pformat(vars(obj_to_format))
544 return PrettyPrintFormatter()
547def stderr_to_str(stderr: Any) -> str:
548 """Workaround for https://github.com/python/cpython/issues/87597."""
549 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace")
552def xprint(log: logging.Logger, value: Any, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None:
553 """Optionally logs ``value`` at stdout/stderr level."""
554 if run and value:
555 value = value if end else str(value).rstrip()
556 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR
557 log.log(level, "%s", value)
560def sha256_hex(text: str) -> str:
561 """Returns the sha256 hex string for the given text."""
562 return hashlib.sha256(text.encode()).hexdigest()
565def sha256_urlsafe_base64(text: str, padding: bool = True) -> str:
566 """Returns the URL-safe base64-encoded sha256 value for the given text."""
567 digest: bytes = hashlib.sha256(text.encode()).digest()
568 s: str = base64.urlsafe_b64encode(digest).decode()
569 return s if padding else s.rstrip("=")
572def sha256_128_urlsafe_base64(text: str) -> str:
573 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
574 s: str = sha256_urlsafe_base64(text, padding=False)
575 return s[: len(s) // 2]
578def sha256_85_urlsafe_base64(text: str) -> str:
579 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
580 s: str = sha256_urlsafe_base64(text, padding=False)
581 return s[: len(s) // 3]
584def urlsafe_base64(
585 value: int, max_value: int = 2**64 - 1, padding: bool = True, byteorder: Literal["little", "big"] = "big"
586) -> str:
587 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value]."""
588 assert 0 <= value <= max_value
589 max_bytes: int = (max_value.bit_length() + 7) // 8
590 value_bytes: bytes = value.to_bytes(max_bytes, byteorder)
591 s: str = base64.urlsafe_b64encode(value_bytes).decode()
592 return s if padding else s.rstrip("=")
595def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn:
596 """Exits the program with ``exit_code`` after logging ``msg``."""
597 if parser is None:
598 ex = SystemExit(msg)
599 ex.code = exit_code
600 raise ex
601 else:
602 parser.error(msg)
605def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
606 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and
607 provides optional child PID tracking, and optional logging of execution status via ``log`` and ``loglevel`` params."""
608 input_value = kwargs.pop("input", None)
609 timeout = kwargs.pop("timeout", None)
610 check = kwargs.pop("check", False)
611 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None)
612 if input_value is not None:
613 if kwargs.get("stdin") is not None:
614 raise ValueError("input and stdin are mutually exclusive")
615 kwargs["stdin"] = subprocess.PIPE
617 log: logging.Logger | None = kwargs.pop("log", None)
618 loglevel: int | None = kwargs.pop("loglevel", None)
619 start_time_nanos: int = time.monotonic_ns()
620 is_timeout: bool = False
621 is_cancel: bool = False
622 exitcode: int | None = None
624 def log_status() -> None:
625 if log is not None:
626 _loglevel: int = loglevel if loglevel is not None else getenv_int("subprocess_run_loglevel", LOG_TRACE)
627 if log.isEnabledFor(_loglevel):
628 elapsed_time: str = human_readable_float((time.monotonic_ns() - start_time_nanos) / 1_000_000) + "ms"
629 status = "cancel" if is_cancel else "timeout" if is_timeout else "success" if exitcode == 0 else "failure"
630 cmd = kwargs["args"] if "args" in kwargs else (args[0] if args else None)
631 cmd_str: str = " ".join(str(arg) for arg in iter(cmd)) if isinstance(cmd, (list, tuple)) else str(cmd)
632 log.log(_loglevel, f"Executed [{status}] [{elapsed_time}]: %s", cmd_str)
634 with xfinally(log_status):
635 ctx: contextlib.AbstractContextManager[subprocess.Popen]
636 if subprocesses is None:
637 ctx = subprocess.Popen(*args, **kwargs)
638 else:
639 ctx = subprocesses.popen_and_track(*args, **kwargs)
640 with ctx as proc:
641 try:
642 sp = subprocesses
643 if sp is not None and sp._termination_event.is_set(): # noqa: SLF001 # pylint: disable=protected-access
644 is_cancel = True
645 timeout = 0.0
646 stdout, stderr = proc.communicate(input_value, timeout=timeout)
647 except BaseException as e:
648 try:
649 if isinstance(e, subprocess.TimeoutExpired):
650 is_timeout = True
651 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child proc and descendants
652 finally:
653 proc.kill()
654 raise
655 else:
656 exitcode = proc.poll()
657 assert exitcode is not None
658 if check and exitcode:
659 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr)
660 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr)
663def terminate_process_subtree(
664 except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM
665) -> None:
666 """For each root PID: Sends the given signal to the root PID and all its descendant processes."""
667 current_pid: int = os.getpid()
668 root_pids = [current_pid] if root_pids is None else root_pids
669 all_pids: list[list[int]] = _get_descendant_processes(root_pids)
670 assert len(all_pids) == len(root_pids)
671 for i, pids in enumerate(all_pids):
672 root_pid = root_pids[i]
673 if root_pid == current_pid:
674 pids += [] if except_current_process else [current_pid]
675 else:
676 pids.insert(0, root_pid)
677 for pid in pids:
678 with contextlib.suppress(OSError):
679 os.kill(pid, sig)
682def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]:
683 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems."""
684 if len(root_pids) == 0:
685 return []
686 cmd: list[str] = ["ps", "-Ao", "pid,ppid"]
687 try:
688 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines()
689 except PermissionError:
690 # degrade gracefully in sandbox environments that deny executing `ps` entirely
691 return [[] for _ in root_pids]
692 procs: dict[int, list[int]] = defaultdict(list)
693 for line in lines[1:]: # all lines except the header line
694 splits: list[str] = line.split()
695 assert len(splits) == 2
696 pid = int(splits[0])
697 ppid = int(splits[1])
698 procs[ppid].append(pid)
700 def recursive_append(ppid: int, descendants: list[int]) -> None:
701 """Recursively collect descendant PIDs starting from ``ppid``."""
702 for child_pid in procs[ppid]:
703 descendants.append(child_pid)
704 recursive_append(child_pid, descendants)
706 all_descendants: list[list[int]] = []
707 for root_pid in root_pids:
708 descendants: list[int] = []
709 recursive_append(root_pid, descendants)
710 all_descendants.append(descendants)
711 return all_descendants
714@contextlib.contextmanager
715def termination_signal_handler(
716 termination_events: list[threading.Event],
717 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(),
718) -> Iterator[None]:
719 """Context manager that installs SIGINT/SIGTERM handlers that set all ``termination_events`` and, by default, terminate
720 all descendant processes."""
721 termination_events = list(termination_events) # shallow copy
723 def _handler(_sig: int, _frame: object) -> None:
724 for event in termination_events:
725 event.set()
726 termination_handler()
728 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler
729 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler
730 try:
731 yield # run body of context manager
732 finally:
733 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler
734 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler
737#############################################################################
738@final
739class Subprocesses:
740 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used
741 when multiple jobs run concurrently within the same Python process.
743 Optionally binds to a termination_event to enforce asynchronous cancellation by forcing immediate timeouts for newly
744 spawned subprocesses once cancellation is requested.
745 """
747 def __init__(self, termination_event: threading.Event | None = None) -> None:
748 self._termination_event: Final[threading.Event] = termination_event or threading.Event()
749 self._lock: Final[threading.Lock] = threading.Lock()
750 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order
752 @contextlib.contextmanager
753 def popen_and_track(self, *popen_args: Any, **popen_kwargs: Any) -> Iterator[subprocess.Popen]:
754 """Context manager that calls subprocess.Popen() and tracks the child PID for per-job termination.
756 Holds a lock across Popen+PID registration to prevent a race when terminate_process_subtrees() is invoked (e.g. from
757 SIGINT/SIGTERM handlers), ensuring newly spawned child processes cannot escape termination. The child PID is
758 unregistered on context exit.
759 """
760 with self._lock:
761 proc: subprocess.Popen = subprocess.Popen(*popen_args, **popen_kwargs)
762 self._child_pids[proc.pid] = None
763 try:
764 yield proc
765 finally:
766 with self._lock:
767 self._child_pids.pop(proc.pid, None)
769 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
770 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination."""
771 return subprocess_run(*args, **kwargs, subprocesses=self)
773 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None:
774 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs."""
775 with self._lock:
776 pids: list[int] = list(self._child_pids)
777 self._child_pids.clear()
778 terminate_process_subtree(root_pids=pids, sig=sig)
781#############################################################################
782def pid_exists(pid: int) -> bool | None:
783 """Returns True if a process with PID exists, False if not, or None on error."""
784 if pid <= 0:
785 return False
786 try: # with signal=0, no signal is actually sent, but error checking is still performed
787 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems
788 except OSError as err:
789 if err.errno == errno.ESRCH: # No such process
790 return False
791 if err.errno == errno.EPERM: # Operation not permitted
792 return True
793 return None
794 return True
797def nprefix(s: str) -> str:
798 """Returns a canonical snapshot prefix with trailing underscore."""
799 return sys.intern(s + "_")
802def ninfix(s: str) -> str:
803 """Returns a canonical infix with trailing underscore when not empty."""
804 return sys.intern(s + "_") if s else ""
807def nsuffix(s: str) -> str:
808 """Returns a canonical suffix with leading underscore when not empty."""
809 return sys.intern("_" + s) if s else ""
812def format_dict(dictionary: dict[Any, Any]) -> str:
813 """Returns a formatted dictionary using repr for consistent output."""
814 return f'"{dictionary}"'
817def format_obj(obj: object) -> str:
818 """Returns a formatted str using repr for consistent output."""
819 return f'"{obj}"'
822def validate_dataset_name(dataset: str, input_text: str) -> None:
823 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc."""
824 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424
825 # and https://github.com/openzfs/zfs/issues/8798
826 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html
827 invalid_chars: str = SHELL_CHARS
828 if (
829 dataset in ("", ".", "..")
830 or dataset.startswith(("/", "./", "../"))
831 or dataset.endswith(("/", "/.", "/.."))
832 or any(substring in dataset for substring in ("//", "/./", "/../"))
833 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset)
834 or not dataset[0].isalpha()
835 ):
836 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'")
839def validate_property_name(propname: str, input_text: str) -> str:
840 """Checks that the ZFS property name contains no spaces or shell chars."""
841 invalid_chars: str = SHELL_CHARS
842 if not propname or any(char.isspace() or char in invalid_chars for char in propname):
843 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'")
844 return propname
847def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None:
848 """Checks that the given path is not a symbolic link."""
849 if os.path.islink(path):
850 die(f"{msg}must not be a symlink: {path}", parser=parser)
853def validate_file_permissions(path: str, mode: int) -> None:
854 """Verify permissions and that ownership is current effective UID."""
855 stats: os.stat_result = os.stat(path, follow_symlinks=False)
856 st_uid: int = stats.st_uid
857 if st_uid != os.geteuid(): # verify ownership is current effective UID
858 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}")
859 st_mode = stat.S_IMODE(stats.st_mode)
860 if st_mode != mode:
861 die(
862 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, "
863 f"not {mode:03o} aka {stat.filemode(mode)[1:]})"
864 )
867def parse_duration_to_milliseconds(duration: str, regex_suffix: str = "", context: str = "") -> int:
868 """Parses human duration strings like '5m' or '2 hours' to milliseconds."""
869 unit_milliseconds: dict[str, int] = {
870 "milliseconds": 1,
871 "millis": 1,
872 "seconds": 1000,
873 "secs": 1000,
874 "minutes": 60 * 1000,
875 "mins": 60 * 1000,
876 "hours": 60 * 60 * 1000,
877 "days": 86400 * 1000,
878 "weeks": 7 * 86400 * 1000,
879 "months": round(30.5 * 86400 * 1000),
880 "years": 365 * 86400 * 1000,
881 }
882 match = re.fullmatch(
883 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix,
884 duration,
885 )
886 if not match:
887 if context:
888 die(f"Invalid duration format: {duration} within {context}")
889 else:
890 raise ValueError(f"Invalid duration format: {duration}")
891 assert match
892 quantity: int = int(match.group(1))
893 unit: str = match.group(2)
894 return quantity * unit_milliseconds[unit]
897def unixtime_fromisoformat(datetime_str: str) -> int:
898 """Converts ISO 8601 datetime string into UTC Unix time seconds."""
899 return int(datetime.fromisoformat(datetime_str).timestamp())
902def isotime_from_unixtime(unixtime_in_seconds: int) -> str:
903 """Converts UTC Unix time seconds into ISO 8601 datetime string."""
904 tz: tzinfo = timezone.utc
905 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz)
906 return dt.isoformat(sep="_", timespec="seconds")
909def current_datetime(
910 tz_spec: str | None = None,
911 now_fn: Callable[[tzinfo | None], datetime] | None = None,
912) -> datetime:
913 """Returns current time in ``tz_spec`` timezone or local timezone."""
914 if now_fn is None:
915 now_fn = datetime.now
916 return now_fn(get_timezone(tz_spec))
919def get_timezone(tz_spec: str | None = None) -> tzinfo | None:
920 """Returns timezone from spec or local timezone if unspecified."""
921 tz: tzinfo | None
922 if tz_spec is None:
923 tz = None
924 elif tz_spec == "UTC":
925 tz = timezone.utc
926 else:
927 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec):
928 sign, hours, minutes = match.groups()
929 offset: int = int(hours) * 60 + int(minutes)
930 offset = -offset if sign == "-" else offset
931 tz = timezone(timedelta(minutes=offset))
932 elif "/" in tz_spec:
933 from zoneinfo import ZoneInfo # lazy import for startup perf
935 tz = ZoneInfo(tz_spec)
936 else:
937 raise ValueError(f"Invalid timezone specification: {tz_spec}")
938 return tz
941###############################################################################
942@final
943class SnapshotPeriods: # thread-safe
944 """Parses snapshot suffix strings and converts between durations."""
946 def __init__(self) -> None:
947 """Initialize lookup tables of suffixes and corresponding millis."""
948 self.suffix_milliseconds: Final[dict[str, int]] = {
949 "yearly": 365 * 86400 * 1000,
950 "monthly": round(30.5 * 86400 * 1000),
951 "weekly": 7 * 86400 * 1000,
952 "daily": 86400 * 1000,
953 "hourly": 60 * 60 * 1000,
954 "minutely": 60 * 1000,
955 "secondly": 1000,
956 "millisecondly": 1,
957 }
958 self.period_labels: Final[dict[str, str]] = {
959 "yearly": "years",
960 "monthly": "months",
961 "weekly": "weeks",
962 "daily": "days",
963 "hourly": "hours",
964 "minutely": "minutes",
965 "secondly": "seconds",
966 "millisecondly": "milliseconds",
967 }
968 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})")
969 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern)
971 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]:
972 """Parse suffix like '10minutely' to (10, 'minutely')."""
973 return self._suffix_to_duration(suffix, self._suffix_regex0)
975 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]:
976 """Like :meth:`suffix_to_duration0` but expects an underscore prefix."""
977 return self._suffix_to_duration(suffix, self._suffix_regex1)
979 @staticmethod
980 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]:
981 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly')."""
982 if match := regex.fullmatch(suffix):
983 duration_amount: int = int(match.group(1)) if match.group(1) else 1
984 assert duration_amount > 0
985 duration_unit: str = match.group(2)
986 return duration_amount, duration_unit
987 else:
988 return 0, ""
990 def label_milliseconds(self, snapshot: str) -> int:
991 """Returns duration encoded in ``snapshot`` suffix, in milliseconds."""
992 i = snapshot.rfind("_")
993 snapshot = "" if i < 0 else snapshot[i + 1 :]
994 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0)
995 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0)
998#############################################################################
999@final
1000class JobStats:
1001 """Simple thread-safe counters summarizing job progress."""
1003 def __init__(self, jobs_all: int) -> None:
1004 assert jobs_all >= 0
1005 self.lock: Final[threading.Lock] = threading.Lock()
1006 self.jobs_all: int = jobs_all
1007 self.jobs_started: int = 0
1008 self.jobs_completed: int = 0
1009 self.jobs_failed: int = 0
1010 self.jobs_running: int = 0
1011 self.sum_elapsed_nanos: int = 0
1012 self.started_job_names: Final[set[str]] = set()
1014 def submit_job(self, job_name: str) -> str:
1015 """Counts a job submission."""
1016 with self.lock:
1017 self.jobs_started += 1
1018 self.jobs_running += 1
1019 self.started_job_names.add(job_name)
1020 return str(self)
1022 def complete_job(self, failed: bool, elapsed_nanos: int) -> str:
1023 """Counts a job completion."""
1024 assert elapsed_nanos >= 0
1025 with self.lock:
1026 self.jobs_running -= 1
1027 self.jobs_completed += 1
1028 self.jobs_failed += 1 if failed else 0
1029 self.sum_elapsed_nanos += elapsed_nanos
1030 msg = str(self)
1031 assert self.sum_elapsed_nanos >= 0, msg
1032 assert self.jobs_running >= 0, msg
1033 assert self.jobs_failed >= 0, msg
1034 assert self.jobs_failed <= self.jobs_completed, msg
1035 assert self.jobs_completed <= self.jobs_started, msg
1036 assert self.jobs_started <= self.jobs_all, msg
1037 return msg
1039 def __repr__(self) -> str:
1040 def pct(number: int) -> str:
1041 """Returns percentage string relative to total jobs."""
1042 return percent(number, total=self.jobs_all, print_total=True)
1044 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed
1045 running = self.jobs_running
1046 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed))
1047 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}"
1050#############################################################################
1051class Comparable(Protocol):
1052 """Partial ordering protocol."""
1054 def __lt__(self, other: Any) -> bool: ...
1057TComparable = TypeVar("TComparable", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue
1060@final
1061class SmallPriorityQueue(Generic[TComparable]):
1062 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and
1063 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case
1064 for us.
1066 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed
1067 priority queue via
1068 https://github.com/nvictus/pqdict.
1069 But, to avoid an external dependency, is actually implemented
1070 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that
1071 are already contained in the queue, via removal of the element, followed by update of the element, followed by
1072 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates.
1073 """
1075 def __init__(self, reverse: bool = False) -> None:
1076 """Creates an empty queue; sort order flips when ``reverse`` is True."""
1077 self._lst: Final[list[TComparable]] = []
1078 self._reverse: Final[bool] = reverse
1080 def clear(self) -> None:
1081 """Removes all elements from the queue."""
1082 self._lst.clear()
1084 def push(self, element: TComparable) -> None:
1085 """Inserts ``element`` while maintaining sorted order."""
1086 bisect.insort(self._lst, element)
1088 def pop(self) -> TComparable:
1089 """Removes and returns the smallest (or largest if reverse == True) element from the queue."""
1090 return self._lst.pop() if self._reverse else self._lst.pop(0)
1092 def peek(self) -> TComparable:
1093 """Returns the smallest (or largest if reverse == True) element without removing it."""
1094 return self._lst[-1] if self._reverse else self._lst[0]
1096 def remove(self, element: TComparable) -> bool:
1097 """Removes the first occurrence (in insertion order aka FIFO) of ``element`` and returns True if it was present."""
1098 lst = self._lst
1099 i = bisect.bisect_left(lst, element)
1100 is_contained = i < len(lst) and lst[i] == element
1101 if is_contained:
1102 del lst[i] # is an optimized memmove()
1103 return is_contained
1105 def __len__(self) -> int:
1106 """Returns the number of queued elements."""
1107 return len(self._lst)
1109 def __contains__(self, element: TComparable) -> bool:
1110 """Returns ``True`` if ``element`` is present."""
1111 lst = self._lst
1112 i = bisect.bisect_left(lst, element)
1113 return i < len(lst) and lst[i] == element
1115 def __iter__(self) -> Iterator[TComparable]:
1116 """Iterates over queued elements in priority order."""
1117 return reversed(self._lst) if self._reverse else iter(self._lst)
1119 def __repr__(self) -> str:
1120 """Representation showing queue contents in current order."""
1121 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst)
1124###############################################################################
1125@final
1126class SortedInterner(Generic[TComparable]):
1127 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search)."""
1129 def __init__(self, sorted_list: list[TComparable]) -> None:
1130 self._lst: Final[list[TComparable]] = sorted_list
1132 def interned(self, element: TComparable) -> TComparable:
1133 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1134 lst = self._lst
1135 i = binary_search(lst, element)
1136 return lst[i] if i >= 0 else element
1138 def __contains__(self, element: TComparable) -> bool:
1139 """Returns ``True`` if ``element`` is present."""
1140 return binary_search(self._lst, element) >= 0
1143def binary_search(sorted_list: list[TComparable], item: TComparable) -> int:
1144 """Java-style binary search; Returns index >= 0 if an equal item is found in list, else '-insertion_point-1'; If it
1145 returns index >= 0, the index will be the left-most index in case multiple such equal items are contained."""
1146 i = bisect.bisect_left(sorted_list, item)
1147 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1
1150###############################################################################
1151_S = TypeVar("_S")
1154@final
1155class HashedInterner(Generic[_S]):
1156 """Same as sys.intern() except that it isn't global and can also be used for types other than str."""
1158 def __init__(self, items: Iterable[_S] = frozenset()) -> None:
1159 self._items: Final[dict[_S, _S]] = {v: v for v in items}
1161 def intern(self, item: _S) -> _S:
1162 """Interns the given item."""
1163 return self._items.setdefault(item, item)
1165 def interned(self, item: _S) -> _S:
1166 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1167 return self._items.get(item, item)
1169 def __contains__(self, item: _S) -> bool:
1170 return item in self._items
1173#############################################################################
1174@final
1175class SynchronizedBool:
1176 """Thread-safe wrapper around a regular bool."""
1178 def __init__(self, val: bool) -> None:
1179 assert isinstance(val, bool)
1180 self._lock: Final[threading.Lock] = threading.Lock()
1181 self._value: bool = val
1183 @property
1184 def value(self) -> bool:
1185 """Returns the current boolean value."""
1186 with self._lock:
1187 return self._value
1189 @value.setter
1190 def value(self, new_value: bool) -> None:
1191 """Atomically assign ``new_value``."""
1192 with self._lock:
1193 self._value = new_value
1195 def get_and_set(self, new_value: bool) -> bool:
1196 """Swaps in ``new_value`` and return the previous value."""
1197 with self._lock:
1198 old_value = self._value
1199 self._value = new_value
1200 return old_value
1202 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool:
1203 """Sets to ``new_value`` only if current value equals ``expected_value``."""
1204 with self._lock:
1205 eq: bool = self._value == expected_value
1206 if eq:
1207 self._value = new_value
1208 return eq
1210 def __bool__(self) -> bool:
1211 return self.value
1213 def __repr__(self) -> str:
1214 return repr(self.value)
1216 def __str__(self) -> str:
1217 return str(self.value)
1220#############################################################################
1221_K = TypeVar("_K")
1222_V = TypeVar("_V")
1225@final
1226class SynchronizedDict(Generic[_K, _V]):
1227 """Thread-safe wrapper around a regular dict."""
1229 def __init__(self, val: dict[_K, _V]) -> None:
1230 assert isinstance(val, dict)
1231 self._lock: Final[threading.Lock] = threading.Lock()
1232 self._dict: Final[dict[_K, _V]] = val
1234 def __getitem__(self, key: _K) -> _V:
1235 with self._lock:
1236 return self._dict[key]
1238 def __setitem__(self, key: _K, value: _V) -> None:
1239 with self._lock:
1240 self._dict[key] = value
1242 def __delitem__(self, key: _K) -> None:
1243 with self._lock:
1244 self._dict.pop(key)
1246 def __contains__(self, key: _K) -> bool:
1247 with self._lock:
1248 return key in self._dict
1250 def __len__(self) -> int:
1251 with self._lock:
1252 return len(self._dict)
1254 def __repr__(self) -> str:
1255 with self._lock:
1256 return repr(self._dict)
1258 def __str__(self) -> str:
1259 with self._lock:
1260 return str(self._dict)
1262 def get(self, key: _K, default: _V | None = None) -> _V | None:
1263 """Returns ``self[key]`` or ``default`` if missing."""
1264 with self._lock:
1265 return self._dict.get(key, default)
1267 def pop(self, key: _K, default: _V | None = None) -> _V | None:
1268 """Removes ``key`` and returns its value."""
1269 with self._lock:
1270 return self._dict.pop(key, default)
1272 def clear(self) -> None:
1273 """Removes all items atomically."""
1274 with self._lock:
1275 self._dict.clear()
1277 def items(self) -> ItemsView[_K, _V]:
1278 """Returns a snapshot of dictionary items."""
1279 with self._lock:
1280 return self._dict.copy().items()
1283#############################################################################
1284@final
1285class InterruptibleSleep:
1286 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable."""
1288 def __init__(self, lock: threading.Lock | None = None) -> None:
1289 self._is_stopping: bool = False
1290 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock()
1291 self._condition: Final[threading.Condition] = threading.Condition(self._lock)
1293 def sleep(self, duration_nanos: int) -> bool:
1294 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted;
1295 Equivalent to threading.Event.wait()."""
1296 end_time_nanos: int = time.monotonic_ns() + duration_nanos
1297 with self._lock:
1298 while not self._is_stopping:
1299 diff_nanos: int = end_time_nanos - time.monotonic_ns()
1300 if diff_nanos <= 0:
1301 return False
1302 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout
1303 return True
1305 def interrupt(self) -> None:
1306 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set()."""
1307 with self._lock:
1308 if not self._is_stopping:
1309 self._is_stopping = True
1310 self._condition.notify_all()
1312 def reset(self) -> None:
1313 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear()."""
1314 with self._lock:
1315 self._is_stopping = False
1318#############################################################################
1319@final
1320class SynchronousExecutor(Executor):
1321 """Executor that runs tasks inline in the calling thread, sequentially."""
1323 def __init__(self) -> None:
1324 self._shutdown: bool = False
1326 def submit(self, fn: Callable[..., _R_], /, *args: Any, **kwargs: Any) -> Future[_R_]:
1327 """Executes `fn(*args, **kwargs)` immediately and returns its Future."""
1328 future: Future[_R_] = Future()
1329 if self._shutdown:
1330 raise RuntimeError("cannot schedule new futures after shutdown")
1331 try:
1332 result: _R_ = fn(*args, **kwargs)
1333 except BaseException as exc:
1334 future.set_exception(exc)
1335 else:
1336 future.set_result(result)
1337 return future
1339 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
1340 """Prevents new submissions; no worker resources to join/cleanup."""
1341 self._shutdown = True
1343 @classmethod
1344 def executor_for(cls, max_workers: int) -> Executor:
1345 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor."""
1346 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers)
1349#############################################################################
1350@final
1351class _XFinally(contextlib.AbstractContextManager):
1352 """Context manager ensuring cleanup code executes after ``with`` blocks."""
1354 def __init__(self, cleanup: Callable[[], None]) -> None:
1355 """Records the callable to run upon exit."""
1356 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits.
1358 def __exit__(
1359 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None
1360 ) -> Literal[False]:
1361 """Runs cleanup and propagate any exceptions appropriately."""
1362 try:
1363 self._cleanup()
1364 except BaseException as cleanup_exc:
1365 if exc is None:
1366 raise # No main error --> propagate cleanup error normally
1367 # Both failed
1368 # if sys.version_info >= (3, 11):
1369 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None
1370 # <= 3.10: attach so it shows up in traceback but doesn't mask
1371 exc.__context__ = cleanup_exc
1372 return False # reraise original exception
1373 return False # propagate main exception if any
1376def xfinally(cleanup: Callable[[], None]) -> _XFinally:
1377 """Usage: with xfinally(lambda: cleanup()): ...
1378 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never
1379 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible.
1381 Problem it solves
1382 -----------------
1383 A naive ``try ... finally`` may lose the original exception:
1385 try:
1386 work()
1387 finally:
1388 cleanup() # <-- if this raises an exception, it replaces the real error!
1390 `_XFinally` preserves exception priority:
1392 * Body raises, cleanup succeeds --> original body exception is re-raised.
1393 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``.
1394 * Body succeeds, cleanup raises --> cleanup exception propagates normally.
1396 Example:
1397 -------
1398 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP
1399 ... run_tasks()
1401 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information.
1402 """
1403 return _XFinally(cleanup)