Coverage for bzfs_main / argparse_actions.py: 100%
327 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 08:03 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 08:03 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Custom argparse actions shared by the 'bzfs' and 'bzfs_jobrunner' CLIs; These helpers validate and expand complex command
16line syntax such as +file references, dataset pairs, and snapshot filters."""
18from __future__ import (
19 annotations,
20)
21import argparse
22import ast
23import os
24import re
25from dataclasses import (
26 dataclass,
27 field,
28)
29from datetime import (
30 timedelta,
31)
32from typing import (
33 Any,
34 final,
35)
37from bzfs_main.filter import (
38 SNAPSHOT_FILTERS_VAR,
39 SNAPSHOT_REGEX_FILTER_NAME,
40 SNAPSHOT_REGEX_FILTER_NAMES,
41 RankRange,
42 UnixTimeRange,
43)
44from bzfs_main.util.check_range import (
45 CheckRange,
46)
47from bzfs_main.util.utils import (
48 SHELL_CHARS,
49 UNIX_TIME_INFINITY_SECS,
50 YEAR_WITH_FOUR_DIGITS_REGEX,
51 SnapshotPeriods,
52 die,
53 ninfix,
54 nprefix,
55 nsuffix,
56 open_nofollow,
57 parse_duration_to_milliseconds,
58 unixtime_fromisoformat,
59)
62#############################################################################
63@dataclass(order=True)
64@final
65class SnapshotFilter:
66 """Represents a snapshot filter with matching options and time range."""
68 name: str
69 timerange: UnixTimeRange # defined in bzfs_main.filter
70 options: Any = field(compare=False, default=None)
73def _add_snapshot_filter(args: argparse.Namespace, _filter: SnapshotFilter) -> None:
74 """Appends snapshot filter to namespace list, creating the list if absent."""
76 if not hasattr(args, SNAPSHOT_FILTERS_VAR):
77 args.snapshot_filters_var = [[]]
78 args.snapshot_filters_var[-1].append(_filter)
81def _add_time_and_rank_snapshot_filter(
82 args: argparse.Namespace, dst: str, timerange: UnixTimeRange, rankranges: list[RankRange]
83) -> None:
84 """Creates and adds a SnapshotFilter using timerange and rank ranges."""
86 if timerange is None or len(rankranges) == 0 or any(rankrange[0] == rankrange[1] for rankrange in rankranges):
87 _add_snapshot_filter(args, SnapshotFilter("include_snapshot_times", timerange, None))
88 else:
89 assert timerange is not None
90 _add_snapshot_filter(args, SnapshotFilter(dst, timerange, rankranges))
93def has_timerange_filter(snapshot_filters: list[list[SnapshotFilter]]) -> bool:
94 """Interacts with add_time_and_rank_snapshot_filter() and optimize_snapshot_filters()."""
96 return any(f.timerange is not None for snapshot_filter in snapshot_filters for f in snapshot_filter)
99def optimize_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> list[SnapshotFilter]:
100 """Basic optimizations for the snapshot filter execution plan."""
102 _merge_adjacent_snapshot_filters(snapshot_filters)
103 _merge_adjacent_snapshot_regexes(snapshot_filters)
104 snapshot_filters = [f for f in snapshot_filters if f.timerange or f.options]
105 _reorder_snapshot_time_filters(snapshot_filters)
106 return snapshot_filters
109def _merge_adjacent_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> None:
110 """Merge adjacent filters of the same type if possible."""
112 i = len(snapshot_filters) - 1
113 while i >= 0:
114 filter_i: SnapshotFilter = snapshot_filters[i]
115 if isinstance(filter_i.options, list):
116 j = i - 1
117 if j >= 0 and snapshot_filters[j] == filter_i:
118 lst: list = snapshot_filters[j].options
119 assert isinstance(lst, list)
120 lst += filter_i.options
121 snapshot_filters.pop(i)
122 i -= 1
125def _merge_adjacent_snapshot_regexes(snapshot_filters: list[SnapshotFilter]) -> None:
126 """Combine consecutive regex filters of the same kind for efficiency."""
128 i = len(snapshot_filters) - 1
129 while i >= 0:
130 filter_i: SnapshotFilter = snapshot_filters[i]
131 if filter_i.name in SNAPSHOT_REGEX_FILTER_NAMES:
132 assert isinstance(filter_i.options, list)
133 j = i - 1
134 while j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES:
135 if snapshot_filters[j].name == filter_i.name:
136 lst: list[object] = snapshot_filters[j].options
137 assert isinstance(lst, list)
138 lst += filter_i.options
139 snapshot_filters.pop(i)
140 break
141 j -= 1
142 i -= 1
144 i = len(snapshot_filters) - 1
145 while i >= 0:
146 filter_i = snapshot_filters[i]
147 name: str = filter_i.name
148 if name in SNAPSHOT_REGEX_FILTER_NAMES:
149 j = i - 1
150 if j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES:
151 filter_j = snapshot_filters[j]
152 assert filter_j.name != name
153 snapshot_filters.pop(i)
154 i -= 1
155 else:
156 name_j: str = next(iter(SNAPSHOT_REGEX_FILTER_NAMES.difference({name})))
157 filter_j = SnapshotFilter(name_j, None, [])
158 sorted_filters: list[SnapshotFilter] = sorted([filter_i, filter_j])
159 exclude_regexes, include_regexes = (sorted_filters[0].options, sorted_filters[1].options)
160 snapshot_filters[i] = SnapshotFilter(SNAPSHOT_REGEX_FILTER_NAME, None, (exclude_regexes, include_regexes))
161 i -= 1
164def _reorder_snapshot_time_filters(snapshot_filters: list[SnapshotFilter]) -> None:
165 """Reorder time filters before regex filters within execution plan sections."""
167 def reorder_time_filters_within_section(i: int, j: int) -> None:
168 while j > i:
169 filter_j: SnapshotFilter = snapshot_filters[j]
170 if filter_j.name == "include_snapshot_times":
171 snapshot_filters.pop(j)
172 snapshot_filters.insert(i + 1, filter_j)
173 j -= 1
175 i = len(snapshot_filters) - 1
176 j = i
177 while i >= 0:
178 name: str = snapshot_filters[i].name
179 if name == "include_snapshot_times_and_ranks":
180 reorder_time_filters_within_section(i, j)
181 j = i - 1
182 i -= 1
183 reorder_time_filters_within_section(i, j)
186def validate_no_argument_file(
187 path: str, namespace: argparse.Namespace, err_prefix: str, parser: argparse.ArgumentParser | None = None
188) -> None:
189 """Checks that command line options do not include +file when disabled."""
190 if getattr(namespace, "no_argument_file", False):
191 die(f"{err_prefix}Argument file inclusion is disabled: {path}", parser=parser)
194#############################################################################
195@final
196class NonEmptyStringAction(argparse.Action):
197 """Argparse action rejecting empty string values."""
199 def __call__(
200 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
201 ) -> None:
202 """Strip whitespace and reject empty values."""
203 values = values.strip()
204 if values == "":
205 parser.error(f"{option_string}: Empty string is not valid")
206 setattr(namespace, self.dest, values)
209#############################################################################
210@final
211class DatasetPairsAction(argparse.Action):
212 """Parses alternating source/destination dataset arguments."""
214 def __call__(
215 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
216 ) -> None:
217 """Validates dataset pair arguments and expand '+file' notation."""
218 datasets: list[str] = []
219 err_prefix: str = f"{option_string or self.dest}: "
221 for value in values:
222 if not value.startswith("+"):
223 datasets.append(value)
224 else:
225 path: str = value[1:]
226 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser)
227 if "bzfs_argument_file" not in os.path.basename(path):
228 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}")
229 try:
230 with open_nofollow(path, "r", encoding="utf-8") as fd:
231 for i, line in enumerate(fd.read().splitlines()):
232 if line.startswith("#") or not line.strip():
233 continue
234 splits: list[str] = line.split("\t", 1)
235 if len(splits) <= 1:
236 parser.error(f"{err_prefix}Line must contain tab-separated SRC_DATASET and DST_DATASET: {i}")
237 src_root_dataset, dst_root_dataset = splits
238 if not src_root_dataset.strip() or not dst_root_dataset.strip():
239 parser.error(
240 f"{err_prefix}SRC_DATASET and DST_DATASET must not be empty or whitespace-only: {i}"
241 )
242 datasets.append(src_root_dataset)
243 datasets.append(dst_root_dataset)
244 except OSError as e:
245 parser.error(f"{err_prefix}{e}")
247 if len(datasets) % 2 != 0:
248 parser.error(f"{err_prefix}Each SRC_DATASET must have a corresponding DST_DATASET: {datasets}")
249 root_dataset_pairs: list[tuple[str, str]] = [(datasets[i], datasets[i + 1]) for i in range(0, len(datasets), 2)]
250 setattr(namespace, self.dest, root_dataset_pairs)
253#############################################################################
254@final
255class SSHConfigFileNameAction(argparse.Action):
256 """Validates SSH config file argument contains no whitespace or shell chars."""
258 def __call__(
259 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
260 ) -> None:
261 """Reject invalid file names with spaces or shell metacharacters."""
263 values = values.strip()
264 if values == "":
265 parser.error(f"{option_string}: Empty string is not valid")
266 if any(char in SHELL_CHARS or char.isspace() for char in values):
267 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace or special chars.")
268 setattr(namespace, self.dest, values)
271#############################################################################
272@final
273class SafeFileNameAction(argparse.Action):
274 """Ensures filenames lack path separators and weird whitespace."""
276 def __call__(
277 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
278 ) -> None:
279 """Rejects filenames containing path traversal or unusual whitespace."""
280 if ".." in values or "/" in values or "\\" in values:
281 parser.error(f"{option_string}: Invalid file name '{values}': must not contain '..' or '/' or '\\'.")
282 if any(char.isspace() and char != " " for char in values):
283 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace other than space.")
284 setattr(namespace, self.dest, values)
287#############################################################################
288@final
289class SafeDirectoryNameAction(argparse.Action):
290 """Validates directory name argument, allowing only simple spaces."""
292 def __call__(
293 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
294 ) -> None:
295 """Rejects directory names with weird whitespace or emptiness."""
296 values = values.strip()
297 if values == "":
298 parser.error(f"{option_string}: Empty string is not valid")
299 if any(char.isspace() and char != " " for char in values):
300 parser.error(f"{option_string}: Invalid dir name '{values}': must not contain whitespace other than space.")
301 setattr(namespace, self.dest, values)
304#############################################################################
305@final
306class NewSnapshotFilterGroupAction(argparse.Action):
307 """Starts a new filter group when seen in command line arguments."""
309 def __call__(
310 self, parser: argparse.ArgumentParser, args: argparse.Namespace, values: Any, option_string: str | None = None
311 ) -> None:
312 """Insert an empty group before adding new snapshot filters."""
313 if not hasattr(args, SNAPSHOT_FILTERS_VAR):
314 args.snapshot_filters_var = [[]]
315 elif len(args.snapshot_filters_var[-1]) > 0:
316 args.snapshot_filters_var.append([])
319#############################################################################
320@final
321class FileOrLiteralAction(argparse.Action):
322 """Allows '@file' style argument expansion with '+' prefix."""
324 def __call__(
325 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
326 ) -> None:
327 """Expands file arguments and appends them to the namespace."""
329 current_values: list[str] | None = getattr(namespace, self.dest, None)
330 if current_values is None:
331 current_values = []
332 extra_values: list[str] = []
333 err_prefix: str = f"{option_string or self.dest}: "
334 for value in values:
335 if not value.startswith("+"):
336 extra_values.append(value)
337 else:
338 path: str = value[1:]
339 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser)
340 if "bzfs_argument_file" not in os.path.basename(path):
341 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}")
342 try:
343 with open_nofollow(path, "r", encoding="utf-8") as fd:
344 for line in fd.read().splitlines():
345 if line.startswith("#") or not line.strip():
346 continue
347 extra_values.append(line)
348 except OSError as e:
349 parser.error(f"{err_prefix}{e}")
350 current_values += extra_values
351 setattr(namespace, self.dest, current_values)
352 if self.dest in SNAPSHOT_REGEX_FILTER_NAMES:
353 _add_snapshot_filter(namespace, SnapshotFilter(self.dest, None, extra_values))
356#############################################################################
357class IncludeSnapshotPlanAction(argparse.Action):
358 """Parses include plan dictionaries from the command line."""
360 def __call__(
361 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
362 ) -> None:
363 """Builds a list of snapshot filters from a serialized plan."""
364 opts: list[str] | None = getattr(namespace, self.dest, None)
365 opts = [] if opts is None else opts
366 if not self._add_opts(opts, parser, values, option_string=option_string):
367 opts += ["--new-snapshot-filter-group", "--include-snapshot-regex=!.*"]
368 setattr(namespace, self.dest, opts)
370 def _add_opts(
371 self,
372 opts: list[str],
373 parser: argparse.ArgumentParser,
374 values: str,
375 option_string: str | None = None,
376 ) -> bool:
377 xperiods: SnapshotPeriods = SnapshotPeriods()
378 has_at_least_one_filter_clause: bool = False
379 for org, target_periods in ast.literal_eval(values).items():
380 prefix: str = re.escape(nprefix(org))
381 for target, periods in target_periods.items():
382 infix: str = re.escape(ninfix(target)) if target else YEAR_WITH_FOUR_DIGITS_REGEX.pattern
383 for period_unit, period_amount in periods.items():
384 if not isinstance(period_amount, int) or period_amount < 0:
385 parser.error(f"{option_string}: Period amount must be a non-negative integer: {period_amount}")
386 suffix: str = re.escape(nsuffix(period_unit))
387 regex: str = f"{prefix}{infix}.*{suffix}"
388 opts += ["--new-snapshot-filter-group", f"--include-snapshot-regex={regex}"]
389 duration_amount, duration_unit = xperiods.suffix_to_duration0(period_unit)
390 duration_unit_label: str | None = xperiods.period_labels.get(duration_unit)
391 opts += [
392 "--include-snapshot-times-and-ranks",
393 (
394 "notime"
395 if duration_unit_label is None or duration_amount * period_amount == 0
396 else f"{duration_amount * period_amount}{duration_unit_label}ago..anytime"
397 ),
398 f"latest{period_amount}",
399 ]
400 has_at_least_one_filter_clause = True
401 return has_at_least_one_filter_clause
404#############################################################################
405@final
406class DeleteDstSnapshotsExceptPlanAction(IncludeSnapshotPlanAction):
407 """Specialized include plan used to decide which dst snapshots to keep."""
409 def __call__(
410 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
411 ) -> None:
412 """Parses plan while preventing disasters."""
413 opts: list[str] | None = getattr(namespace, self.dest, None)
414 opts = [] if opts is None else opts
415 opts += ["--delete-dst-snapshots-except"]
416 if not self._add_opts(opts, parser, values, option_string=option_string):
417 parser.error(
418 f"{option_string}: Cowardly refusing to delete all snapshots on"
419 f"--delete-dst-snapshots-except-plan='{values}' (which means 'retain no snapshots' aka "
420 "'delete all snapshots'). Assuming this is an unintended pilot error rather than intended carnage. "
421 "Aborting. If this is really what is intended, use `--delete-dst-snapshots --include-snapshot-regex=.*` "
422 "instead to force the deletion."
423 )
424 setattr(namespace, self.dest, opts)
427#############################################################################
428@final
429class TimeRangeAndRankRangeAction(argparse.Action):
430 """Parses --include-snapshot-times-and-ranks option values."""
432 def __call__(
433 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
434 ) -> None:
435 """Converts user-supplied time and rank ranges into snapshot filters."""
437 def parse_time(time_spec: str) -> int | timedelta | None:
438 time_spec = time_spec.strip()
439 if time_spec == "*" or time_spec == "anytime":
440 return None
441 if time_spec.isdigit():
442 return int(time_spec)
443 try:
444 return timedelta(milliseconds=parse_duration_to_milliseconds(time_spec, regex_suffix=r"\s*ago"))
445 except ValueError:
446 try:
447 return unixtime_fromisoformat(time_spec)
448 except ValueError:
449 parser.error(f"{option_string}: Invalid duration, Unix time, or ISO 8601 datetime: {time_spec}")
451 assert isinstance(values, list)
452 assert len(values) > 0
453 value: str = values[0].strip()
454 if value == "notime":
455 value = "0..0"
456 if ".." not in value:
457 parser.error(f"{option_string}: Invalid time range: Missing '..' separator: {value}")
458 timerange_specs: list[int | timedelta | None] = [parse_time(time_spec) for time_spec in value.split("..", 1)]
459 rankranges: list[RankRange] = self._parse_rankranges(parser, values[1:], option_string=option_string)
460 setattr(namespace, self.dest, [timerange_specs] + rankranges)
461 timerange: UnixTimeRange = self._get_include_snapshot_times(timerange_specs)
462 _add_time_and_rank_snapshot_filter(namespace, self.dest, timerange, rankranges)
464 @staticmethod
465 def _get_include_snapshot_times(times: list[timedelta | int | None]) -> UnixTimeRange:
466 """Convert start and end times to ``UnixTimeRange`` for filtering."""
468 def utc_unix_time_in_seconds(time_spec: timedelta | int | None, default: int) -> timedelta | int:
469 if isinstance(time_spec, timedelta):
470 return time_spec
471 if isinstance(time_spec, int):
472 return int(time_spec)
473 return default
475 lo, hi = times
476 if lo is None and hi is None:
477 return None
478 lo = utc_unix_time_in_seconds(lo, default=0)
479 hi = utc_unix_time_in_seconds(hi, default=UNIX_TIME_INFINITY_SECS)
480 if isinstance(lo, int) and isinstance(hi, int):
481 return (lo, hi) if lo <= hi else (hi, lo)
482 return lo, hi
484 @staticmethod
485 def _parse_rankranges(parser: argparse.ArgumentParser, values: Any, option_string: str | None = None) -> list[RankRange]:
486 """Parses rank range strings like 'latest 3..latest 5' into tuples."""
488 def parse_rank(spec: str) -> tuple[bool, str, int, bool]:
489 spec = spec.strip()
490 if not (match := re.fullmatch(r"(all\s*except\s*)?(oldest|latest)\s*(\d+)%?", spec)):
491 parser.error(f"{option_string}: Invalid rank format: {spec}")
492 assert match
493 is_except: bool = bool(match.group(1))
494 kind: str = match.group(2)
495 num: int = int(match.group(3))
496 is_percent: bool = spec.endswith("%")
497 if is_percent and num > 100:
498 parser.error(f"{option_string}: Invalid rank: Percent must not be greater than 100: {spec}")
499 return is_except, kind, num, is_percent
501 rankranges: list[RankRange] = []
502 for value in values:
503 value = value.strip()
504 if ".." in value:
505 lo_split, hi_split = value.split("..", 1)
506 lo = parse_rank(lo_split)
507 hi = parse_rank(hi_split)
508 if lo[0] or hi[0]:
509 parser.error(f"{option_string}: Invalid rank range: {value}")
510 if lo[1] != hi[1]:
511 parser.error(f"{option_string}: Ambiguous rank range: Must not compare oldest with latest: {value}")
512 else:
513 hi = parse_rank(value)
514 is_except, kind, num, is_percent = hi
515 if is_except:
516 if is_percent:
517 negated_kind: str = "oldest" if kind == "latest" else "latest"
518 lo = parse_rank(f"{negated_kind}0")
519 hi = parse_rank(f"{negated_kind}{100-num}%")
520 else:
521 lo = parse_rank(f"{kind}{num}")
522 hi = parse_rank(f"{kind}100%")
523 else:
524 lo = parse_rank(f"{kind}0")
525 rankranges.append((lo[1:], hi[1:]))
526 return rankranges
529#############################################################################
530@final
531class CheckPercentRange(CheckRange):
532 """Argparse action verifying percentages fall within 0-100."""
534 def __call__(
535 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
536 ) -> None:
537 """Normalizes integer or percent values and store them."""
538 assert isinstance(values, str)
539 original = values
540 values = values.strip()
541 is_percent: bool = values.endswith("%")
542 if is_percent:
543 values = values[0:-1]
544 try:
545 values = float(values)
546 except ValueError:
547 parser.error(f"{option_string}: Invalid percentage or number: {original}")
548 super().__call__(parser, namespace, values, option_string=option_string)
549 setattr(namespace, self.dest, (getattr(namespace, self.dest), is_percent))