Coverage for bzfs_main / util / utils.py: 100%
789 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 10:16 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-24 10:16 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight
16concurrency primitives, etc.
18Everything in this module relies only on the Python standard library so other modules remain dependency free. Each utility
19favors simple, predictable behavior on all supported platforms.
20"""
22from __future__ import (
23 annotations,
24)
25import argparse
26import base64
27import bisect
28import contextlib
29import dataclasses
30import errno
31import hashlib
32import itertools
33import logging
34import operator
35import os
36import platform
37import pwd
38import random
39import re
40import signal
41import stat
42import subprocess
43import sys
44import threading
45import time
46import types
47from collections import (
48 defaultdict,
49 deque,
50)
51from collections.abc import (
52 ItemsView,
53 Iterable,
54 Iterator,
55 Sequence,
56)
57from concurrent.futures import (
58 Executor,
59 Future,
60 ThreadPoolExecutor,
61)
62from dataclasses import (
63 dataclass,
64)
65from datetime import (
66 datetime,
67 timedelta,
68 timezone,
69 tzinfo,
70)
71from subprocess import (
72 DEVNULL,
73 PIPE,
74)
75from typing import (
76 IO,
77 Any,
78 Callable,
79 Final,
80 Generic,
81 Literal,
82 NoReturn,
83 Protocol,
84 SupportsIndex,
85 TextIO,
86 TypeVar,
87 cast,
88 final,
89)
91# constants:
92PROG_NAME: Final[str] = "bzfs"
93ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_"
94DIE_STATUS: Final[int] = 3
95DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset
96LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between
97LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between
98LOG_DEBUG: Final[int] = logging.DEBUG
99LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between
100YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target
101UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID
102DONT_SKIP_DATASET: Final[str] = ""
103SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\" # intentionally not included: -_.:/
104SHELL_CHARS_AND_SLASH: Final[str] = SHELL_CHARS + "/"
105FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write)
106DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute)
107UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions
108UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path'
110RegexList = list[tuple[re.Pattern[str], bool]] # Type alias
113def getenv_any(key: str, default: str | None = None, env_var_prefix: str = ENV_VAR_PREFIX) -> str | None:
114 """All shell environment variable names used for configuration start with this prefix."""
115 return os.getenv(env_var_prefix + key, default)
118def getenv_int(key: str, default: int, env_var_prefix: str = ENV_VAR_PREFIX) -> int:
119 """Returns environment variable ``key`` as int with ``default`` fallback."""
120 return int(cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)))
123def getenv_bool(key: str, default: bool = False, env_var_prefix: str = ENV_VAR_PREFIX) -> bool:
124 """Returns environment variable ``key`` as bool with ``default`` fallback."""
125 return cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)).lower().strip() == "true"
128def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]:
129 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command."""
130 assert lines is not None
131 assert isinstance(lines, list)
132 assert len(separator) == 1
133 if field == 1:
134 return [line[: line.index(separator)] for line in lines]
135 elif field == 2:
136 return [line[line.index(separator) + 1 :] for line in lines]
137 else:
138 raise ValueError(f"Invalid field value: {field}")
141def drain(iterable: Iterable[Any]) -> None:
142 """Consumes all items in the iterable, effectively draining it."""
143 for _ in iterable:
144 del _ # help gc (iterable can block)
147_K_ = TypeVar("_K_")
148_V_ = TypeVar("_V_")
149_R_ = TypeVar("_R_")
152def shuffle_dict(dictionary: dict[_K_, _V_], /, rand: random.Random = random.SystemRandom()) -> dict[_K_, _V_]: # noqa: B008
153 """Returns a new dict with items shuffled randomly."""
154 items: list[tuple[_K_, _V_]] = list(dictionary.items())
155 rand.shuffle(items)
156 return dict(items)
159def sorted_dict(
160 dictionary: dict[_K_, _V_], /, *, key: Callable[[tuple[_K_, _V_]], Any] | None = None, reverse: bool = False
161) -> dict[_K_, _V_]:
162 """Returns a new dict with items sorted, primarily by key and secondarily by value (unless a custom key is supplied)."""
163 return dict(sorted(dictionary.items(), key=key, reverse=reverse))
166def tail(file: str, *, n: int, errors: str | None = None) -> Sequence[str]:
167 """Return the last ``n`` lines of ``file`` without following symlinks."""
168 if not os.path.isfile(file):
169 return []
170 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False) as fd:
171 return deque(fd, maxlen=n)
174_NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">"))
175_NUMERIC_BACKREFERENCE_REGEX: Final[re.Pattern[str]] = re.compile(r"\\\d+") # example: \1
178def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str:
179 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky).
181 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\('
182 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash
183 with the replacement string '(?:'
185 Named capturing group example: '(?P<name>abc)' --> '(?:abc)'
186 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash
187 with the replacement string '(?:'
189 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups
190 """
191 i = regex.find("[")
192 if i >= 0 and regex.find("(", i) >= 0:
193 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the case where the regex might
194 # contain a regex character class that contains parenthesis.
195 # Rewriting a regex is a performance optimization; correctness comes first.
196 return regex
198 if "(?P=" in regex or "(?(" in regex or _NUMERIC_BACKREFERENCE_REGEX.search(regex):
199 # Conservative fallback to minimize code complexity: skip the rewrite entirely if the regex might contain a
200 # (named or conditional or numeric) backreference.
201 # Rewriting a regex is a performance optimization; correctness comes first.
202 return regex
204 i = len(regex) - 2
205 while i >= 0:
206 i = regex.rfind("(", 0, i + 1)
207 if i >= 0 and (i == 0 or regex[i - 1] != "\\"):
208 if regex[i + 1] != "?":
209 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group
210 else: # potentially a valid named capturing group
211 regex = regex[0:i] + _NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1)
212 i -= 1
213 return regex
216def get_home_directory() -> str:
217 """Reliably detects home dir without using HOME env var."""
218 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~')
219 return pwd.getpwuid(os.getuid()).pw_dir
222def human_readable_bytes(num_bytes: float, *, separator: str = " ", precision: int | None = None) -> str:
223 """Formats 'num_bytes' as a human-readable size; for example "567 MiB"."""
224 sign = "-" if num_bytes < 0 else ""
225 s = abs(num_bytes)
226 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB")
227 n = len(units) - 1
228 i = 0
229 while s >= 1024 and i < n:
230 s /= 1024
231 i += 1
232 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}"
233 return f"{sign}{formatted_num}{separator}{units[i]}"
236def human_readable_duration(duration: float, *, unit: str = "ns", separator: str = "", precision: int | None = None) -> str:
237 """Formats a duration in human units, automatically scaling as needed; for example "567ms"."""
238 sign = "-" if duration < 0 else ""
239 t = abs(duration)
240 units = ("ns", "μs", "ms", "s", "m", "h", "d")
241 i = units.index(unit)
242 if t < 1 and t != 0:
243 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000)
244 t *= nanos[i]
245 i = 0
246 while t >= 1000 and i < 3:
247 t /= 1000
248 i += 1
249 if i >= 3:
250 while t >= 60 and i < 5:
251 t /= 60
252 i += 1
253 if i >= 5:
254 while t >= 24 and i < len(units) - 1:
255 t /= 24
256 i += 1
257 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}"
258 return f"{sign}{formatted_num}{separator}{units[i]}"
261def human_readable_float(number: float) -> str:
262 """Formats ``number`` with a variable precision depending on magnitude.
264 This design mirrors the way humans round values when scanning logs.
266 If the number has one digit before the decimal point (0 <= abs(number) < 10):
267 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15").
269 If the number has two digits before the decimal point (10 <= abs(number) < 100):
270 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4").
272 If the number has three or more digits before the decimal point (abs(number) >= 100):
273 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124").
275 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1"
276 """
277 abs_number = abs(number)
278 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0
279 if precision == 0:
280 return str(round(number))
281 result = f"{number:.{precision}f}"
282 assert "." in result
283 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty
284 return "0" if result == "-0" else result
287def percent(number: int, total: int, *, print_total: bool = False) -> str:
288 """Returns percentage string of ``number`` relative to ``total``."""
289 tot: str = f"/{total}" if print_total else ""
290 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%"
293def open_nofollow(
294 path: str,
295 mode: str = "r",
296 buffering: int = -1,
297 encoding: str | None = None,
298 errors: str | None = None,
299 newline: str | None = None,
300 *,
301 perm: int = FILE_PERMISSIONS,
302 check_owner: bool = True,
303 **kwargs: Any,
304) -> IO[Any]:
305 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with
306 errno.ELOOP/EMLINK if basename of path is a symlink.
308 Also, can specify custom permissions on O_CREAT, and verify ownership.
310 If check_owner=True, write-capable opens require ownership by the effective UID; read-only opens also allow ownership by
311 uid 0 (root). This allows safe reads of root-owned system files while preventing writes to files not owned by the caller.
312 """
313 if not mode:
314 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus")
315 flags = {
316 "r": os.O_RDONLY,
317 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
318 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND,
319 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL,
320 }.get(mode[0])
321 if flags is None:
322 raise ValueError(f"invalid mode {mode!r}")
323 if "+" in mode: # enable read-write access for r+, w+, a+, x+
324 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags
325 flags |= os.O_NOFOLLOW | os.O_CLOEXEC
326 fd: int = os.open(path, flags=flags, mode=perm)
327 try:
328 if check_owner:
329 st_uid: int = os.fstat(fd).st_uid
330 if st_uid != os.geteuid(): # verify ownership is current effective UID
331 if (flags & (os.O_WRONLY | os.O_RDWR)) != 0: # require that writer owns the file
332 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path)
333 elif st_uid != 0: # it's ok for root to own a file that we'll merely read
334 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()} or 0", path)
335 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs)
336 except Exception:
337 try:
338 os.close(fd)
339 except OSError:
340 pass
341 raise
344def close_quietly(fd: int) -> None:
345 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this."""
346 if fd >= 0:
347 try:
348 os.close(fd)
349 except OSError:
350 pass
353_P = TypeVar("_P")
356def find_match(
357 seq: Sequence[_P],
358 predicate: Callable[[_P], bool],
359 start: SupportsIndex | None = None,
360 end: SupportsIndex | None = None,
361 *,
362 reverse: bool = False,
363 raises: bool | object | Callable[[], object] = False, # raises: bool | object | Callable = False, # python >= 3.10
364) -> int:
365 """Returns the integer index within ``seq`` of the first item (or last item if reverse=True) that matches the given
366 predicate condition.
368 If no matching item is found returns -1 or ValueError, depending on the ``raises`` parameter, which is a bool indicating
369 whether to raise an error, or an object containing the error message, but can also be a Callable/lambda in order to
370 support efficient deferred generation of error messages.
372 Analog to ``str.find()``, including slicing semantics with parameters start and end, i.e. respects Python slicing
373 semantics for start/end (including clamping). For example, seq can be a list, tuple or str.
375 Example usage:
376 lst = ["a", "b", "-c", "d"]
377 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True)
378 if i >= 0:
379 print(lst[i])
380 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}")
381 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}")
382 """
383 if start is None and end is None:
384 for i in range(len(seq) - 1, -1, -1) if reverse else range(len(seq)):
385 if predicate(seq[i]):
386 return i
387 else:
388 slice_start, slice_end, _ = slice(start, end).indices(len(seq))
389 for i in range(slice_end - 1, slice_start - 1, -1) if reverse else range(slice_start, slice_end):
390 if predicate(seq[i]):
391 return i
392 if raises is False or raises is None:
393 return -1
394 if raises is True:
395 raise ValueError("No matching item found in sequence")
396 if callable(raises):
397 raises = raises()
398 raise ValueError(raises)
401def is_descendant(dataset: str, of_root_dataset: str) -> bool:
402 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same."""
403 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/")
406def has_duplicates(sorted_list: list[Any]) -> bool:
407 """Returns True if any adjacent items within the given sorted sequence are equal."""
408 return any(map(operator.eq, sorted_list, itertools.islice(sorted_list, 1, None)))
411def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool:
412 """Returns whether the (sorted) list of ZFS input datasets contains any siblings."""
413 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted"
414 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates"
415 skip_dataset: str = DONT_SKIP_DATASET
416 parents: set[str] = set()
417 for dataset in sorted_datasets:
418 assert dataset
419 parent = os.path.dirname(dataset)
420 if parent in parents:
421 return True # I have a sibling if my parent already has another child
422 parents.add(parent)
423 if is_descendant(dataset, of_root_dataset=skip_dataset):
424 continue
425 if skip_dataset != DONT_SKIP_DATASET:
426 return True # I have a sibling if I am a root dataset and another root dataset already exists
427 skip_dataset = dataset
428 return False
431def dry(msg: str, is_dry_run: bool) -> str:
432 """Prefix ``msg`` with 'Dry' when in dry-run mode."""
433 return "Dry " + msg if is_dry_run else msg
436def relativize_dataset(dataset: str, root_dataset: str) -> str:
437 """Converts an absolute dataset path to one relative to ``root_dataset``.
439 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz.
440 """
441 return dataset[len(root_dataset) :]
444def dataset_paths(dataset: str) -> Iterator[str]:
445 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c"."""
446 i: int = 0
447 while i >= 0:
448 i = dataset.find("/", i)
449 if i < 0:
450 yield dataset
451 else:
452 yield dataset[:i]
453 i += 1
456def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str:
457 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present."""
458 assert s.startswith(old_prefix)
459 return new_prefix + s[len(old_prefix) :]
462def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None:
463 """Replaces ``old`` with ``new`` in-place for every string in ``lines``."""
464 for i in range(len(lines)):
465 lines[i] = lines[i].replace(old, new, count)
468_TAPPEND = TypeVar("_TAPPEND")
471def append_if_absent(lst: list[_TAPPEND], *items: _TAPPEND) -> list[_TAPPEND]:
472 """Appends items to list if they are not already present."""
473 for item in items:
474 if item not in lst:
475 lst.append(item)
476 return lst
479def xappend(lst: list[_TAPPEND], *items: _TAPPEND | Iterable[_TAPPEND]) -> list[_TAPPEND]:
480 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If
481 an item is an iterable does so recursively, flattening the output."""
482 for item in items:
483 if isinstance(item, str) or not isinstance(item, Iterable):
484 if item:
485 lst.append(item)
486 else:
487 xappend(lst, *item)
488 return lst
491def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool:
492 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False.
494 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match.
495 """
496 for regex, is_negation in exclude_regexes:
497 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
498 if is_negation:
499 is_match = not is_match
500 if is_match:
501 return False
503 for regex, is_negation in include_regexes:
504 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
505 if is_negation:
506 is_match = not is_match
507 if is_match:
508 return True
510 return False
513def compile_regexes(regexes: list[str], *, suffix: str = "") -> RegexList:
514 """Compiles regex strings and keeps track of negations."""
515 assert isinstance(regexes, list)
516 compiled_regexes: RegexList = []
517 for regex in regexes:
518 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match
519 if regex.endswith("\\$"):
520 pass # trailing literal $ is ok
521 elif regex.endswith("$"):
522 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch()
523 elif "$" in regex:
524 raise re.error("Must not use non-trailing '$' character", regex)
525 if is_negation := regex.startswith("!"):
526 regex = regex[1:]
527 regex = replace_capturing_groups_with_non_capturing_groups(regex)
528 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")):
529 regex = f"{regex}{suffix}"
530 compiled_regexes.append((re.compile(regex), is_negation))
531 return compiled_regexes
534def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any:
535 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels."""
537 @final
538 class CustomListFormatter:
539 """Formatter object that joins items when converted to ``str``."""
541 def __str__(self) -> str:
542 s = separator.join(map(str, iterable))
543 return s.lstrip() if lstrip else s
545 return CustomListFormatter()
548def pretty_print_formatter(obj_to_format: Any) -> Any:
549 """Lazy pprint formatter used to avoid overhead in disabled log levels."""
551 @final
552 class PrettyPrintFormatter:
553 """Formatter that pretty-prints the object on conversion to ``str``."""
555 def __str__(self) -> str:
556 import pprint # lazy import for startup perf
558 return pprint.pformat(vars(obj_to_format))
560 return PrettyPrintFormatter()
563def stderr_to_str(stderr: Any) -> str:
564 """Workaround for https://github.com/python/cpython/issues/87597."""
565 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace")
568def xprint(log: logging.Logger, value: Any, *, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None:
569 """Optionally logs ``value`` at stdout/stderr level."""
570 if run and value:
571 value = value if end else str(value).rstrip()
572 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR
573 log.log(level, "%s", value)
576def sha256_hex(text: str) -> str:
577 """Returns the sha256 hex string for the given text."""
578 return hashlib.sha256(text.encode()).hexdigest()
581def sha256_urlsafe_base64(text: str, *, padding: bool = True) -> str:
582 """Returns the URL-safe base64-encoded sha256 value for the given text."""
583 digest: bytes = hashlib.sha256(text.encode()).digest()
584 s: str = base64.urlsafe_b64encode(digest).decode()
585 return s if padding else s.rstrip("=")
588def sha256_128_urlsafe_base64(text: str) -> str:
589 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
590 s: str = sha256_urlsafe_base64(text, padding=False)
591 return s[: len(s) // 2]
594def sha256_85_urlsafe_base64(text: str) -> str:
595 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
596 s: str = sha256_urlsafe_base64(text, padding=False)
597 return s[: len(s) // 3]
600def urlsafe_base64(
601 value: int, max_value: int = 2**64 - 1, *, padding: bool = True, byteorder: Literal["little", "big"] = "big"
602) -> str:
603 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value]."""
604 assert 0 <= value <= max_value
605 max_bytes: int = (max_value.bit_length() + 7) // 8
606 value_bytes: bytes = value.to_bytes(max_bytes, byteorder)
607 s: str = base64.urlsafe_b64encode(value_bytes).decode()
608 return s if padding else s.rstrip("=")
611def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn:
612 """Exits the program with ``exit_code`` after logging ``msg``."""
613 if parser is None:
614 ex = SystemExit(msg)
615 ex.code = exit_code
616 raise ex
617 else:
618 parser.error(msg)
621def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
622 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and
623 provides optional child PID tracking, and optional logging of execution status via ``log`` and ``loglevel`` params."""
624 input_value = kwargs.pop("input", None)
625 timeout = kwargs.pop("timeout", None)
626 check = kwargs.pop("check", False)
627 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None)
628 if input_value is not None:
629 if kwargs.get("stdin") is not None:
630 raise ValueError("input and stdin are mutually exclusive")
631 kwargs["stdin"] = subprocess.PIPE
633 log: logging.Logger | None = kwargs.pop("log", None)
634 loglevel: int | None = kwargs.pop("loglevel", None)
635 start_time_nanos: int = time.monotonic_ns()
636 is_timeout: bool = False
637 is_cancel: bool = False
638 exitcode: int | None = None
640 def log_status() -> None:
641 if log is not None:
642 _loglevel: int = loglevel if loglevel is not None else getenv_int("subprocess_run_loglevel", LOG_TRACE)
643 if log.isEnabledFor(_loglevel):
644 elapsed_time: str = human_readable_float((time.monotonic_ns() - start_time_nanos) / 1_000_000) + "ms"
645 status = "cancel" if is_cancel else "timeout" if is_timeout else "success" if exitcode == 0 else "failure"
646 cmd = kwargs["args"] if "args" in kwargs else (args[0] if args else None)
647 cmd_str: str = " ".join(str(arg) for arg in iter(cmd)) if isinstance(cmd, (list, tuple)) else str(cmd)
648 log.log(_loglevel, f"Executed [{status}] [{elapsed_time}]: %s", cmd_str)
650 with xfinally(log_status):
651 ctx: contextlib.AbstractContextManager[subprocess.Popen]
652 if subprocesses is None:
653 ctx = subprocess.Popen(*args, **kwargs)
654 else:
655 ctx = subprocesses.popen_and_track(*args, **kwargs)
656 with ctx as proc:
657 try:
658 sp = subprocesses
659 if sp is not None and sp._is_terminated(): # noqa: SLF001 pylint: disable=protected-access
660 is_cancel = True
661 timeout = 0.0
662 stdout, stderr = proc.communicate(input_value, timeout=timeout)
663 except BaseException as e:
664 try:
665 if isinstance(e, subprocess.TimeoutExpired):
666 is_timeout = True
667 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child proc and descendants
668 finally:
669 proc.kill()
670 raise
671 else:
672 exitcode = proc.poll()
673 assert exitcode is not None
674 if check and exitcode:
675 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr)
676 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr)
679def terminate_process_subtree(
680 *, except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM
681) -> None:
682 """For each root PID: Sends the given signal to the root PID and all its descendant processes."""
683 current_pid: int = os.getpid()
684 root_pids = [current_pid] if root_pids is None else root_pids
685 all_pids: list[list[int]] = _get_descendant_processes(root_pids)
686 assert len(all_pids) == len(root_pids)
687 for i, pids in enumerate(all_pids):
688 root_pid = root_pids[i]
689 if root_pid == current_pid:
690 pids += [] if except_current_process else [current_pid]
691 else:
692 pids.insert(0, root_pid)
693 for pid in pids:
694 with contextlib.suppress(OSError):
695 os.kill(pid, sig)
698def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]:
699 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems."""
700 if len(root_pids) == 0:
701 return []
702 cmd: list[str] = ["ps", "-Ao", "pid,ppid"]
703 try:
704 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines()
705 except PermissionError:
706 # degrade gracefully in sandbox environments that deny executing `ps` entirely
707 return [[] for _ in root_pids]
708 procs: dict[int, list[int]] = defaultdict(list)
709 for line in lines[1:]: # all lines except the header line
710 splits: list[str] = line.split()
711 assert len(splits) == 2
712 pid = int(splits[0])
713 ppid = int(splits[1])
714 procs[ppid].append(pid)
716 def recursive_append(ppid: int, descendants: list[int]) -> None:
717 """Recursively collect descendant PIDs starting from ``ppid``."""
718 for child_pid in procs[ppid]:
719 descendants.append(child_pid)
720 recursive_append(child_pid, descendants)
722 all_descendants: list[list[int]] = []
723 for root_pid in root_pids:
724 descendants: list[int] = []
725 recursive_append(root_pid, descendants)
726 all_descendants.append(descendants)
727 return all_descendants
730@contextlib.contextmanager
731def termination_signal_handler(
732 termination_events: list[threading.Event],
733 *,
734 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(),
735) -> Iterator[None]:
736 """Context manager that installs SIGINT/SIGTERM handlers that set all ``termination_events`` and, by default, terminate
737 all descendant processes."""
738 termination_events = list(termination_events) # shallow copy
740 def _handler(_sig: int, _frame: object) -> None:
741 for event in termination_events:
742 event.set()
743 termination_handler()
745 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler
746 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler
747 try:
748 yield # run body of context manager
749 finally:
750 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler
751 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler
754def return_false() -> bool:
755 """Always returns ``False``; picklable."""
756 return False
759def sleep_nanos(delay_nanos: int) -> None:
760 """Same as time.sleep() but expects a relative sleep duration in nanoseconds as input value; picklable."""
761 time.sleep(delay_nanos / 1_000_000_000)
764#############################################################################
765@dataclass(frozen=True)
766@final
767class TaskTiming:
768 """Customizable callbacks for reading the current monotonic time, sleeping and optional async termination; immutable."""
770 monotonic_ns: Callable[[], int] = time.monotonic_ns
772 is_terminated: Callable[[], bool] = return_false
773 """Returns whether a predicate has become true; can be used to indicate system shutdown or similar cancellation
774 conditions; default is to always return ``False``."""
776 sleep: Callable[[int], None] = sleep_nanos
777 """Sleeps N nanoseconds; thread-safe."""
779 def copy(self, **override_kwargs: Any) -> TaskTiming:
780 """Creates a new object copying an existing one with the specified fields overridden for customization; thread-
781 safe."""
782 return dataclasses.replace(self, **override_kwargs)
784 @staticmethod
785 def make_from(termination_event: threading.Event | None) -> TaskTiming:
786 """Convenience factory that creates an object that performs async termination when ``termination_event`` is set."""
787 if termination_event is None:
788 return TaskTiming()
790 def _sleep(delay_nanos: int) -> None:
791 termination_event.wait(delay_nanos / 1_000_000_000) # allow early wakeup on async termination
793 return TaskTiming(is_terminated=termination_event.is_set, sleep=_sleep)
796#############################################################################
797@final
798class Subprocesses:
799 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used
800 when multiple jobs run concurrently within the same Python process.
802 Optionally binds to an ``_is_terminated`` predicate to enforce async cancellation by forcing immediate timeouts for newly
803 spawned subprocesses once cancellation is requested.
804 """
806 def __init__(self, is_terminated: Callable[[], bool] = return_false) -> None:
807 self._is_terminated: Final[Callable[[], bool]] = is_terminated
808 self._lock: Final[threading.Lock] = threading.Lock()
809 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order
811 @contextlib.contextmanager
812 def popen_and_track(self, *popen_args: Any, **popen_kwargs: Any) -> Iterator[subprocess.Popen]:
813 """Context manager that calls subprocess.Popen() and tracks the child PID for per-job termination.
815 Holds a lock across Popen+PID registration to prevent a race when terminate_process_subtrees() is invoked (e.g. from
816 SIGINT/SIGTERM handlers), ensuring newly spawned child processes cannot escape termination. The child PID is
817 unregistered on context exit.
818 """
819 with self._lock:
820 proc: subprocess.Popen = subprocess.Popen(*popen_args, **popen_kwargs)
821 self._child_pids[proc.pid] = None
822 try:
823 yield proc
824 finally:
825 with self._lock:
826 self._child_pids.pop(proc.pid, None)
828 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
829 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination."""
830 return subprocess_run(*args, **kwargs, subprocesses=self)
832 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None:
833 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs."""
834 with self._lock:
835 pids: list[int] = list(self._child_pids)
836 self._child_pids.clear()
837 terminate_process_subtree(root_pids=pids, sig=sig)
840#############################################################################
841def pid_exists(pid: int) -> bool | None:
842 """Returns True if a process with PID exists, False if not, or None on error."""
843 if pid <= 0:
844 return False
845 try: # with signal=0, no signal is actually sent, but error checking is still performed
846 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems
847 except OSError as err:
848 if err.errno == errno.ESRCH: # No such process
849 return False
850 if err.errno == errno.EPERM: # Operation not permitted
851 return True
852 return None
853 return True
856def nprefix(s: str) -> str:
857 """Returns a canonical snapshot prefix with trailing underscore."""
858 return sys.intern(s + "_")
861def ninfix(s: str) -> str:
862 """Returns a canonical infix with trailing underscore when not empty."""
863 return sys.intern(s + "_") if s else ""
866def nsuffix(s: str) -> str:
867 """Returns a canonical suffix with leading underscore when not empty."""
868 return sys.intern("_" + s) if s else ""
871def format_dict(dictionary: dict[Any, Any]) -> str:
872 """Returns a formatted dictionary using repr for consistent output."""
873 return f'"{dictionary}"'
876def format_obj(obj: object) -> str:
877 """Returns a formatted str using repr for consistent output."""
878 return f'"{obj}"'
881def validate_dataset_name(dataset: str, input_text: str) -> None:
882 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc."""
883 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424
884 # and https://github.com/openzfs/zfs/issues/8798
885 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html
886 invalid_chars: str = SHELL_CHARS
887 if (
888 dataset in ("", ".", "..")
889 or dataset.startswith(("/", "./", "../"))
890 or dataset.endswith(("/", "/.", "/.."))
891 or any(substring in dataset for substring in ("//", "/./", "/../"))
892 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset)
893 or not dataset[0].isalpha()
894 ):
895 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'")
898def validate_property_name(propname: str, input_text: str) -> str:
899 """Checks that the ZFS property name contains no spaces or shell chars, etc."""
900 invalid_chars: str = SHELL_CHARS
901 if (not propname) or propname.startswith("-") or any(char.isspace() or char in invalid_chars for char in propname):
902 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'")
903 return propname
906def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None:
907 """Checks that the given path is not a symbolic link."""
908 if os.path.islink(path):
909 die(f"{msg}must not be a symlink: {path}", parser=parser)
912def validate_file_permissions(path: str, mode: int) -> None:
913 """Verify permissions and that ownership is current effective UID."""
914 stats: os.stat_result = os.stat(path, follow_symlinks=False)
915 st_uid: int = stats.st_uid
916 if st_uid != os.geteuid(): # verify ownership is current effective UID
917 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}")
918 st_mode = stat.S_IMODE(stats.st_mode)
919 if st_mode != mode:
920 die(
921 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, "
922 f"not {mode:03o} aka {stat.filemode(mode)[1:]})"
923 )
926def parse_duration_to_milliseconds(duration: str, *, regex_suffix: str = "", context: str = "") -> int:
927 """Parses human duration strings like '5minutes' or '2 hours' to milliseconds."""
928 unit_milliseconds: dict[str, int] = {
929 "milliseconds": 1,
930 "millis": 1,
931 "seconds": 1000,
932 "secs": 1000,
933 "minutes": 60 * 1000,
934 "mins": 60 * 1000,
935 "hours": 60 * 60 * 1000,
936 "days": 86400 * 1000,
937 "weeks": 7 * 86400 * 1000,
938 "months": round(30.5 * 86400 * 1000),
939 "years": 365 * 86400 * 1000,
940 }
941 match = re.fullmatch(
942 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix,
943 duration,
944 )
945 if not match:
946 if context:
947 die(f"Invalid duration format: {duration} within {context}")
948 else:
949 raise ValueError(f"Invalid duration format: {duration}")
950 assert match
951 quantity: int = int(match.group(1))
952 unit: str = match.group(2)
953 return quantity * unit_milliseconds[unit]
956def unixtime_fromisoformat(datetime_str: str) -> int:
957 """Converts ISO 8601 datetime string into UTC Unix time in integer seconds."""
958 return int(datetime.fromisoformat(datetime_str).timestamp())
961def isotime_from_unixtime(unixtime_in_seconds: int) -> str:
962 """Converts UTC Unix time seconds into ISO 8601 datetime string."""
963 tz: tzinfo = timezone.utc
964 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz)
965 return dt.isoformat(sep="_", timespec="seconds")
968def current_datetime(
969 tz_spec: str | None = None,
970 now_fn: Callable[[tzinfo | None], datetime] | None = None,
971) -> datetime:
972 """Returns current time in ``tz_spec`` timezone or local timezone."""
973 if now_fn is None:
974 now_fn = datetime.now
975 return now_fn(get_timezone(tz_spec))
978def get_timezone(tz_spec: str | None = None) -> tzinfo | None:
979 """Returns timezone from spec or local timezone if unspecified."""
980 tz: tzinfo | None
981 if tz_spec is None:
982 tz = None
983 elif tz_spec == "UTC":
984 tz = timezone.utc
985 else:
986 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec):
987 sign, hours, minutes = match.groups()
988 offset: int = int(hours) * 60 + int(minutes)
989 offset = -offset if sign == "-" else offset
990 tz = timezone(timedelta(minutes=offset))
991 elif "/" in tz_spec:
992 from zoneinfo import ZoneInfo # lazy import for startup perf
994 tz = ZoneInfo(tz_spec)
995 else:
996 raise ValueError(f"Invalid timezone specification: {tz_spec}")
997 return tz
1000###############################################################################
1001@final
1002class SnapshotPeriods: # thread-safe
1003 """Parses snapshot suffix strings and converts between durations."""
1005 def __init__(self) -> None:
1006 """Initialize lookup tables of suffixes and corresponding millis."""
1007 self.suffix_milliseconds: Final[dict[str, int]] = {
1008 "yearly": 365 * 86400 * 1000,
1009 "monthly": round(30.5 * 86400 * 1000),
1010 "weekly": 7 * 86400 * 1000,
1011 "daily": 86400 * 1000,
1012 "hourly": 60 * 60 * 1000,
1013 "minutely": 60 * 1000,
1014 "secondly": 1000,
1015 "millisecondly": 1,
1016 }
1017 self.period_labels: Final[dict[str, str]] = {
1018 "yearly": "years",
1019 "monthly": "months",
1020 "weekly": "weeks",
1021 "daily": "days",
1022 "hourly": "hours",
1023 "minutely": "minutes",
1024 "secondly": "seconds",
1025 "millisecondly": "milliseconds",
1026 }
1027 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})")
1028 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern)
1030 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]:
1031 """Parse suffix like '10minutely' to (10, 'minutely')."""
1032 return self._suffix_to_duration(suffix, self._suffix_regex0)
1034 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]:
1035 """Like :meth:`suffix_to_duration0` but expects an underscore prefix."""
1036 return self._suffix_to_duration(suffix, self._suffix_regex1)
1038 @staticmethod
1039 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]:
1040 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly')."""
1041 if match := regex.fullmatch(suffix):
1042 duration_amount: int = int(match.group(1)) if match.group(1) else 1
1043 assert duration_amount > 0
1044 duration_unit: str = match.group(2)
1045 return duration_amount, duration_unit
1046 else:
1047 return 0, ""
1049 def label_milliseconds(self, snapshot: str) -> int:
1050 """Returns duration encoded in ``snapshot`` suffix, in milliseconds."""
1051 i = snapshot.rfind("_")
1052 snapshot = "" if i < 0 else snapshot[i + 1 :]
1053 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0)
1054 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0)
1057#############################################################################
1058@final
1059class JobStats:
1060 """Simple thread-safe counters summarizing job progress."""
1062 def __init__(self, jobs_all: int) -> None:
1063 assert jobs_all >= 0
1064 self.lock: Final[threading.Lock] = threading.Lock()
1065 self.jobs_all: int = jobs_all
1066 self.jobs_started: int = 0
1067 self.jobs_completed: int = 0
1068 self.jobs_failed: int = 0
1069 self.jobs_running: int = 0
1070 self.sum_elapsed_nanos: int = 0
1071 self.started_job_names: Final[set[str]] = set()
1073 def submit_job(self, job_name: str) -> str:
1074 """Counts a job submission."""
1075 with self.lock:
1076 self.jobs_started += 1
1077 self.jobs_running += 1
1078 self.started_job_names.add(job_name)
1079 return str(self)
1081 def complete_job(self, failed: bool, elapsed_nanos: int) -> str:
1082 """Counts a job completion."""
1083 assert elapsed_nanos >= 0
1084 with self.lock:
1085 self.jobs_running -= 1
1086 self.jobs_completed += 1
1087 self.jobs_failed += 1 if failed else 0
1088 self.sum_elapsed_nanos += elapsed_nanos
1089 msg = str(self)
1090 assert self.sum_elapsed_nanos >= 0, msg
1091 assert self.jobs_running >= 0, msg
1092 assert self.jobs_failed >= 0, msg
1093 assert self.jobs_failed <= self.jobs_completed, msg
1094 assert self.jobs_completed <= self.jobs_started, msg
1095 assert self.jobs_started <= self.jobs_all, msg
1096 return msg
1098 def __repr__(self) -> str:
1099 def pct(number: int) -> str:
1100 """Returns percentage string relative to total jobs."""
1101 return percent(number, total=self.jobs_all, print_total=True)
1103 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed
1104 running = self.jobs_running
1105 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed))
1106 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}"
1109#############################################################################
1110class Comparable(Protocol):
1111 """Partial ordering protocol."""
1113 def __lt__(self, other: Any) -> bool: ...
1116TComparable = TypeVar("TComparable", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue
1119@final
1120class SmallPriorityQueue(Generic[TComparable]):
1121 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and
1122 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case
1123 for us.
1125 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed
1126 priority queue via
1127 https://github.com/nvictus/pqdict.
1128 But, to avoid an external dependency, is actually implemented
1129 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that
1130 are already contained in the queue, via removal of the element, followed by update of the element, followed by
1131 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates.
1132 """
1134 def __init__(self, reverse: bool = False) -> None:
1135 """Creates an empty queue; sort order flips when ``reverse`` is True."""
1136 self._lst: Final[list[TComparable]] = []
1137 self._reverse: Final[bool] = reverse
1139 def clear(self) -> None:
1140 """Removes all elements from the queue."""
1141 self._lst.clear()
1143 def push(self, element: TComparable) -> None:
1144 """Inserts ``element`` while maintaining sorted order."""
1145 bisect.insort(self._lst, element)
1147 def pop(self) -> TComparable:
1148 """Removes and returns the smallest (or largest if reverse == True) element from the queue."""
1149 return self._lst.pop() if self._reverse else self._lst.pop(0)
1151 def peek(self) -> TComparable:
1152 """Returns the smallest (or largest if reverse == True) element without removing it."""
1153 return self._lst[-1] if self._reverse else self._lst[0]
1155 def remove(self, element: TComparable) -> bool:
1156 """Removes the first occurrence (in insertion order aka FIFO) of ``element`` and returns True if it was present."""
1157 lst = self._lst
1158 i = bisect.bisect_left(lst, element)
1159 is_contained = i < len(lst) and lst[i] == element
1160 if is_contained:
1161 del lst[i] # is an optimized memmove()
1162 return is_contained
1164 def __len__(self) -> int:
1165 """Returns the number of queued elements."""
1166 return len(self._lst)
1168 def __contains__(self, element: TComparable) -> bool:
1169 """Returns ``True`` if ``element`` is present."""
1170 lst = self._lst
1171 i = bisect.bisect_left(lst, element)
1172 return i < len(lst) and lst[i] == element
1174 def __iter__(self) -> Iterator[TComparable]:
1175 """Iterates over queued elements in priority order."""
1176 return reversed(self._lst) if self._reverse else iter(self._lst)
1178 def __repr__(self) -> str:
1179 """Representation showing queue contents in current order."""
1180 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst)
1183###############################################################################
1184@final
1185class SortedInterner(Generic[TComparable]):
1186 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search)."""
1188 def __init__(self, sorted_list: list[TComparable]) -> None:
1189 self._lst: Final[list[TComparable]] = sorted_list
1191 def interned(self, element: TComparable) -> TComparable:
1192 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1193 lst = self._lst
1194 i = binary_search(lst, element)
1195 return lst[i] if i >= 0 else element
1197 def __contains__(self, element: TComparable) -> bool:
1198 """Returns ``True`` if ``element`` is present."""
1199 return binary_search(self._lst, element) >= 0
1202def binary_search(sorted_list: list[TComparable], item: TComparable) -> int:
1203 """Java-style binary search; Returns index >= 0 if an equal item is found in list, else '-insertion_point-1'; If it
1204 returns index >= 0, the index will be the left-most index in case multiple such equal items are contained."""
1205 i = bisect.bisect_left(sorted_list, item)
1206 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1
1209###############################################################################
1210_S = TypeVar("_S")
1213@final
1214class HashedInterner(Generic[_S]):
1215 """Same as sys.intern() except that it isn't global and can also be used for types other than str."""
1217 def __init__(self, items: Iterable[_S] = frozenset()) -> None:
1218 self._items: Final[dict[_S, _S]] = {v: v for v in items}
1220 def intern(self, item: _S) -> _S:
1221 """Interns the given item."""
1222 return self._items.setdefault(item, item)
1224 def interned(self, item: _S) -> _S:
1225 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1226 return self._items.get(item, item)
1228 def __contains__(self, item: _S) -> bool:
1229 return item in self._items
1232#############################################################################
1233@final
1234class SynchronizedBool:
1235 """Thread-safe wrapper around a regular bool."""
1237 def __init__(self, val: bool) -> None:
1238 assert isinstance(val, bool)
1239 self._lock: Final[threading.Lock] = threading.Lock()
1240 self._value: bool = val
1242 @property
1243 def value(self) -> bool:
1244 """Returns the current boolean value."""
1245 with self._lock:
1246 return self._value
1248 @value.setter
1249 def value(self, new_value: bool) -> None:
1250 """Atomically assign ``new_value``."""
1251 with self._lock:
1252 self._value = new_value
1254 def get_and_set(self, new_value: bool) -> bool:
1255 """Swaps in ``new_value`` and return the previous value."""
1256 with self._lock:
1257 old_value = self._value
1258 self._value = new_value
1259 return old_value
1261 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool:
1262 """Sets to ``new_value`` only if current value equals ``expected_value``."""
1263 with self._lock:
1264 eq: bool = self._value == expected_value
1265 if eq:
1266 self._value = new_value
1267 return eq
1269 def __bool__(self) -> bool:
1270 return self.value
1272 def __repr__(self) -> str:
1273 return repr(self.value)
1275 def __str__(self) -> str:
1276 return str(self.value)
1279#############################################################################
1280_K = TypeVar("_K")
1281_V = TypeVar("_V")
1284@final
1285class SynchronizedDict(Generic[_K, _V]):
1286 """Thread-safe wrapper around a regular dict."""
1288 def __init__(self, val: dict[_K, _V]) -> None:
1289 assert isinstance(val, dict)
1290 self._lock: Final[threading.Lock] = threading.Lock()
1291 self._dict: Final[dict[_K, _V]] = val
1293 def __getitem__(self, key: _K) -> _V:
1294 with self._lock:
1295 return self._dict[key]
1297 def __setitem__(self, key: _K, value: _V) -> None:
1298 with self._lock:
1299 self._dict[key] = value
1301 def __delitem__(self, key: _K) -> None:
1302 with self._lock:
1303 self._dict.pop(key)
1305 def __contains__(self, key: _K) -> bool:
1306 with self._lock:
1307 return key in self._dict
1309 def __len__(self) -> int:
1310 with self._lock:
1311 return len(self._dict)
1313 def __repr__(self) -> str:
1314 with self._lock:
1315 return repr(self._dict)
1317 def __str__(self) -> str:
1318 with self._lock:
1319 return str(self._dict)
1321 def get(self, key: _K, default: _V | None = None) -> _V | None:
1322 """Returns ``self[key]`` or ``default`` if missing."""
1323 with self._lock:
1324 return self._dict.get(key, default)
1326 def pop(self, key: _K, default: _V | None = None) -> _V | None:
1327 """Removes ``key`` and returns its value."""
1328 with self._lock:
1329 return self._dict.pop(key, default)
1331 def clear(self) -> None:
1332 """Removes all items atomically."""
1333 with self._lock:
1334 self._dict.clear()
1336 def items(self) -> ItemsView[_K, _V]:
1337 """Returns a snapshot of dictionary items."""
1338 with self._lock:
1339 return self._dict.copy().items()
1342#############################################################################
1343@final
1344class InterruptibleSleep:
1345 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable."""
1347 def __init__(self, lock: threading.Lock | None = None) -> None:
1348 self._is_stopping: bool = False
1349 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock()
1350 self._condition: Final[threading.Condition] = threading.Condition(self._lock)
1352 def sleep(self, duration_nanos: int) -> bool:
1353 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted;
1354 Equivalent to threading.Event.wait()."""
1355 end_time_nanos: int = time.monotonic_ns() + duration_nanos
1356 with self._lock:
1357 while not self._is_stopping:
1358 diff_nanos: int = end_time_nanos - time.monotonic_ns()
1359 if diff_nanos <= 0:
1360 return False
1361 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout
1362 return True
1364 def interrupt(self) -> None:
1365 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set()."""
1366 with self._lock:
1367 if not self._is_stopping:
1368 self._is_stopping = True
1369 self._condition.notify_all()
1371 def reset(self) -> None:
1372 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear()."""
1373 with self._lock:
1374 self._is_stopping = False
1377#############################################################################
1378@final
1379class SynchronousExecutor(Executor):
1380 """Executor that runs tasks inline in the calling thread, sequentially."""
1382 def __init__(self) -> None:
1383 self._shutdown: bool = False
1385 def submit(self, fn: Callable[..., _R_], /, *args: Any, **kwargs: Any) -> Future[_R_]:
1386 """Executes `fn(*args, **kwargs)` immediately and returns its Future."""
1387 future: Future[_R_] = Future()
1388 if self._shutdown:
1389 raise RuntimeError("cannot schedule new futures after shutdown")
1390 try:
1391 result: _R_ = fn(*args, **kwargs)
1392 except BaseException as exc:
1393 future.set_exception(exc)
1394 else:
1395 future.set_result(result)
1396 return future
1398 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
1399 """Prevents new submissions; no worker resources to join/cleanup."""
1400 self._shutdown = True
1402 @classmethod
1403 def executor_for(cls, max_workers: int) -> Executor:
1404 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor."""
1405 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers)
1408#############################################################################
1409@final
1410class _XFinally(contextlib.AbstractContextManager):
1411 """Context manager ensuring cleanup code executes after ``with`` blocks."""
1413 def __init__(self, cleanup: Callable[[], None]) -> None:
1414 """Records the callable to run upon exit."""
1415 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits.
1417 def __exit__(
1418 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None
1419 ) -> Literal[False]:
1420 """Runs cleanup and propagate any exceptions appropriately."""
1421 try:
1422 self._cleanup()
1423 except BaseException as cleanup_exc:
1424 if exc is None:
1425 raise # No main error --> propagate cleanup error normally
1426 # Both failed
1427 # if sys.version_info >= (3, 11):
1428 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None
1429 # <= 3.10: attach so it shows up in traceback but doesn't mask
1430 exc.__context__ = cleanup_exc
1431 return False # reraise original exception
1432 return False # propagate main exception if any
1435def xfinally(cleanup: Callable[[], None]) -> _XFinally:
1436 """Usage: with xfinally(lambda: cleanup()): ...
1437 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never
1438 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible.
1440 Problem it solves
1441 -----------------
1442 A naive ``try ... finally`` may lose the original exception:
1444 try:
1445 work()
1446 finally:
1447 cleanup() # <-- if this raises an exception, it replaces the real error!
1449 `_XFinally` preserves exception priority:
1451 * Body raises, cleanup succeeds --> original body exception is re-raised.
1452 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``.
1453 * Body succeeds, cleanup raises --> cleanup exception propagates normally.
1455 Example:
1456 -------
1457 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP
1458 ... run_tasks()
1460 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information.
1461 """
1462 return _XFinally(cleanup)