Coverage for bzfs_main/argparse_actions.py: 100%
330 statements
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-06 13:30 +0000
« prev ^ index » next coverage.py v7.10.2, created at 2025-08-06 13:30 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Custom argparse actions shared by the 'bzfs' and 'bzfs_jobrunner' CLIs; These helpers validate and expand complex command
16line syntax such as +file references, dataset pairs, and snapshot filters."""
18from __future__ import annotations
19import argparse
20import ast
21import os
22import re
23from dataclasses import dataclass, field
24from datetime import timedelta
25from typing import Any
27from bzfs_main.check_range import CheckRange
28from bzfs_main.filter import (
29 SNAPSHOT_REGEX_FILTER_NAME,
30 SNAPSHOT_REGEX_FILTER_NAMES,
31 RankRange,
32 UnixTimeRange,
33)
34from bzfs_main.loggers import (
35 validate_log_config_variable,
36)
37from bzfs_main.utils import (
38 SHELL_CHARS,
39 SNAPSHOT_FILTERS_VAR,
40 UNIX_TIME_INFINITY_SECS,
41 YEAR_WITH_FOUR_DIGITS_REGEX,
42 SnapshotPeriods,
43 die,
44 getenv_bool,
45 ninfix,
46 nprefix,
47 nsuffix,
48 open_nofollow,
49 parse_duration_to_milliseconds,
50 unixtime_fromisoformat,
51)
54#############################################################################
55@dataclass(order=True)
56class SnapshotFilter:
57 """Represents a snapshot filter with matching options and time range."""
59 name: str
60 timerange: UnixTimeRange # defined in bzfs_main.filter
61 options: Any = field(compare=False, default=None)
64def _add_snapshot_filter(args: argparse.Namespace, _filter: SnapshotFilter) -> None:
65 """Appends snapshot filter to namespace list, creating the list if absent."""
67 if not hasattr(args, SNAPSHOT_FILTERS_VAR):
68 args.snapshot_filters_var = [[]]
69 args.snapshot_filters_var[-1].append(_filter)
72def _add_time_and_rank_snapshot_filter(
73 args: argparse.Namespace, dst: str, timerange: UnixTimeRange, rankranges: list[RankRange]
74) -> None:
75 """Creates and adds a SnapshotFilter using timerange and rank ranges."""
77 if timerange is None or len(rankranges) == 0 or any(rankrange[0] == rankrange[1] for rankrange in rankranges):
78 _add_snapshot_filter(args, SnapshotFilter("include_snapshot_times", timerange, None))
79 else:
80 assert timerange is not None
81 _add_snapshot_filter(args, SnapshotFilter(dst, timerange, rankranges))
84def has_timerange_filter(snapshot_filters: list[list[SnapshotFilter]]) -> bool:
85 """Interacts with add_time_and_rank_snapshot_filter() and optimize_snapshot_filters()."""
87 return any(f.timerange is not None for snapshot_filter in snapshot_filters for f in snapshot_filter)
90def optimize_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> list[SnapshotFilter]:
91 """Basic optimizations for the snapshot filter execution plan."""
93 _merge_adjacent_snapshot_filters(snapshot_filters)
94 _merge_adjacent_snapshot_regexes(snapshot_filters)
95 snapshot_filters = [f for f in snapshot_filters if f.timerange or f.options]
96 _reorder_snapshot_time_filters(snapshot_filters)
97 return snapshot_filters
100def _merge_adjacent_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> None:
101 """Merge adjacent filters of the same type if possible."""
103 i = len(snapshot_filters) - 1
104 while i >= 0:
105 filter_i: SnapshotFilter = snapshot_filters[i]
106 if isinstance(filter_i.options, list):
107 j = i - 1
108 if j >= 0 and snapshot_filters[j] == filter_i:
109 lst: list = snapshot_filters[j].options
110 assert isinstance(lst, list)
111 lst += filter_i.options
112 snapshot_filters.pop(i)
113 i -= 1
116def _merge_adjacent_snapshot_regexes(snapshot_filters: list[SnapshotFilter]) -> None:
117 """Combine consecutive regex filters of the same kind for efficiency."""
119 i = len(snapshot_filters) - 1
120 while i >= 0:
121 filter_i: SnapshotFilter = snapshot_filters[i]
122 if filter_i.name in SNAPSHOT_REGEX_FILTER_NAMES:
123 assert isinstance(filter_i.options, list)
124 j = i - 1
125 while j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES:
126 if snapshot_filters[j].name == filter_i.name:
127 lst: list[object] = snapshot_filters[j].options
128 assert isinstance(lst, list)
129 lst += filter_i.options
130 snapshot_filters.pop(i)
131 break
132 j -= 1
133 i -= 1
135 i = len(snapshot_filters) - 1
136 while i >= 0:
137 filter_i = snapshot_filters[i]
138 name: str = filter_i.name
139 if name in SNAPSHOT_REGEX_FILTER_NAMES:
140 j = i - 1
141 if j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES:
142 filter_j = snapshot_filters[j]
143 assert filter_j.name != name
144 snapshot_filters.pop(i)
145 i -= 1
146 else:
147 name_j: str = next(iter(SNAPSHOT_REGEX_FILTER_NAMES.difference({name})))
148 filter_j = SnapshotFilter(name_j, None, [])
149 sorted_filters: list[SnapshotFilter] = sorted([filter_i, filter_j])
150 exclude_regexes, include_regexes = (sorted_filters[0].options, sorted_filters[1].options)
151 snapshot_filters[i] = SnapshotFilter(SNAPSHOT_REGEX_FILTER_NAME, None, (exclude_regexes, include_regexes))
152 i -= 1
155def _reorder_snapshot_time_filters(snapshot_filters: list[SnapshotFilter]) -> None:
156 """Reorder time filters before regex filters within execution plan sections."""
158 def reorder_time_filters_within_section(i: int, j: int) -> None:
159 while j > i:
160 filter_j: SnapshotFilter = snapshot_filters[j]
161 if filter_j.name == "include_snapshot_times":
162 snapshot_filters.pop(j)
163 snapshot_filters.insert(i + 1, filter_j)
164 j -= 1
166 i = len(snapshot_filters) - 1
167 j = i
168 while i >= 0:
169 name: str = snapshot_filters[i].name
170 if name == "include_snapshot_times_and_ranks":
171 reorder_time_filters_within_section(i, j)
172 j = i - 1
173 i -= 1
174 reorder_time_filters_within_section(i, j)
177def validate_no_argument_file(
178 path: str, namespace: argparse.Namespace, err_prefix: str, parser: argparse.ArgumentParser | None = None
179) -> None:
180 """Checks that command line options do not include +file when disabled."""
181 if getattr(namespace, "no_argument_file", False):
182 die(f"{err_prefix}Argument file inclusion is disabled: {path}", parser=parser)
185#############################################################################
186class NonEmptyStringAction(argparse.Action):
187 """Argparse action rejecting empty string values."""
189 def __call__(
190 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
191 ) -> None:
192 """Strip whitespace and reject empty values."""
193 values = values.strip()
194 if values == "":
195 parser.error(f"{option_string}: Empty string is not valid")
196 setattr(namespace, self.dest, values)
199#############################################################################
200class DatasetPairsAction(argparse.Action):
201 """Parses alternating source/destination dataset arguments."""
203 def __call__(
204 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
205 ) -> None:
206 """Validates dataset pair arguments and expand '+file' notation."""
207 datasets: list[str] = []
208 err_prefix: str = f"{option_string or self.dest}: "
210 for value in values:
211 if not value.startswith("+"):
212 datasets.append(value)
213 else:
214 path: str = value[1:]
215 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser)
216 if "bzfs_argument_file" not in os.path.basename(path):
217 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}")
218 try:
219 with open_nofollow(path, "r", encoding="utf-8") as fd:
220 for i, line in enumerate(fd.read().splitlines()):
221 if line.startswith("#") or not line.strip():
222 continue
223 splits: list[str] = line.split("\t", 1)
224 if len(splits) <= 1:
225 parser.error(f"{err_prefix}Line must contain tab-separated SRC_DATASET and DST_DATASET: {i}")
226 src_root_dataset, dst_root_dataset = splits
227 if not src_root_dataset.strip() or not dst_root_dataset.strip():
228 parser.error(
229 f"{err_prefix}SRC_DATASET and DST_DATASET must not be empty or whitespace-only: {i}"
230 )
231 datasets.append(src_root_dataset)
232 datasets.append(dst_root_dataset)
233 except OSError as e:
234 parser.error(f"{err_prefix}{e}")
236 if len(datasets) % 2 != 0:
237 parser.error(f"{err_prefix}Each SRC_DATASET must have a corresponding DST_DATASET: {datasets}")
238 root_dataset_pairs: list[tuple[str, str]] = [(datasets[i], datasets[i + 1]) for i in range(0, len(datasets), 2)]
239 setattr(namespace, self.dest, root_dataset_pairs)
242#############################################################################
243class SSHConfigFileNameAction(argparse.Action):
244 """Validates SSH config file argument contains no whitespace or shell chars."""
246 def __call__(
247 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
248 ) -> None:
249 """Reject invalid file names with spaces or shell metacharacters."""
251 values = values.strip()
252 if values == "":
253 parser.error(f"{option_string}: Empty string is not valid")
254 if any(char in SHELL_CHARS or char.isspace() for char in values):
255 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace or special chars.")
256 setattr(namespace, self.dest, values)
259#############################################################################
260class SafeFileNameAction(argparse.Action):
261 """Ensures filenames lack path separators and weird whitespace."""
263 def __call__(
264 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
265 ) -> None:
266 """Rejects filenames containing path traversal or unusual whitespace."""
267 if ".." in values or "/" in values or "\\" in values:
268 parser.error(f"{option_string}: Invalid file name '{values}': must not contain '..' or '/' or '\\'.")
269 if any(char.isspace() and char != " " for char in values):
270 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace other than space.")
271 setattr(namespace, self.dest, values)
274#############################################################################
275class SafeDirectoryNameAction(argparse.Action):
276 """Validates directory name argument, allowing only simple spaces."""
278 def __call__(
279 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
280 ) -> None:
281 """Rejects directory names with weird whitespace or emptiness."""
282 values = values.strip()
283 if values == "":
284 parser.error(f"{option_string}: Empty string is not valid")
285 if any(char.isspace() and char != " " for char in values):
286 parser.error(f"{option_string}: Invalid dir name '{values}': must not contain whitespace other than space.")
287 setattr(namespace, self.dest, values)
290#############################################################################
291class NewSnapshotFilterGroupAction(argparse.Action):
292 """Starts a new filter group when seen in command line arguments."""
294 def __call__(
295 self, parser: argparse.ArgumentParser, args: argparse.Namespace, values: Any, option_string: str | None = None
296 ) -> None:
297 """Insert an empty group before adding new snapshot filters."""
298 if not hasattr(args, SNAPSHOT_FILTERS_VAR):
299 args.snapshot_filters_var = [[]]
300 elif len(args.snapshot_filters_var[-1]) > 0:
301 args.snapshot_filters_var.append([])
304#############################################################################
305class FileOrLiteralAction(argparse.Action):
306 """Allows '@file' style argument expansion with '+' prefix."""
308 def __call__(
309 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
310 ) -> None:
311 """Expands file arguments and appends them to the namespace."""
313 current_values: list[str] | None = getattr(namespace, self.dest, None)
314 if current_values is None:
315 current_values = []
316 extra_values: list[str] = []
317 err_prefix: str = f"{option_string or self.dest}: "
318 for value in values:
319 if not value.startswith("+"):
320 extra_values.append(value)
321 else:
322 path: str = value[1:]
323 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser)
324 if "bzfs_argument_file" not in os.path.basename(path):
325 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}")
326 try:
327 with open_nofollow(path, "r", encoding="utf-8") as fd:
328 for line in fd.read().splitlines():
329 if line.startswith("#") or not line.strip():
330 continue
331 extra_values.append(line)
332 except OSError as e:
333 parser.error(f"{err_prefix}{e}")
334 current_values += extra_values
335 setattr(namespace, self.dest, current_values)
336 if self.dest in SNAPSHOT_REGEX_FILTER_NAMES:
337 _add_snapshot_filter(namespace, SnapshotFilter(self.dest, None, extra_values))
340#############################################################################
341class IncludeSnapshotPlanAction(argparse.Action):
342 """Parses include plan dictionaries from the command line."""
344 def __call__(
345 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
346 ) -> None:
347 """Builds a list of snapshot filters from a serialized plan."""
348 opts: list[str] | None = getattr(namespace, self.dest, None)
349 opts = [] if opts is None else opts
350 include_snapshot_times_and_ranks: bool = getenv_bool("include_snapshot_plan_excludes_outdated_snapshots", True)
351 if not self._add_opts(opts, include_snapshot_times_and_ranks, parser, values, option_string=option_string):
352 opts += ["--new-snapshot-filter-group", "--include-snapshot-regex=!.*"]
353 setattr(namespace, self.dest, opts)
355 def _add_opts(
356 self,
357 opts: list[str],
358 include_snapshot_times_and_ranks: bool,
359 parser: argparse.ArgumentParser,
360 values: str,
361 option_string: str | None = None,
362 ) -> bool:
363 xperiods: SnapshotPeriods = SnapshotPeriods()
364 has_at_least_one_filter_clause: bool = False
365 for org, target_periods in ast.literal_eval(values).items():
366 prefix: str = re.escape(nprefix(org))
367 for target, periods in target_periods.items():
368 infix: str = re.escape(ninfix(target)) if target else YEAR_WITH_FOUR_DIGITS_REGEX.pattern
369 for period_unit, period_amount in periods.items():
370 if not isinstance(period_amount, int) or period_amount < 0:
371 parser.error(f"{option_string}: Period amount must be a non-negative integer: {period_amount}")
372 suffix: str = re.escape(nsuffix(period_unit))
373 regex: str = f"{prefix}{infix}.*{suffix}"
374 opts += ["--new-snapshot-filter-group", f"--include-snapshot-regex={regex}"]
375 if include_snapshot_times_and_ranks:
376 duration_amount, duration_unit = xperiods.suffix_to_duration0(period_unit)
377 duration_unit_label: str | None = xperiods.period_labels.get(duration_unit)
378 opts += [
379 "--include-snapshot-times-and-ranks",
380 (
381 "notime"
382 if duration_unit_label is None or duration_amount * period_amount == 0
383 else f"{duration_amount * period_amount}{duration_unit_label}ago..anytime"
384 ),
385 f"latest{period_amount}",
386 ]
387 has_at_least_one_filter_clause = True
388 return has_at_least_one_filter_clause
391#############################################################################
392class DeleteDstSnapshotsExceptPlanAction(IncludeSnapshotPlanAction):
393 """Specialized include plan used to decide which dst snapshots to keep."""
395 def __call__(
396 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
397 ) -> None:
398 """Parses plan while preventing disasters."""
399 opts: list[str] | None = getattr(namespace, self.dest, None)
400 opts = [] if opts is None else opts
401 opts += ["--delete-dst-snapshots-except"]
402 if not self._add_opts(opts, True, parser, values, option_string=option_string):
403 parser.error(
404 f"{option_string}: Cowardly refusing to delete all snapshots on"
405 f"--delete-dst-snapshots-except-plan='{values}' (which means 'retain no snapshots' aka "
406 "'delete all snapshots'). Assuming this is an unintended pilot error rather than intended carnage. "
407 "Aborting. If this is really what is intended, use `--delete-dst-snapshots --include-snapshot-regex=.*` "
408 "instead to force the deletion."
409 )
410 setattr(namespace, self.dest, opts)
413#############################################################################
414class TimeRangeAndRankRangeAction(argparse.Action):
415 """Parses --include-snapshot-times-and-ranks option values."""
417 def __call__(
418 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
419 ) -> None:
420 """Converts user-supplied time and rank ranges into snapshot filters."""
422 def parse_time(time_spec: str) -> int | timedelta | None:
423 time_spec = time_spec.strip()
424 if time_spec == "*" or time_spec == "anytime":
425 return None
426 if time_spec.isdigit():
427 return int(time_spec)
428 try:
429 return timedelta(milliseconds=parse_duration_to_milliseconds(time_spec, regex_suffix=r"\s*ago"))
430 except ValueError:
431 try:
432 return unixtime_fromisoformat(time_spec)
433 except ValueError:
434 parser.error(f"{option_string}: Invalid duration, Unix time, or ISO 8601 datetime: {time_spec}")
436 assert isinstance(values, list)
437 assert len(values) > 0
438 value: str = values[0].strip()
439 if value == "notime":
440 value = "0..0"
441 if ".." not in value:
442 parser.error(f"{option_string}: Invalid time range: Missing '..' separator: {value}")
443 timerange_specs: list[int | timedelta | None] = [parse_time(time_spec) for time_spec in value.split("..", 1)]
444 rankranges: list[RankRange] = self._parse_rankranges(parser, values[1:], option_string=option_string)
445 setattr(namespace, self.dest, [timerange_specs] + rankranges)
446 timerange: UnixTimeRange = self._get_include_snapshot_times(timerange_specs)
447 _add_time_and_rank_snapshot_filter(namespace, self.dest, timerange, rankranges)
449 @staticmethod
450 def _get_include_snapshot_times(times: list[timedelta | int | None]) -> UnixTimeRange:
451 """Convert start and end times to ``UnixTimeRange`` for filtering."""
453 def utc_unix_time_in_seconds(time_spec: timedelta | int | None, default: int) -> timedelta | int:
454 if isinstance(time_spec, timedelta):
455 return time_spec
456 if isinstance(time_spec, int):
457 return int(time_spec)
458 return default
460 lo, hi = times
461 if lo is None and hi is None:
462 return None
463 lo = utc_unix_time_in_seconds(lo, default=0)
464 hi = utc_unix_time_in_seconds(hi, default=UNIX_TIME_INFINITY_SECS)
465 if isinstance(lo, int) and isinstance(hi, int):
466 return (lo, hi) if lo <= hi else (hi, lo)
467 return lo, hi
469 @staticmethod
470 def _parse_rankranges(parser: argparse.ArgumentParser, values: Any, option_string: str | None = None) -> list[RankRange]:
471 """Parses rank range strings like 'latest 3..latest 5' into tuples."""
473 def parse_rank(spec: str) -> tuple[bool, str, int, bool]:
474 spec = spec.strip()
475 if not (match := re.fullmatch(r"(all\s*except\s*)?(oldest|latest)\s*(\d+)%?", spec)):
476 parser.error(f"{option_string}: Invalid rank format: {spec}")
477 assert match
478 is_except: bool = bool(match.group(1))
479 kind: str = match.group(2)
480 num: int = int(match.group(3))
481 is_percent: bool = spec.endswith("%")
482 if is_percent and num > 100:
483 parser.error(f"{option_string}: Invalid rank: Percent must not be greater than 100: {spec}")
484 return is_except, kind, num, is_percent
486 rankranges: list[RankRange] = []
487 for value in values:
488 value = value.strip()
489 if ".." in value:
490 lo_split, hi_split = value.split("..", 1)
491 lo = parse_rank(lo_split)
492 hi = parse_rank(hi_split)
493 if lo[0] or hi[0]:
494 parser.error(f"{option_string}: Invalid rank range: {value}")
495 if lo[1] != hi[1]:
496 parser.error(f"{option_string}: Ambiguous rank range: Must not compare oldest with latest: {value}")
497 else:
498 hi = parse_rank(value)
499 is_except, kind, num, is_percent = hi
500 if is_except:
501 if is_percent:
502 negated_kind: str = "oldest" if kind == "latest" else "latest"
503 lo = parse_rank(f"{negated_kind}0")
504 hi = parse_rank(f"{negated_kind}{100-num}%")
505 else:
506 lo = parse_rank(f"{kind}{num}")
507 hi = parse_rank(f"{kind}100%")
508 else:
509 lo = parse_rank(f"{kind}0")
510 rankranges.append((lo[1:], hi[1:]))
511 return rankranges
514#############################################################################
515class LogConfigVariablesAction(argparse.Action):
516 """Collects --log-config-var NAME:VALUE pairs for later substitution."""
518 def __call__(
519 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
520 ) -> None:
521 """Validates NAME:VALUE entries and accumulate them."""
522 current_values: list[str] | None = getattr(namespace, self.dest, None)
523 if current_values is None:
524 current_values = []
525 for variable in values:
526 error_msg: str | None = validate_log_config_variable(variable)
527 if error_msg:
528 parser.error(error_msg)
529 current_values.append(variable)
530 setattr(namespace, self.dest, current_values)
533#############################################################################
534class CheckPercentRange(CheckRange):
535 """Argparse action verifying percentages fall within 0-100."""
537 def __call__(
538 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
539 ) -> None:
540 """Normalizes integer or percent values and store them."""
541 assert isinstance(values, str)
542 original = values
543 values = values.strip()
544 is_percent: bool = values.endswith("%")
545 if is_percent:
546 values = values[0:-1]
547 try:
548 values = float(values)
549 except ValueError:
550 parser.error(f"{option_string}: Invalid percentage or number: {original}")
551 super().__call__(parser, namespace, values, option_string=option_string)
552 setattr(namespace, self.dest, (getattr(namespace, self.dest), is_percent))