Coverage for bzfs_main / util / utils.py: 100%
797 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 12:49 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-29 12:49 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Collection of helper functions used across bzfs; includes environment variable parsing, process management and lightweight
16concurrency primitives, etc.
18Everything in this module relies only on the Python standard library so other modules remain dependency free. Each utility
19favors simple, predictable behavior on all supported platforms.
20"""
22from __future__ import (
23 annotations,
24)
25import argparse
26import base64
27import bisect
28import contextlib
29import dataclasses
30import errno
31import hashlib
32import itertools
33import logging
34import operator
35import os
36import platform
37import pwd
38import random
39import re
40import signal
41import stat
42import subprocess
43import sys
44import threading
45import time
46import types
47from collections import (
48 defaultdict,
49 deque,
50)
51from collections.abc import (
52 ItemsView,
53 Iterable,
54 Iterator,
55 Sequence,
56)
57from concurrent.futures import (
58 Executor,
59 Future,
60)
61from concurrent.futures.thread import (
62 ThreadPoolExecutor,
63)
64from dataclasses import (
65 dataclass,
66)
67from datetime import (
68 datetime,
69 timedelta,
70 timezone,
71 tzinfo,
72)
73from subprocess import (
74 DEVNULL,
75 PIPE,
76)
77from typing import (
78 IO,
79 Any,
80 Callable,
81 Final,
82 Generic,
83 Literal,
84 NoReturn,
85 Protocol,
86 SupportsIndex,
87 TextIO,
88 TypeVar,
89 cast,
90 final,
91)
93# constants:
94PROG_NAME: Final[str] = "bzfs"
95ENV_VAR_PREFIX: Final[str] = PROG_NAME + "_"
96DIE_STATUS: Final[int] = 3
97DESCENDANTS_RE_SUFFIX: Final[str] = r"(?:/.*)?" # also match descendants of a matching dataset
98LOG_STDERR: Final[int] = (logging.INFO + logging.WARNING) // 2 # custom log level is halfway in between
99LOG_STDOUT: Final[int] = (LOG_STDERR + logging.INFO) // 2 # custom log level is halfway in between
100LOG_DEBUG: Final[int] = logging.DEBUG
101LOG_TRACE: Final[int] = logging.DEBUG // 2 # custom log level is halfway in between
102YEAR_WITH_FOUR_DIGITS_REGEX: Final[re.Pattern] = re.compile(r"[1-9][0-9][0-9][0-9]") # empty shall not match nonempty target
103UNIX_TIME_INFINITY_SECS: Final[int] = 2**64 # billions of years and to be extra safe, larger than the largest ZFS GUID
104DONT_SKIP_DATASET: Final[str] = ""
105SHELL_CHARS: Final[str] = '"' + "'`~!@#$%^&*()+={}[]|;<>?,\\" # intentionally not included: -_.:/
106SHELL_CHARS_AND_SLASH: Final[str] = SHELL_CHARS + "/"
107FILE_PERMISSIONS: Final[int] = stat.S_IRUSR | stat.S_IWUSR # rw------- (user read + write)
108DIR_PERMISSIONS: Final[int] = stat.S_IRWXU # rwx------ (user read + write + execute)
109UMASK: Final[int] = (~DIR_PERMISSIONS) & 0o777 # so intermediate dirs created by os.makedirs() have stricter permissions
110UNIX_DOMAIN_SOCKET_PATH_MAX_LENGTH: Final[int] = 107 if platform.system() == "Linux" else 103 # see Google for 'sun_path'
112RegexList = list[tuple[re.Pattern[str], bool]] # Type alias
115def getenv_any(key: str, default: str | None = None, env_var_prefix: str = ENV_VAR_PREFIX) -> str | None:
116 """All shell environment variable names used for configuration start with this prefix."""
117 return os.getenv(env_var_prefix + key, default)
120def getenv_int(key: str, default: int, env_var_prefix: str = ENV_VAR_PREFIX) -> int:
121 """Returns environment variable ``key`` as int with ``default`` fallback."""
122 return int(cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)))
125def getenv_bool(key: str, default: bool = False, env_var_prefix: str = ENV_VAR_PREFIX) -> bool:
126 """Returns environment variable ``key`` as bool with ``default`` fallback."""
127 return cast(str, getenv_any(key, default=str(default), env_var_prefix=env_var_prefix)).lower().strip() == "true"
130def cut(field: int, separator: str = "\t", *, lines: list[str]) -> list[str]:
131 """Retains only column number 'field' in a list of TSV/CSV lines; Analog to Unix 'cut' CLI command."""
132 assert lines is not None
133 assert isinstance(lines, list)
134 assert len(separator) == 1
135 if field == 1:
136 return [line[: line.index(separator)] for line in lines]
137 elif field == 2:
138 return [line[line.index(separator) + 1 :] for line in lines]
139 else:
140 raise ValueError(f"Invalid field value: {field}")
143def drain(iterable: Iterable[Any]) -> None:
144 """Consumes all items in the iterable, effectively draining it."""
145 for _ in iterable:
146 del _ # help gc (iterable can block)
149_K_ = TypeVar("_K_")
150_V_ = TypeVar("_V_")
151_R_ = TypeVar("_R_")
154def shuffle_dict(dictionary: dict[_K_, _V_], /, rand: random.Random = random.SystemRandom()) -> dict[_K_, _V_]: # noqa: B008
155 """Returns a new dict with items shuffled randomly."""
156 items: list[tuple[_K_, _V_]] = list(dictionary.items())
157 rand.shuffle(items)
158 return dict(items)
161def sorted_dict(
162 dictionary: dict[_K_, _V_], /, *, key: Callable[[tuple[_K_, _V_]], Any] | None = None, reverse: bool = False
163) -> dict[_K_, _V_]:
164 """Returns a new dict with items sorted, primarily by key and secondarily by value (unless a custom key is supplied)."""
165 return dict(sorted(dictionary.items(), key=key, reverse=reverse))
168def tail(file: str, *, n: int, errors: str | None = None) -> Sequence[str]:
169 """Return the last ``n`` lines of ``file`` without following symlinks."""
170 if not os.path.isfile(file):
171 return []
172 with open_nofollow(file, "r", encoding="utf-8", errors=errors, check_owner=False, check_perm=False) as fd:
173 return deque(fd, maxlen=n)
176_NAMED_CAPTURING_GROUP: Final[re.Pattern[str]] = re.compile(r"^" + re.escape("(?P<") + r"[^\W\d]\w*" + re.escape(">"))
177_NUMERIC_BACKREFERENCE_REGEX: Final[re.Pattern[str]] = re.compile(r"\\\d+") # example: \1
180def replace_capturing_groups_with_non_capturing_groups(regex: str) -> str:
181 """Replaces regex capturing groups with non-capturing groups for better matching performance (unless it's tricky).
183 Unnamed capturing groups example: '(.*/)?tmp(foo|bar)(?!public)\\(' --> '(?:.*/)?tmp(?:foo|bar)(?!public)\\('
184 Aka replaces parenthesis '(' followed by a char other than question mark '?', but not preceded by a backslash
185 with the replacement string '(?:'
187 Named capturing group example: '(?P<name>abc)' --> '(?:abc)'
188 Aka replaces '(?P<' followed by a valid name followed by '>', but not preceded by a backslash
189 with the replacement string '(?:'
191 Also see https://docs.python.org/3/howto/regex.html#non-capturing-and-named-groups
192 """
193 i = regex.find("[")
194 if i >= 0 and regex.find("(", i) >= 0:
195 # Conservative fallback to minimize code complexity: skip the rewrite entirely in the case where the regex might
196 # contain a regex character class that contains parenthesis.
197 # Rewriting a regex is a performance optimization; correctness comes first.
198 return regex
200 if "(?P=" in regex or "(?(" in regex or _NUMERIC_BACKREFERENCE_REGEX.search(regex):
201 # Conservative fallback to minimize code complexity: skip the rewrite entirely if the regex might contain a
202 # (named or conditional or numeric) backreference.
203 # Rewriting a regex is a performance optimization; correctness comes first.
204 return regex
206 i = len(regex) - 2
207 while i >= 0:
208 i = regex.rfind("(", 0, i + 1)
209 if i >= 0 and (i == 0 or regex[i - 1] != "\\"):
210 if regex[i + 1] != "?":
211 regex = f"{regex[0:i]}(?:{regex[i + 1:]}" # unnamed capturing group
212 else: # potentially a valid named capturing group
213 regex = regex[0:i] + _NAMED_CAPTURING_GROUP.sub(repl="(?:", string=regex[i:], count=1)
214 i -= 1
215 return regex
218def get_home_directory() -> str:
219 """Reliably detects home dir without using HOME env var."""
220 # thread-safe version of: os.environ.pop('HOME', None); os.path.expanduser('~')
221 return pwd.getpwuid(os.getuid()).pw_dir
224def human_readable_bytes(num_bytes: float, *, separator: str = " ", precision: int | None = None) -> str:
225 """Formats 'num_bytes' as a human-readable size; for example "567 MiB"."""
226 sign = "-" if num_bytes < 0 else ""
227 s = abs(num_bytes)
228 units = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB", "RiB", "QiB")
229 n = len(units) - 1
230 i = 0
231 while s >= 1024 and i < n:
232 s /= 1024
233 i += 1
234 formatted_num = human_readable_float(s) if precision is None else f"{s:.{precision}f}"
235 return f"{sign}{formatted_num}{separator}{units[i]}"
238def human_readable_duration(duration: float, *, unit: str = "ns", separator: str = "", precision: int | None = None) -> str:
239 """Formats a duration in human units, automatically scaling as needed; for example "567ms"."""
240 sign = "-" if duration < 0 else ""
241 t = abs(duration)
242 units = ("ns", "μs", "ms", "s", "m", "h", "d")
243 i = units.index(unit)
244 if t < 1 and t != 0:
245 nanos = (1, 1_000, 1_000_000, 1_000_000_000, 60 * 1_000_000_000, 60 * 60 * 1_000_000_000, 3600 * 24 * 1_000_000_000)
246 t *= nanos[i]
247 i = 0
248 while t >= 1000 and i < 3:
249 t /= 1000
250 i += 1
251 if i >= 3:
252 while t >= 60 and i < 5:
253 t /= 60
254 i += 1
255 if i >= 5:
256 while t >= 24 and i < len(units) - 1:
257 t /= 24
258 i += 1
259 formatted_num = human_readable_float(t) if precision is None else f"{t:.{precision}f}"
260 return f"{sign}{formatted_num}{separator}{units[i]}"
263def human_readable_float(number: float) -> str:
264 """Formats ``number`` with a variable precision depending on magnitude.
266 This design mirrors the way humans round values when scanning logs.
268 If the number has one digit before the decimal point (0 <= abs(number) < 10):
269 Round and use two decimals after the decimal point (e.g., 3.14559 --> "3.15").
271 If the number has two digits before the decimal point (10 <= abs(number) < 100):
272 Round and use one decimal after the decimal point (e.g., 12.36 --> "12.4").
274 If the number has three or more digits before the decimal point (abs(number) >= 100):
275 Round and use zero decimals after the decimal point (e.g., 123.556 --> "124").
277 Ensures no unnecessary trailing zeroes are retained: Example: 1.500 --> "1.5", 1.00 --> "1"
278 """
279 abs_number = abs(number)
280 precision = 2 if abs_number < 10 else 1 if abs_number < 100 else 0
281 if precision == 0:
282 return str(round(number))
283 result = f"{number:.{precision}f}"
284 assert "." in result
285 result = result.rstrip("0").rstrip(".") # Remove trailing zeros and trailing decimal point if empty
286 return "0" if result == "-0" else result
289def percent(number: int, total: int, *, print_total: bool = False) -> str:
290 """Returns percentage string of ``number`` relative to ``total``."""
291 tot: str = f"/{total}" if print_total else ""
292 return f"{number}{tot}={'inf' if total == 0 else human_readable_float(100 * number / total)}%"
295def open_nofollow(
296 path: str,
297 mode: str = "r",
298 buffering: int = -1,
299 encoding: str | None = None,
300 errors: str | None = None,
301 newline: str | None = None,
302 *,
303 perm: int = FILE_PERMISSIONS,
304 check_owner: bool = True,
305 check_perm: bool = True,
306 **kwargs: Any,
307) -> IO[Any]:
308 """Behaves exactly like built-in open(), except that it refuses to follow symlinks, i.e. raises OSError with
309 errno.ELOOP/EMLINK if basename of path is a symlink.
311 Also, can specify custom permissions on O_CREAT, and verify secure ownership and mode bits.
313 If check_owner=True, write-capable opens require ownership by the effective UID; read-only opens also allow ownership by
314 uid 0 (root). This allows safe reads of root-owned system files while preventing writes to files not owned by the caller.
316 If check_perm=True, the file must also not be writable by group or others, and must not be executable by anyone.
317 """
318 if not mode:
319 raise ValueError("Must have exactly one of create/read/write/append mode and at most one plus")
320 flags = {
321 "r": os.O_RDONLY,
322 "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
323 "a": os.O_WRONLY | os.O_CREAT | os.O_APPEND,
324 "x": os.O_WRONLY | os.O_CREAT | os.O_EXCL,
325 }.get(mode[0])
326 if flags is None:
327 raise ValueError(f"invalid mode {mode!r}")
328 if "+" in mode: # enable read-write access for r+, w+, a+, x+
329 flags = (flags & ~os.O_WRONLY) | os.O_RDWR # clear os.O_WRONLY and set os.O_RDWR while preserving all other flags
330 flags |= os.O_NOFOLLOW | os.O_CLOEXEC
331 fd: int = os.open(path, flags=flags, mode=perm)
332 try:
333 stats: os.stat_result | None = None
334 if check_owner:
335 stats = os.fstat(fd)
336 st_uid: int = stats.st_uid
337 if st_uid != os.geteuid(): # verify ownership is current effective UID
338 if (flags & (os.O_WRONLY | os.O_RDWR)) != 0: # require that writer owns the file
339 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}", path)
340 elif st_uid != 0: # it's ok for root to own a file that we'll merely read
341 raise PermissionError(errno.EPERM, f"{path!r} is owned by uid {st_uid}, not {os.geteuid()} or 0", path)
342 if check_perm:
343 stats = os.fstat(fd) if stats is None else stats
344 if (stats.st_mode & (stat.S_IWGRP | stat.S_IWOTH | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)) != 0:
345 st_mode: int = stat.S_IMODE(stats.st_mode)
346 raise PermissionError(
347 errno.EPERM,
348 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]} "
349 "but must not be writable by group or others, or executable by anyone",
350 path,
351 )
352 return os.fdopen(fd, mode, buffering=buffering, encoding=encoding, errors=errors, newline=newline, **kwargs)
353 except Exception:
354 try:
355 os.close(fd)
356 except OSError:
357 pass
358 raise
361def close_quietly(fd: int) -> None:
362 """Closes the given file descriptor while silently swallowing any OSError that might arise as part of this."""
363 if fd >= 0:
364 try:
365 os.close(fd)
366 except OSError:
367 pass
370_P = TypeVar("_P")
373def find_match(
374 seq: Sequence[_P],
375 predicate: Callable[[_P], bool],
376 start: SupportsIndex | None = None,
377 end: SupportsIndex | None = None,
378 *,
379 reverse: bool = False,
380 raises: bool | object | Callable[[], object] = False, # raises: bool | object | Callable = False, # python >= 3.10
381) -> int:
382 """Returns the integer index within ``seq`` of the first item (or last item if reverse=True) that matches the given
383 predicate condition.
385 If no matching item is found returns -1 or ValueError, depending on the ``raises`` parameter, which is a bool indicating
386 whether to raise an error, or an object containing the error message, but can also be a Callable/lambda in order to
387 support efficient deferred generation of error messages.
389 Analog to ``str.find()``, including slicing semantics with parameters start and end, i.e. respects Python slicing
390 semantics for start/end (including clamping). For example, seq can be a list, tuple or str.
392 Example usage:
393 lst = ["a", "b", "-c", "d"]
394 i = find_match(lst, lambda arg: arg.startswith("-"), start=1, end=3, reverse=True)
395 if i >= 0:
396 print(lst[i])
397 i = find_match(lst, lambda arg: arg.startswith("-"), raises=f"Tag {tag} not found in {file}")
398 i = find_match(lst, lambda arg: arg.startswith("-"), raises=lambda: f"Tag {tag} not found in {file}")
399 """
400 if start is None and end is None:
401 for i in range(len(seq) - 1, -1, -1) if reverse else range(len(seq)):
402 if predicate(seq[i]):
403 return i
404 else:
405 slice_start, slice_end, _ = slice(start, end).indices(len(seq))
406 for i in range(slice_end - 1, slice_start - 1, -1) if reverse else range(slice_start, slice_end):
407 if predicate(seq[i]):
408 return i
409 if raises is False or raises is None:
410 return -1
411 if raises is True:
412 raise ValueError("No matching item found in sequence")
413 if callable(raises):
414 raises = raises()
415 raise ValueError(raises)
418def is_descendant(dataset: str, of_root_dataset: str) -> bool:
419 """Returns True if ZFS ``dataset`` lies under ``of_root_dataset`` in the dataset hierarchy, or is the same."""
420 return dataset == of_root_dataset or dataset.startswith(of_root_dataset + "/")
423def has_duplicates(sorted_list: list[Any]) -> bool:
424 """Returns True if any adjacent items within the given sorted sequence are equal."""
425 return any(map(operator.eq, sorted_list, itertools.islice(sorted_list, 1, None)))
428def has_siblings(sorted_datasets: list[str], is_test_mode: bool = False) -> bool:
429 """Returns whether the (sorted) list of ZFS input datasets contains any siblings."""
430 assert (not is_test_mode) or sorted_datasets == sorted(sorted_datasets), "List is not sorted"
431 assert (not is_test_mode) or not has_duplicates(sorted_datasets), "List contains duplicates"
432 skip_dataset: str = DONT_SKIP_DATASET
433 parents: set[str] = set()
434 for dataset in sorted_datasets:
435 assert dataset
436 parent = os.path.dirname(dataset)
437 if parent in parents:
438 return True # I have a sibling if my parent already has another child
439 parents.add(parent)
440 if is_descendant(dataset, of_root_dataset=skip_dataset):
441 continue
442 if skip_dataset != DONT_SKIP_DATASET:
443 return True # I have a sibling if I am a root dataset and another root dataset already exists
444 skip_dataset = dataset
445 return False
448def dry(msg: str, is_dry_run: bool) -> str:
449 """Prefix ``msg`` with 'Dry' when in dry-run mode."""
450 return "Dry " + msg if is_dry_run else msg
453def relativize_dataset(dataset: str, root_dataset: str) -> str:
454 """Converts an absolute dataset path to one relative to ``root_dataset``.
456 Example: root_dataset=tank/foo, dataset=tank/foo/bar/baz --> relative_path=/bar/baz.
457 """
458 return dataset[len(root_dataset) :]
461def dataset_paths(dataset: str) -> Iterator[str]:
462 """Enumerates all paths of a valid ZFS dataset name; Example: "a/b/c" --> yields "a", "a/b", "a/b/c"."""
463 i: int = 0
464 while i >= 0:
465 i = dataset.find("/", i)
466 if i < 0:
467 yield dataset
468 else:
469 yield dataset[:i]
470 i += 1
473def replace_prefix(s: str, old_prefix: str, new_prefix: str) -> str:
474 """In a string s, replaces a leading old_prefix string with new_prefix; assumes the leading string is present."""
475 assert s.startswith(old_prefix)
476 return new_prefix + s[len(old_prefix) :]
479def replace_in_lines(lines: list[str], old: str, new: str, count: int = -1) -> None:
480 """Replaces ``old`` with ``new`` in-place for every string in ``lines``."""
481 for i in range(len(lines)):
482 lines[i] = lines[i].replace(old, new, count)
485_TAPPEND = TypeVar("_TAPPEND")
488def append_if_absent(lst: list[_TAPPEND], *items: _TAPPEND) -> list[_TAPPEND]:
489 """Appends items to list if they are not already present."""
490 for item in items:
491 if item not in lst:
492 lst.append(item)
493 return lst
496def xappend(lst: list[_TAPPEND], *items: _TAPPEND | Iterable[_TAPPEND]) -> list[_TAPPEND]:
497 """Appends each of the items to the given list if the item is "truthy", for example not None and not an empty string; If
498 an item is an iterable does so recursively, flattening the output."""
499 for item in items:
500 if isinstance(item, str) or not isinstance(item, Iterable):
501 if item:
502 lst.append(item)
503 else:
504 xappend(lst, *item)
505 return lst
508def is_included(name: str, include_regexes: RegexList, exclude_regexes: RegexList) -> bool:
509 """Returns True if the name matches at least one of the include regexes but none of the exclude regexes; else False.
511 A regex that starts with a `!` is a negation - the regex matches if the regex without the `!` prefix does not match.
512 """
513 for regex, is_negation in exclude_regexes:
514 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
515 if is_negation:
516 is_match = not is_match
517 if is_match:
518 return False
520 for regex, is_negation in include_regexes:
521 is_match = regex.fullmatch(name) if regex.pattern != ".*" else True
522 if is_negation:
523 is_match = not is_match
524 if is_match:
525 return True
527 return False
530def compile_regexes(regexes: list[str], *, suffix: str = "") -> RegexList:
531 """Compiles regex strings and keeps track of negations."""
532 assert isinstance(regexes, list)
533 compiled_regexes: RegexList = []
534 for regex in regexes:
535 if suffix: # disallow non-trailing end-of-str symbol in dataset regexes to ensure descendants will also match
536 if regex.endswith("\\$"):
537 pass # trailing literal $ is ok
538 elif regex.endswith("$"):
539 regex = regex[0:-1] # ok because all users of compile_regexes() call re.fullmatch()
540 elif "$" in regex:
541 raise re.error("Must not use non-trailing '$' character", regex)
542 if is_negation := regex.startswith("!"):
543 regex = regex[1:]
544 regex = replace_capturing_groups_with_non_capturing_groups(regex)
545 if regex != ".*" or not (suffix.startswith("(") and suffix.endswith(")?")):
546 regex = f"{regex}{suffix}"
547 compiled_regexes.append((re.compile(regex), is_negation))
548 return compiled_regexes
551def list_formatter(iterable: Iterable[Any], separator: str = " ", lstrip: bool = False) -> Any:
552 """Lazy formatter joining items with ``separator`` used to avoid overhead in disabled log levels."""
554 @final
555 class CustomListFormatter:
556 """Formatter object that joins items when converted to ``str``."""
558 def __str__(self) -> str:
559 s = separator.join(map(str, iterable))
560 return s.lstrip() if lstrip else s
562 return CustomListFormatter()
565def pretty_print_formatter(obj_to_format: Any) -> Any:
566 """Lazy pprint formatter used to avoid overhead in disabled log levels."""
568 @final
569 class PrettyPrintFormatter:
570 """Formatter that pretty-prints the object on conversion to ``str``."""
572 def __str__(self) -> str:
573 import pprint # lazy import for startup perf
575 return pprint.pformat(vars(obj_to_format))
577 return PrettyPrintFormatter()
580def stderr_to_str(stderr: Any) -> str:
581 """Workaround for https://github.com/python/cpython/issues/87597."""
582 return str(stderr) if not isinstance(stderr, bytes) else stderr.decode("utf-8", errors="replace")
585def xprint(log: logging.Logger, value: Any, *, run: bool = True, end: str = "\n", file: TextIO | None = None) -> None:
586 """Optionally logs ``value`` at stdout/stderr level."""
587 if run and value:
588 value = value if end else str(value).rstrip()
589 level = LOG_STDOUT if file is sys.stdout else LOG_STDERR
590 log.log(level, "%s", value)
593def sha256_hex(text: str) -> str:
594 """Returns the sha256 hex string for the given text."""
595 return hashlib.sha256(text.encode()).hexdigest()
598def sha256_urlsafe_base64(text: str, *, padding: bool = True) -> str:
599 """Returns the URL-safe base64-encoded sha256 value for the given text."""
600 digest: bytes = hashlib.sha256(text.encode()).digest()
601 s: str = base64.urlsafe_b64encode(digest).decode()
602 return s if padding else s.rstrip("=")
605def sha256_128_urlsafe_base64(text: str) -> str:
606 """Returns the left half portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
607 s: str = sha256_urlsafe_base64(text, padding=False)
608 return s[: len(s) // 2]
611def sha256_85_urlsafe_base64(text: str) -> str:
612 """Returns the left one third portion of the unpadded URL-safe base64-encoded sha256 value for the given text."""
613 s: str = sha256_urlsafe_base64(text, padding=False)
614 return s[: len(s) // 3]
617def urlsafe_base64(
618 value: int, max_value: int = 2**64 - 1, *, padding: bool = True, byteorder: Literal["little", "big"] = "big"
619) -> str:
620 """Returns the URL-safe base64 string encoding of the int value, assuming it is contained in the range [0..max_value]."""
621 assert 0 <= value <= max_value
622 max_bytes: int = (max_value.bit_length() + 7) // 8
623 value_bytes: bytes = value.to_bytes(max_bytes, byteorder)
624 s: str = base64.urlsafe_b64encode(value_bytes).decode()
625 return s if padding else s.rstrip("=")
628def die(msg: str, exit_code: int = DIE_STATUS, parser: argparse.ArgumentParser | None = None) -> NoReturn:
629 """Exits the program with ``exit_code`` after logging ``msg``."""
630 if parser is None:
631 ex = SystemExit(msg)
632 ex.code = exit_code
633 raise ex
634 else:
635 parser.error(msg)
638def subprocess_run(*args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
639 """Drop-in replacement for subprocess.run() that mimics its behavior except it enhances cleanup on TimeoutExpired, and
640 provides optional child PID tracking, and optional logging of execution status via ``log`` and ``loglevel`` params."""
641 input_value = kwargs.pop("input", None)
642 timeout = kwargs.pop("timeout", None)
643 check = kwargs.pop("check", False)
644 subprocesses: Subprocesses | None = kwargs.pop("subprocesses", None)
645 if input_value is not None:
646 if kwargs.get("stdin") is not None:
647 raise ValueError("input and stdin are mutually exclusive")
648 kwargs["stdin"] = subprocess.PIPE
650 log: logging.Logger | None = kwargs.pop("log", None)
651 loglevel: int | None = kwargs.pop("loglevel", None)
652 start_time_nanos: int = time.monotonic_ns()
653 is_timeout: bool = False
654 is_cancel: bool = False
655 exitcode: int | None = None
657 def log_status() -> None:
658 if log is not None:
659 _loglevel: int = loglevel if loglevel is not None else getenv_int("subprocess_run_loglevel", LOG_TRACE)
660 if log.isEnabledFor(_loglevel):
661 elapsed_time: str = human_readable_float((time.monotonic_ns() - start_time_nanos) / 1_000_000) + "ms"
662 status = "cancel" if is_cancel else "timeout" if is_timeout else "success" if exitcode == 0 else "failure"
663 cmd = kwargs["args"] if "args" in kwargs else (args[0] if args else None)
664 cmd_str: str = " ".join(str(arg) for arg in iter(cmd)) if isinstance(cmd, (list, tuple)) else str(cmd)
665 log.log(_loglevel, f"Executed [{status}] [{elapsed_time}]: %s", cmd_str)
667 with xfinally(log_status):
668 ctx: contextlib.AbstractContextManager[subprocess.Popen]
669 if subprocesses is None:
670 ctx = subprocess.Popen(*args, **kwargs)
671 else:
672 ctx = subprocesses.popen_and_track(*args, **kwargs)
673 with ctx as proc:
674 try:
675 sp = subprocesses
676 if sp is not None and sp._is_terminated(): # noqa: SLF001 pylint: disable=protected-access
677 is_cancel = True
678 timeout = 0.0
679 stdout, stderr = proc.communicate(input_value, timeout=timeout)
680 except BaseException as e:
681 try:
682 if isinstance(e, subprocess.TimeoutExpired):
683 is_timeout = True
684 terminate_process_subtree(root_pids=[proc.pid]) # send SIGTERM to child proc and descendants
685 finally:
686 proc.kill()
687 raise
688 else:
689 exitcode = proc.poll()
690 assert exitcode is not None
691 if check and exitcode:
692 raise subprocess.CalledProcessError(exitcode, proc.args, output=stdout, stderr=stderr)
693 return subprocess.CompletedProcess(proc.args, exitcode, stdout, stderr)
696def terminate_process_subtree(
697 *, except_current_process: bool = True, root_pids: list[int] | None = None, sig: signal.Signals = signal.SIGTERM
698) -> None:
699 """For each root PID: Sends the given signal to the root PID and all its descendant processes."""
700 current_pid: int = os.getpid()
701 root_pids = [current_pid] if root_pids is None else root_pids
702 all_pids: list[list[int]] = _get_descendant_processes(root_pids)
703 assert len(all_pids) == len(root_pids)
704 for i, pids in enumerate(all_pids):
705 root_pid = root_pids[i]
706 if root_pid == current_pid:
707 pids += [] if except_current_process else [current_pid]
708 else:
709 pids.insert(0, root_pid)
710 for pid in pids:
711 with contextlib.suppress(OSError):
712 os.kill(pid, sig)
715def _get_descendant_processes(root_pids: list[int]) -> list[list[int]]:
716 """For each root PID, returns the list of all descendant process IDs for the given root PID, on POSIX systems."""
717 if len(root_pids) == 0:
718 return []
719 cmd: list[str] = ["ps", "-Ao", "pid,ppid"]
720 try:
721 lines: list[str] = subprocess.run(cmd, stdin=DEVNULL, stdout=PIPE, text=True, check=True).stdout.splitlines()
722 except PermissionError:
723 # degrade gracefully in sandbox environments that deny executing `ps` entirely
724 return [[] for _ in root_pids]
725 procs: dict[int, list[int]] = defaultdict(list)
726 for line in lines[1:]: # all lines except the header line
727 splits: list[str] = line.split()
728 assert len(splits) == 2
729 pid = int(splits[0])
730 ppid = int(splits[1])
731 procs[ppid].append(pid)
733 def recursive_append(ppid: int, descendants: list[int]) -> None:
734 """Recursively collect descendant PIDs starting from ``ppid``."""
735 for child_pid in procs[ppid]:
736 descendants.append(child_pid)
737 recursive_append(child_pid, descendants)
739 all_descendants: list[list[int]] = []
740 for root_pid in root_pids:
741 descendants: list[int] = []
742 recursive_append(root_pid, descendants)
743 all_descendants.append(descendants)
744 return all_descendants
747@contextlib.contextmanager
748def termination_signal_handler(
749 termination_events: list[threading.Event],
750 *,
751 termination_handler: Callable[[], None] = lambda: terminate_process_subtree(),
752) -> Iterator[None]:
753 """Context manager that installs SIGINT/SIGTERM handlers that set all ``termination_events`` and, by default, terminate
754 all descendant processes."""
755 termination_events = list(termination_events) # shallow copy
757 def _handler(_sig: int, _frame: object) -> None:
758 for event in termination_events:
759 event.set()
760 termination_handler()
762 previous_int_handler = signal.signal(signal.SIGINT, _handler) # install new signal handler
763 previous_term_handler = signal.signal(signal.SIGTERM, _handler) # install new signal handler
764 try:
765 yield # run body of context manager
766 finally:
767 signal.signal(signal.SIGINT, previous_int_handler) # restore original signal handler
768 signal.signal(signal.SIGTERM, previous_term_handler) # restore original signal handler
771def return_false() -> bool:
772 """Always returns ``False``; picklable."""
773 return False
776def sleep_nanos(delay_nanos: int) -> None:
777 """Same as time.sleep() but expects a relative sleep duration in nanoseconds as input value; picklable."""
778 time.sleep(delay_nanos / 1_000_000_000)
781#############################################################################
782@dataclass(frozen=True)
783@final
784class TaskTiming:
785 """Customizable callbacks for reading the current monotonic time, sleeping and optional async termination; immutable."""
787 monotonic_ns: Callable[[], int] = time.monotonic_ns
789 is_terminated: Callable[[], bool] = return_false
790 """Returns whether a predicate has become true; can be used to indicate system shutdown or similar cancellation
791 conditions; default is to always return ``False``."""
793 sleep: Callable[[int], None] = sleep_nanos
794 """Sleeps N nanoseconds; thread-safe."""
796 def copy(self, **override_kwargs: Any) -> TaskTiming:
797 """Creates a new object copying an existing one with the specified fields overridden for customization; thread-
798 safe."""
799 return dataclasses.replace(self, **override_kwargs)
801 @staticmethod
802 def make_from(termination_event: threading.Event | None) -> TaskTiming:
803 """Convenience factory that creates an object that performs async termination when ``termination_event`` is set."""
804 if termination_event is None:
805 return TaskTiming()
807 def _sleep(delay_nanos: int) -> None:
808 termination_event.wait(delay_nanos / 1_000_000_000) # allow early wakeup on async termination
810 return TaskTiming(is_terminated=termination_event.is_set, sleep=_sleep)
813#############################################################################
814@final
815class Subprocesses:
816 """Provides per-job tracking of child PIDs so a job can safely terminate only the subprocesses it spawned itself; used
817 when multiple jobs run concurrently within the same Python process.
819 Optionally binds to an ``_is_terminated`` predicate to enforce async cancellation by forcing immediate timeouts for newly
820 spawned subprocesses once cancellation is requested.
821 """
823 def __init__(self, is_terminated: Callable[[], bool] = return_false) -> None:
824 self._is_terminated: Final[Callable[[], bool]] = is_terminated
825 self._lock: Final[threading.Lock] = threading.Lock()
826 self._child_pids: Final[dict[int, None]] = {} # a set that preserves insertion order
828 @contextlib.contextmanager
829 def popen_and_track(self, *popen_args: Any, **popen_kwargs: Any) -> Iterator[subprocess.Popen]:
830 """Context manager that calls subprocess.Popen() and tracks the child PID for per-job termination.
832 Holds a lock across Popen+PID registration to prevent a race when terminate_process_subtrees() is invoked (e.g. from
833 SIGINT/SIGTERM handlers), ensuring newly spawned child processes cannot escape termination. The child PID is
834 unregistered on context exit.
835 """
836 with self._lock:
837 proc: subprocess.Popen = subprocess.Popen(*popen_args, **popen_kwargs)
838 self._child_pids[proc.pid] = None
839 try:
840 yield proc
841 finally:
842 with self._lock:
843 self._child_pids.pop(proc.pid, None)
845 def subprocess_run(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess:
846 """Wrapper around utils.subprocess_run() that auto-registers/unregisters child PIDs for per-job termination."""
847 return subprocess_run(*args, **kwargs, subprocesses=self)
849 def terminate_process_subtrees(self, sig: signal.Signals = signal.SIGTERM) -> None:
850 """Sends the given signal to all tracked child PIDs and their descendants, ignoring errors for dead PIDs."""
851 with self._lock:
852 pids: list[int] = list(self._child_pids)
853 self._child_pids.clear()
854 terminate_process_subtree(root_pids=pids, sig=sig)
857#############################################################################
858def pid_exists(pid: int) -> bool | None:
859 """Returns True if a process with PID exists, False if not, or None on error."""
860 if pid <= 0:
861 return False
862 try: # with signal=0, no signal is actually sent, but error checking is still performed
863 os.kill(pid, 0) # ... which can be used to check for process existence on POSIX systems
864 except OSError as err:
865 if err.errno == errno.ESRCH: # No such process
866 return False
867 if err.errno == errno.EPERM: # Operation not permitted
868 return True
869 return None
870 return True
873def nprefix(s: str) -> str:
874 """Returns a canonical snapshot prefix with trailing underscore."""
875 return sys.intern(s + "_")
878def ninfix(s: str) -> str:
879 """Returns a canonical infix with trailing underscore when not empty."""
880 return sys.intern(s + "_") if s else ""
883def nsuffix(s: str) -> str:
884 """Returns a canonical suffix with leading underscore when not empty."""
885 return sys.intern("_" + s) if s else ""
888def format_dict(dictionary: dict[Any, Any]) -> str:
889 """Returns a formatted dictionary using repr for consistent output."""
890 return f'"{dictionary}"'
893def format_obj(obj: object) -> str:
894 """Returns a formatted str using repr for consistent output."""
895 return f'"{obj}"'
898def validate_dataset_name(dataset: str, input_text: str) -> None:
899 """'zfs create' CLI does not accept dataset names that are empty or start or end in a slash, etc."""
900 # Also see https://github.com/openzfs/zfs/issues/439#issuecomment-2784424
901 # and https://github.com/openzfs/zfs/issues/8798
902 # and (by now no longer accurate): https://docs.oracle.com/cd/E26505_01/html/E37384/gbcpt.html
903 invalid_chars: str = SHELL_CHARS
904 if (
905 dataset in ("", ".", "..")
906 or dataset.startswith(("/", "./", "../"))
907 or dataset.endswith(("/", "/.", "/.."))
908 or any(substring in dataset for substring in ("//", "/./", "/../"))
909 or any(char in invalid_chars or (char.isspace() and char != " ") for char in dataset)
910 or not dataset[0].isalpha()
911 ):
912 die(f"Invalid ZFS dataset name: '{dataset}' for: '{input_text}'")
915def validate_property_name(propname: str, input_text: str) -> str:
916 """Checks that the ZFS property name contains no spaces or shell chars, etc."""
917 invalid_chars: str = SHELL_CHARS
918 if (not propname) or propname.startswith("-") or any(char.isspace() or char in invalid_chars for char in propname):
919 die(f"Invalid ZFS property name: '{propname}' for: '{input_text}'")
920 return propname
923def validate_is_not_a_symlink(msg: str, path: str, parser: argparse.ArgumentParser | None = None) -> None:
924 """Checks that the given path is not a symbolic link."""
925 if os.path.islink(path):
926 die(f"{msg}must not be a symlink: {path}", parser=parser)
929def validate_file_permissions(path: str, mode: int) -> None:
930 """Verify permissions and that ownership is current effective UID."""
931 stats: os.stat_result = os.stat(path, follow_symlinks=False)
932 st_uid: int = stats.st_uid
933 if st_uid != os.geteuid(): # verify ownership is current effective UID
934 die(f"{path!r} is owned by uid {st_uid}, not {os.geteuid()}")
935 st_mode = stat.S_IMODE(stats.st_mode)
936 if st_mode != mode:
937 die(
938 f"{path!r} has permissions {st_mode:03o} aka {stat.filemode(st_mode)[1:]}, "
939 f"not {mode:03o} aka {stat.filemode(mode)[1:]})"
940 )
943def parse_duration_to_milliseconds(duration: str, *, regex_suffix: str = "", context: str = "") -> int:
944 """Parses human duration strings like '5minutes' or '2 hours' to milliseconds."""
945 unit_milliseconds: dict[str, int] = {
946 "milliseconds": 1,
947 "millis": 1,
948 "seconds": 1000,
949 "secs": 1000,
950 "minutes": 60 * 1000,
951 "mins": 60 * 1000,
952 "hours": 60 * 60 * 1000,
953 "days": 86400 * 1000,
954 "weeks": 7 * 86400 * 1000,
955 "months": round(30.5 * 86400 * 1000),
956 "years": 365 * 86400 * 1000,
957 }
958 match = re.fullmatch(
959 r"(\d+)\s*(milliseconds|millis|seconds|secs|minutes|mins|hours|days|weeks|months|years)" + regex_suffix,
960 duration,
961 )
962 if not match:
963 if context:
964 die(f"Invalid duration format: {duration} within {context}")
965 else:
966 raise ValueError(f"Invalid duration format: {duration}")
967 assert match
968 quantity: int = int(match.group(1))
969 unit: str = match.group(2)
970 return quantity * unit_milliseconds[unit]
973def unixtime_fromisoformat(datetime_str: str) -> int:
974 """Converts ISO 8601 datetime string into UTC Unix time in integer seconds."""
975 return int(datetime.fromisoformat(datetime_str).timestamp())
978def isotime_from_unixtime(unixtime_in_seconds: int) -> str:
979 """Converts UTC Unix time seconds into ISO 8601 datetime string."""
980 tz: tzinfo = timezone.utc
981 dt: datetime = datetime.fromtimestamp(unixtime_in_seconds, tz=tz)
982 return dt.isoformat(sep="_", timespec="seconds")
985def current_datetime(
986 tz_spec: str | None = None,
987 now_fn: Callable[[tzinfo | None], datetime] | None = None,
988) -> datetime:
989 """Returns current time in ``tz_spec`` timezone or local timezone."""
990 if now_fn is None:
991 now_fn = datetime.now
992 return now_fn(get_timezone(tz_spec))
995def get_timezone(tz_spec: str | None = None) -> tzinfo | None:
996 """Returns timezone from spec or local timezone if unspecified."""
997 tz: tzinfo | None
998 if tz_spec is None:
999 tz = None
1000 elif tz_spec == "UTC":
1001 tz = timezone.utc
1002 else:
1003 if match := re.fullmatch(r"([+-])(\d\d):?(\d\d)", tz_spec):
1004 sign, hours, minutes = match.groups()
1005 offset: int = int(hours) * 60 + int(minutes)
1006 offset = -offset if sign == "-" else offset
1007 tz = timezone(timedelta(minutes=offset))
1008 elif "/" in tz_spec:
1009 from zoneinfo import ZoneInfo # lazy import for startup perf
1011 tz = ZoneInfo(tz_spec)
1012 else:
1013 raise ValueError(f"Invalid timezone specification: {tz_spec}")
1014 return tz
1017###############################################################################
1018@final
1019class SnapshotPeriods: # thread-safe
1020 """Parses snapshot suffix strings and converts between durations."""
1022 def __init__(self) -> None:
1023 """Initialize lookup tables of suffixes and corresponding millis."""
1024 self.suffix_milliseconds: Final[dict[str, int]] = {
1025 "yearly": 365 * 86400 * 1000,
1026 "monthly": round(30.5 * 86400 * 1000),
1027 "weekly": 7 * 86400 * 1000,
1028 "daily": 86400 * 1000,
1029 "hourly": 60 * 60 * 1000,
1030 "minutely": 60 * 1000,
1031 "secondly": 1000,
1032 "millisecondly": 1,
1033 }
1034 self.period_labels: Final[dict[str, str]] = {
1035 "yearly": "years",
1036 "monthly": "months",
1037 "weekly": "weeks",
1038 "daily": "days",
1039 "hourly": "hours",
1040 "minutely": "minutes",
1041 "secondly": "seconds",
1042 "millisecondly": "milliseconds",
1043 }
1044 self._suffix_regex0: Final[re.Pattern] = re.compile(rf"([1-9][0-9]*)?({'|'.join(self.suffix_milliseconds.keys())})")
1045 self._suffix_regex1: Final[re.Pattern] = re.compile("_" + self._suffix_regex0.pattern)
1047 def suffix_to_duration0(self, suffix: str) -> tuple[int, str]:
1048 """Parse suffix like '10minutely' to (10, 'minutely')."""
1049 return self._suffix_to_duration(suffix, self._suffix_regex0)
1051 def suffix_to_duration1(self, suffix: str) -> tuple[int, str]:
1052 """Like :meth:`suffix_to_duration0` but expects an underscore prefix."""
1053 return self._suffix_to_duration(suffix, self._suffix_regex1)
1055 @staticmethod
1056 def _suffix_to_duration(suffix: str, regex: re.Pattern) -> tuple[int, str]:
1057 """Example: Converts '2 hourly' to (2, 'hourly') and 'hourly' to (1, 'hourly')."""
1058 if match := regex.fullmatch(suffix):
1059 duration_amount: int = int(match.group(1)) if match.group(1) else 1
1060 assert duration_amount > 0
1061 duration_unit: str = match.group(2)
1062 return duration_amount, duration_unit
1063 else:
1064 return 0, ""
1066 def label_milliseconds(self, snapshot: str) -> int:
1067 """Returns duration encoded in ``snapshot`` suffix, in milliseconds."""
1068 i = snapshot.rfind("_")
1069 snapshot = "" if i < 0 else snapshot[i + 1 :]
1070 duration_amount, duration_unit = self._suffix_to_duration(snapshot, self._suffix_regex0)
1071 return duration_amount * self.suffix_milliseconds.get(duration_unit, 0)
1074#############################################################################
1075@final
1076class JobStats:
1077 """Simple thread-safe counters summarizing job progress."""
1079 def __init__(self, jobs_all: int) -> None:
1080 assert jobs_all >= 0
1081 self.lock: Final[threading.Lock] = threading.Lock()
1082 self.jobs_all: int = jobs_all
1083 self.jobs_started: int = 0
1084 self.jobs_completed: int = 0
1085 self.jobs_failed: int = 0
1086 self.jobs_running: int = 0
1087 self.sum_elapsed_nanos: int = 0
1088 self.started_job_names: Final[set[str]] = set()
1090 def submit_job(self, job_name: str) -> str:
1091 """Counts a job submission."""
1092 with self.lock:
1093 self.jobs_started += 1
1094 self.jobs_running += 1
1095 self.started_job_names.add(job_name)
1096 return str(self)
1098 def complete_job(self, failed: bool, elapsed_nanos: int) -> str:
1099 """Counts a job completion."""
1100 assert elapsed_nanos >= 0
1101 with self.lock:
1102 self.jobs_running -= 1
1103 self.jobs_completed += 1
1104 self.jobs_failed += 1 if failed else 0
1105 self.sum_elapsed_nanos += elapsed_nanos
1106 msg = str(self)
1107 assert self.sum_elapsed_nanos >= 0, msg
1108 assert self.jobs_running >= 0, msg
1109 assert self.jobs_failed >= 0, msg
1110 assert self.jobs_failed <= self.jobs_completed, msg
1111 assert self.jobs_completed <= self.jobs_started, msg
1112 assert self.jobs_started <= self.jobs_all, msg
1113 return msg
1115 def __repr__(self) -> str:
1116 def pct(number: int) -> str:
1117 """Returns percentage string relative to total jobs."""
1118 return percent(number, total=self.jobs_all, print_total=True)
1120 al, started, completed, failed = self.jobs_all, self.jobs_started, self.jobs_completed, self.jobs_failed
1121 running = self.jobs_running
1122 t = "avg_completion_time:" + human_readable_duration(self.sum_elapsed_nanos / max(1, completed))
1123 return f"all:{al}, started:{pct(started)}, completed:{pct(completed)}, failed:{pct(failed)}, running:{running}, {t}"
1126#############################################################################
1127class Comparable(Protocol):
1128 """Partial ordering protocol."""
1130 def __lt__(self, other: Any) -> bool: ...
1133TComparable = TypeVar("TComparable", bound=Comparable) # Generic type variable for elements stored in a SmallPriorityQueue
1136@final
1137class SmallPriorityQueue(Generic[TComparable]):
1138 """A priority queue that can handle updates to the priority of any element that is already contained in the queue, and
1139 does so very efficiently if there are a small number of elements in the queue (no more than thousands), as is the case
1140 for us.
1142 Could be implemented using a SortedList via https://github.com/grantjenks/python-sortedcontainers or using an indexed
1143 priority queue via
1144 https://github.com/nvictus/pqdict.
1145 But, to avoid an external dependency, is actually implemented
1146 using a simple yet effective binary search-based sorted list that can handle updates to the priority of elements that
1147 are already contained in the queue, via removal of the element, followed by update of the element, followed by
1148 (re)insertion. Duplicate elements (if any) are maintained in their order of insertion relative to other duplicates.
1149 """
1151 def __init__(self, reverse: bool = False) -> None:
1152 """Creates an empty queue; sort order flips when ``reverse`` is True."""
1153 self._lst: Final[list[TComparable]] = []
1154 self._reverse: Final[bool] = reverse
1156 def clear(self) -> None:
1157 """Removes all elements from the queue."""
1158 self._lst.clear()
1160 def push(self, element: TComparable) -> None:
1161 """Inserts ``element`` while maintaining sorted order."""
1162 bisect.insort(self._lst, element)
1164 def pop(self) -> TComparable:
1165 """Removes and returns the smallest (or largest if reverse == True) element from the queue."""
1166 return self._lst.pop() if self._reverse else self._lst.pop(0)
1168 def peek(self) -> TComparable:
1169 """Returns the smallest (or largest if reverse == True) element without removing it."""
1170 return self._lst[-1] if self._reverse else self._lst[0]
1172 def remove(self, element: TComparable) -> bool:
1173 """Removes the first occurrence (in insertion order aka FIFO) of ``element`` and returns True if it was present."""
1174 lst = self._lst
1175 i = bisect.bisect_left(lst, element)
1176 is_contained = i < len(lst) and lst[i] == element
1177 if is_contained:
1178 del lst[i] # is an optimized memmove()
1179 return is_contained
1181 def __len__(self) -> int:
1182 """Returns the number of queued elements."""
1183 return len(self._lst)
1185 def __contains__(self, element: TComparable) -> bool:
1186 """Returns ``True`` if ``element`` is present."""
1187 lst = self._lst
1188 i = bisect.bisect_left(lst, element)
1189 return i < len(lst) and lst[i] == element
1191 def __iter__(self) -> Iterator[TComparable]:
1192 """Iterates over queued elements in priority order."""
1193 return reversed(self._lst) if self._reverse else iter(self._lst)
1195 def __repr__(self) -> str:
1196 """Representation showing queue contents in current order."""
1197 return repr(list(reversed(self._lst))) if self._reverse else repr(self._lst)
1200###############################################################################
1201@final
1202class SortedInterner(Generic[TComparable]):
1203 """Same as sys.intern() except that it isn't global and that it assumes the input list is sorted (for binary search)."""
1205 def __init__(self, sorted_list: list[TComparable]) -> None:
1206 self._lst: Final[list[TComparable]] = sorted_list
1208 def interned(self, element: TComparable) -> TComparable:
1209 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1210 lst = self._lst
1211 i = binary_search(lst, element)
1212 return lst[i] if i >= 0 else element
1214 def __contains__(self, element: TComparable) -> bool:
1215 """Returns ``True`` if ``element`` is present."""
1216 return binary_search(self._lst, element) >= 0
1219def binary_search(sorted_list: list[TComparable], item: TComparable) -> int:
1220 """Java-style binary search; Returns index >= 0 if an equal item is found in list, else '-insertion_point-1'; If it
1221 returns index >= 0, the index will be the left-most index in case multiple such equal items are contained."""
1222 i = bisect.bisect_left(sorted_list, item)
1223 return i if i < len(sorted_list) and sorted_list[i] == item else -i - 1
1226###############################################################################
1227_S = TypeVar("_S")
1230@final
1231class HashedInterner(Generic[_S]):
1232 """Same as sys.intern() except that it isn't global and can also be used for types other than str."""
1234 def __init__(self, items: Iterable[_S] = frozenset()) -> None:
1235 self._items: Final[dict[_S, _S]] = {v: v for v in items}
1237 def intern(self, item: _S) -> _S:
1238 """Interns the given item."""
1239 return self._items.setdefault(item, item)
1241 def interned(self, item: _S) -> _S:
1242 """Returns the interned (aka deduped) item if an equal item is contained, else returns the non-interned item."""
1243 return self._items.get(item, item)
1245 def __contains__(self, item: _S) -> bool:
1246 return item in self._items
1249#############################################################################
1250@final
1251class SynchronizedBool:
1252 """Thread-safe wrapper around a regular bool."""
1254 def __init__(self, val: bool) -> None:
1255 assert isinstance(val, bool)
1256 self._lock: Final[threading.Lock] = threading.Lock()
1257 self._value: bool = val
1259 @property
1260 def value(self) -> bool:
1261 """Returns the current boolean value."""
1262 with self._lock:
1263 return self._value
1265 @value.setter
1266 def value(self, new_value: bool) -> None:
1267 """Atomically assign ``new_value``."""
1268 with self._lock:
1269 self._value = new_value
1271 def get_and_set(self, new_value: bool) -> bool:
1272 """Swaps in ``new_value`` and return the previous value."""
1273 with self._lock:
1274 old_value = self._value
1275 self._value = new_value
1276 return old_value
1278 def compare_and_set(self, expected_value: bool, new_value: bool) -> bool:
1279 """Sets to ``new_value`` only if current value equals ``expected_value``."""
1280 with self._lock:
1281 eq: bool = self._value == expected_value
1282 if eq:
1283 self._value = new_value
1284 return eq
1286 def __bool__(self) -> bool:
1287 return self.value
1289 def __repr__(self) -> str:
1290 return repr(self.value)
1292 def __str__(self) -> str:
1293 return str(self.value)
1296#############################################################################
1297_K = TypeVar("_K")
1298_V = TypeVar("_V")
1301@final
1302class SynchronizedDict(Generic[_K, _V]):
1303 """Thread-safe wrapper around a regular dict."""
1305 def __init__(self, val: dict[_K, _V]) -> None:
1306 assert isinstance(val, dict)
1307 self._lock: Final[threading.Lock] = threading.Lock()
1308 self._dict: Final[dict[_K, _V]] = val
1310 def __getitem__(self, key: _K) -> _V:
1311 with self._lock:
1312 return self._dict[key]
1314 def __setitem__(self, key: _K, value: _V) -> None:
1315 with self._lock:
1316 self._dict[key] = value
1318 def __delitem__(self, key: _K) -> None:
1319 with self._lock:
1320 self._dict.pop(key)
1322 def __contains__(self, key: _K) -> bool:
1323 with self._lock:
1324 return key in self._dict
1326 def __len__(self) -> int:
1327 with self._lock:
1328 return len(self._dict)
1330 def __repr__(self) -> str:
1331 with self._lock:
1332 return repr(self._dict)
1334 def __str__(self) -> str:
1335 with self._lock:
1336 return str(self._dict)
1338 def get(self, key: _K, default: _V | None = None) -> _V | None:
1339 """Returns ``self[key]`` or ``default`` if missing."""
1340 with self._lock:
1341 return self._dict.get(key, default)
1343 def pop(self, key: _K, default: _V | None = None) -> _V | None:
1344 """Removes ``key`` and returns its value."""
1345 with self._lock:
1346 return self._dict.pop(key, default)
1348 def clear(self) -> None:
1349 """Removes all items atomically."""
1350 with self._lock:
1351 self._dict.clear()
1353 def items(self) -> ItemsView[_K, _V]:
1354 """Returns a snapshot of dictionary items."""
1355 with self._lock:
1356 return self._dict.copy().items()
1359#############################################################################
1360@final
1361class InterruptibleSleep:
1362 """Provides a sleep(timeout) function that can be interrupted by another thread; The underlying lock is configurable."""
1364 def __init__(self, lock: threading.Lock | None = None) -> None:
1365 self._is_stopping: bool = False
1366 self._lock: Final[threading.Lock] = lock if lock is not None else threading.Lock()
1367 self._condition: Final[threading.Condition] = threading.Condition(self._lock)
1369 def sleep(self, duration_nanos: int) -> bool:
1370 """Delays the current thread by the given number of nanoseconds; Returns True if the sleep got interrupted;
1371 Equivalent to threading.Event.wait()."""
1372 end_time_nanos: int = time.monotonic_ns() + duration_nanos
1373 with self._lock:
1374 while not self._is_stopping:
1375 diff_nanos: int = end_time_nanos - time.monotonic_ns()
1376 if diff_nanos <= 0:
1377 return False
1378 self._condition.wait(timeout=diff_nanos / 1_000_000_000) # release, then block until notified or timeout
1379 return True
1381 def interrupt(self) -> None:
1382 """Wakes sleeping threads and makes any future sleep()s a no-op; Equivalent to threading.Event.set()."""
1383 with self._lock:
1384 if not self._is_stopping:
1385 self._is_stopping = True
1386 self._condition.notify_all()
1388 def reset(self) -> None:
1389 """Makes any future sleep()s no longer a no-op; Equivalent to threading.Event.clear()."""
1390 with self._lock:
1391 self._is_stopping = False
1394#############################################################################
1395@final
1396class SynchronousExecutor(Executor):
1397 """Executor that runs tasks inline in the calling thread, sequentially."""
1399 def __init__(self) -> None:
1400 self._shutdown: bool = False
1402 def submit(self, fn: Callable[..., _R_], /, *args: Any, **kwargs: Any) -> Future[_R_]:
1403 """Executes `fn(*args, **kwargs)` immediately and returns its Future."""
1404 future: Future[_R_] = Future()
1405 if self._shutdown:
1406 raise RuntimeError("cannot schedule new futures after shutdown")
1407 try:
1408 result: _R_ = fn(*args, **kwargs)
1409 except BaseException as exc:
1410 future.set_exception(exc)
1411 else:
1412 future.set_result(result)
1413 return future
1415 def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
1416 """Prevents new submissions; no worker resources to join/cleanup."""
1417 self._shutdown = True
1419 @classmethod
1420 def executor_for(cls, max_workers: int) -> Executor:
1421 """Factory returning a SynchronousExecutor if 0 <= max_workers <= 1; else a ThreadPoolExecutor."""
1422 return cls() if 0 <= max_workers <= 1 else ThreadPoolExecutor(max_workers=max_workers)
1425#############################################################################
1426@final
1427class _XFinally(contextlib.AbstractContextManager):
1428 """Context manager ensuring cleanup code executes after ``with`` blocks."""
1430 def __init__(self, cleanup: Callable[[], None]) -> None:
1431 """Records the callable to run upon exit."""
1432 self._cleanup: Final = cleanup # Zero-argument callable executed after the `with` block exits.
1434 def __exit__(
1435 self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: types.TracebackType | None
1436 ) -> Literal[False]:
1437 """Runs cleanup and propagate any exceptions appropriately."""
1438 try:
1439 self._cleanup()
1440 except BaseException as cleanup_exc:
1441 if exc is None:
1442 raise # No main error --> propagate cleanup error normally
1443 # Both failed
1444 # if sys.version_info >= (3, 11):
1445 # raise ExceptionGroup("main error and cleanup error", [exc, cleanup_exc]) from None
1446 # <= 3.10: attach so it shows up in traceback but doesn't mask
1447 exc.__context__ = cleanup_exc
1448 return False # reraise original exception
1449 return False # propagate main exception if any
1452def xfinally(cleanup: Callable[[], None]) -> _XFinally:
1453 """Usage: with xfinally(lambda: cleanup()): ...
1454 Returns a context manager that guarantees that cleanup() runs on exit and guarantees any error in cleanup() will never
1455 mask an exception raised earlier inside the body of the `with` block, while still surfacing both problems when possible.
1457 Problem it solves
1458 -----------------
1459 A naive ``try ... finally`` may lose the original exception:
1461 try:
1462 work()
1463 finally:
1464 cleanup() # <-- if this raises an exception, it replaces the real error!
1466 `_XFinally` preserves exception priority:
1468 * Body raises, cleanup succeeds --> original body exception is re-raised.
1469 * Body raises, cleanup also raises --> re-raises body exception; cleanup exception is linked via ``__context__``.
1470 * Body succeeds, cleanup raises --> cleanup exception propagates normally.
1472 Example:
1473 -------
1474 >>> with xfinally(lambda: release_resources()): # doctest: +SKIP
1475 ... run_tasks()
1477 The single *with* line replaces verbose ``try/except/finally`` boilerplate while preserving full error information.
1478 """
1479 return _XFinally(cleanup)