Coverage for bzfs_main/argparse_actions.py: 100%
316 statements
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-07 04:44 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2025-11-07 04:44 +0000
1# Copyright 2024 Wolfgang Hoschek AT mac DOT com
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15"""Custom argparse actions shared by the 'bzfs' and 'bzfs_jobrunner' CLIs; These helpers validate and expand complex command
16line syntax such as +file references, dataset pairs, and snapshot filters."""
18from __future__ import (
19 annotations,
20)
21import argparse
22import ast
23import os
24import re
25from dataclasses import (
26 dataclass,
27 field,
28)
29from datetime import (
30 timedelta,
31)
32from typing import (
33 Any,
34)
36from bzfs_main.check_range import (
37 CheckRange,
38)
39from bzfs_main.filter import (
40 SNAPSHOT_REGEX_FILTER_NAME,
41 SNAPSHOT_REGEX_FILTER_NAMES,
42 RankRange,
43 UnixTimeRange,
44)
45from bzfs_main.utils import (
46 SHELL_CHARS,
47 SNAPSHOT_FILTERS_VAR,
48 UNIX_TIME_INFINITY_SECS,
49 YEAR_WITH_FOUR_DIGITS_REGEX,
50 SnapshotPeriods,
51 die,
52 ninfix,
53 nprefix,
54 nsuffix,
55 open_nofollow,
56 parse_duration_to_milliseconds,
57 unixtime_fromisoformat,
58)
61#############################################################################
62@dataclass(order=True)
63class SnapshotFilter:
64 """Represents a snapshot filter with matching options and time range."""
66 name: str
67 timerange: UnixTimeRange # defined in bzfs_main.filter
68 options: Any = field(compare=False, default=None)
71def _add_snapshot_filter(args: argparse.Namespace, _filter: SnapshotFilter) -> None:
72 """Appends snapshot filter to namespace list, creating the list if absent."""
74 if not hasattr(args, SNAPSHOT_FILTERS_VAR):
75 args.snapshot_filters_var = [[]]
76 args.snapshot_filters_var[-1].append(_filter)
79def _add_time_and_rank_snapshot_filter(
80 args: argparse.Namespace, dst: str, timerange: UnixTimeRange, rankranges: list[RankRange]
81) -> None:
82 """Creates and adds a SnapshotFilter using timerange and rank ranges."""
84 if timerange is None or len(rankranges) == 0 or any(rankrange[0] == rankrange[1] for rankrange in rankranges):
85 _add_snapshot_filter(args, SnapshotFilter("include_snapshot_times", timerange, None))
86 else:
87 assert timerange is not None
88 _add_snapshot_filter(args, SnapshotFilter(dst, timerange, rankranges))
91def has_timerange_filter(snapshot_filters: list[list[SnapshotFilter]]) -> bool:
92 """Interacts with add_time_and_rank_snapshot_filter() and optimize_snapshot_filters()."""
94 return any(f.timerange is not None for snapshot_filter in snapshot_filters for f in snapshot_filter)
97def optimize_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> list[SnapshotFilter]:
98 """Basic optimizations for the snapshot filter execution plan."""
100 _merge_adjacent_snapshot_filters(snapshot_filters)
101 _merge_adjacent_snapshot_regexes(snapshot_filters)
102 snapshot_filters = [f for f in snapshot_filters if f.timerange or f.options]
103 _reorder_snapshot_time_filters(snapshot_filters)
104 return snapshot_filters
107def _merge_adjacent_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> None:
108 """Merge adjacent filters of the same type if possible."""
110 i = len(snapshot_filters) - 1
111 while i >= 0:
112 filter_i: SnapshotFilter = snapshot_filters[i]
113 if isinstance(filter_i.options, list):
114 j = i - 1
115 if j >= 0 and snapshot_filters[j] == filter_i:
116 lst: list = snapshot_filters[j].options
117 assert isinstance(lst, list)
118 lst += filter_i.options
119 snapshot_filters.pop(i)
120 i -= 1
123def _merge_adjacent_snapshot_regexes(snapshot_filters: list[SnapshotFilter]) -> None:
124 """Combine consecutive regex filters of the same kind for efficiency."""
126 i = len(snapshot_filters) - 1
127 while i >= 0:
128 filter_i: SnapshotFilter = snapshot_filters[i]
129 if filter_i.name in SNAPSHOT_REGEX_FILTER_NAMES:
130 assert isinstance(filter_i.options, list)
131 j = i - 1
132 while j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES:
133 if snapshot_filters[j].name == filter_i.name:
134 lst: list[object] = snapshot_filters[j].options
135 assert isinstance(lst, list)
136 lst += filter_i.options
137 snapshot_filters.pop(i)
138 break
139 j -= 1
140 i -= 1
142 i = len(snapshot_filters) - 1
143 while i >= 0:
144 filter_i = snapshot_filters[i]
145 name: str = filter_i.name
146 if name in SNAPSHOT_REGEX_FILTER_NAMES:
147 j = i - 1
148 if j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES:
149 filter_j = snapshot_filters[j]
150 assert filter_j.name != name
151 snapshot_filters.pop(i)
152 i -= 1
153 else:
154 name_j: str = next(iter(SNAPSHOT_REGEX_FILTER_NAMES.difference({name})))
155 filter_j = SnapshotFilter(name_j, None, [])
156 sorted_filters: list[SnapshotFilter] = sorted([filter_i, filter_j])
157 exclude_regexes, include_regexes = (sorted_filters[0].options, sorted_filters[1].options)
158 snapshot_filters[i] = SnapshotFilter(SNAPSHOT_REGEX_FILTER_NAME, None, (exclude_regexes, include_regexes))
159 i -= 1
162def _reorder_snapshot_time_filters(snapshot_filters: list[SnapshotFilter]) -> None:
163 """Reorder time filters before regex filters within execution plan sections."""
165 def reorder_time_filters_within_section(i: int, j: int) -> None:
166 while j > i:
167 filter_j: SnapshotFilter = snapshot_filters[j]
168 if filter_j.name == "include_snapshot_times":
169 snapshot_filters.pop(j)
170 snapshot_filters.insert(i + 1, filter_j)
171 j -= 1
173 i = len(snapshot_filters) - 1
174 j = i
175 while i >= 0:
176 name: str = snapshot_filters[i].name
177 if name == "include_snapshot_times_and_ranks":
178 reorder_time_filters_within_section(i, j)
179 j = i - 1
180 i -= 1
181 reorder_time_filters_within_section(i, j)
184def validate_no_argument_file(
185 path: str, namespace: argparse.Namespace, err_prefix: str, parser: argparse.ArgumentParser | None = None
186) -> None:
187 """Checks that command line options do not include +file when disabled."""
188 if getattr(namespace, "no_argument_file", False):
189 die(f"{err_prefix}Argument file inclusion is disabled: {path}", parser=parser)
192#############################################################################
193class NonEmptyStringAction(argparse.Action):
194 """Argparse action rejecting empty string values."""
196 def __call__(
197 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
198 ) -> None:
199 """Strip whitespace and reject empty values."""
200 values = values.strip()
201 if values == "":
202 parser.error(f"{option_string}: Empty string is not valid")
203 setattr(namespace, self.dest, values)
206#############################################################################
207class DatasetPairsAction(argparse.Action):
208 """Parses alternating source/destination dataset arguments."""
210 def __call__(
211 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
212 ) -> None:
213 """Validates dataset pair arguments and expand '+file' notation."""
214 datasets: list[str] = []
215 err_prefix: str = f"{option_string or self.dest}: "
217 for value in values:
218 if not value.startswith("+"):
219 datasets.append(value)
220 else:
221 path: str = value[1:]
222 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser)
223 if "bzfs_argument_file" not in os.path.basename(path):
224 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}")
225 try:
226 with open_nofollow(path, "r", encoding="utf-8") as fd:
227 for i, line in enumerate(fd.read().splitlines()):
228 if line.startswith("#") or not line.strip():
229 continue
230 splits: list[str] = line.split("\t", 1)
231 if len(splits) <= 1:
232 parser.error(f"{err_prefix}Line must contain tab-separated SRC_DATASET and DST_DATASET: {i}")
233 src_root_dataset, dst_root_dataset = splits
234 if not src_root_dataset.strip() or not dst_root_dataset.strip():
235 parser.error(
236 f"{err_prefix}SRC_DATASET and DST_DATASET must not be empty or whitespace-only: {i}"
237 )
238 datasets.append(src_root_dataset)
239 datasets.append(dst_root_dataset)
240 except OSError as e:
241 parser.error(f"{err_prefix}{e}")
243 if len(datasets) % 2 != 0:
244 parser.error(f"{err_prefix}Each SRC_DATASET must have a corresponding DST_DATASET: {datasets}")
245 root_dataset_pairs: list[tuple[str, str]] = [(datasets[i], datasets[i + 1]) for i in range(0, len(datasets), 2)]
246 setattr(namespace, self.dest, root_dataset_pairs)
249#############################################################################
250class SSHConfigFileNameAction(argparse.Action):
251 """Validates SSH config file argument contains no whitespace or shell chars."""
253 def __call__(
254 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
255 ) -> None:
256 """Reject invalid file names with spaces or shell metacharacters."""
258 values = values.strip()
259 if values == "":
260 parser.error(f"{option_string}: Empty string is not valid")
261 if any(char in SHELL_CHARS or char.isspace() for char in values):
262 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace or special chars.")
263 setattr(namespace, self.dest, values)
266#############################################################################
267class SafeFileNameAction(argparse.Action):
268 """Ensures filenames lack path separators and weird whitespace."""
270 def __call__(
271 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
272 ) -> None:
273 """Rejects filenames containing path traversal or unusual whitespace."""
274 if ".." in values or "/" in values or "\\" in values:
275 parser.error(f"{option_string}: Invalid file name '{values}': must not contain '..' or '/' or '\\'.")
276 if any(char.isspace() and char != " " for char in values):
277 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace other than space.")
278 setattr(namespace, self.dest, values)
281#############################################################################
282class SafeDirectoryNameAction(argparse.Action):
283 """Validates directory name argument, allowing only simple spaces."""
285 def __call__(
286 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
287 ) -> None:
288 """Rejects directory names with weird whitespace or emptiness."""
289 values = values.strip()
290 if values == "":
291 parser.error(f"{option_string}: Empty string is not valid")
292 if any(char.isspace() and char != " " for char in values):
293 parser.error(f"{option_string}: Invalid dir name '{values}': must not contain whitespace other than space.")
294 setattr(namespace, self.dest, values)
297#############################################################################
298class NewSnapshotFilterGroupAction(argparse.Action):
299 """Starts a new filter group when seen in command line arguments."""
301 def __call__(
302 self, parser: argparse.ArgumentParser, args: argparse.Namespace, values: Any, option_string: str | None = None
303 ) -> None:
304 """Insert an empty group before adding new snapshot filters."""
305 if not hasattr(args, SNAPSHOT_FILTERS_VAR):
306 args.snapshot_filters_var = [[]]
307 elif len(args.snapshot_filters_var[-1]) > 0:
308 args.snapshot_filters_var.append([])
311#############################################################################
312class FileOrLiteralAction(argparse.Action):
313 """Allows '@file' style argument expansion with '+' prefix."""
315 def __call__(
316 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
317 ) -> None:
318 """Expands file arguments and appends them to the namespace."""
320 current_values: list[str] | None = getattr(namespace, self.dest, None)
321 if current_values is None:
322 current_values = []
323 extra_values: list[str] = []
324 err_prefix: str = f"{option_string or self.dest}: "
325 for value in values:
326 if not value.startswith("+"):
327 extra_values.append(value)
328 else:
329 path: str = value[1:]
330 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser)
331 if "bzfs_argument_file" not in os.path.basename(path):
332 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}")
333 try:
334 with open_nofollow(path, "r", encoding="utf-8") as fd:
335 for line in fd.read().splitlines():
336 if line.startswith("#") or not line.strip():
337 continue
338 extra_values.append(line)
339 except OSError as e:
340 parser.error(f"{err_prefix}{e}")
341 current_values += extra_values
342 setattr(namespace, self.dest, current_values)
343 if self.dest in SNAPSHOT_REGEX_FILTER_NAMES:
344 _add_snapshot_filter(namespace, SnapshotFilter(self.dest, None, extra_values))
347#############################################################################
348class IncludeSnapshotPlanAction(argparse.Action):
349 """Parses include plan dictionaries from the command line."""
351 def __call__(
352 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
353 ) -> None:
354 """Builds a list of snapshot filters from a serialized plan."""
355 opts: list[str] | None = getattr(namespace, self.dest, None)
356 opts = [] if opts is None else opts
357 if not self._add_opts(opts, parser, values, option_string=option_string):
358 opts += ["--new-snapshot-filter-group", "--include-snapshot-regex=!.*"]
359 setattr(namespace, self.dest, opts)
361 def _add_opts(
362 self,
363 opts: list[str],
364 parser: argparse.ArgumentParser,
365 values: str,
366 option_string: str | None = None,
367 ) -> bool:
368 xperiods: SnapshotPeriods = SnapshotPeriods()
369 has_at_least_one_filter_clause: bool = False
370 for org, target_periods in ast.literal_eval(values).items():
371 prefix: str = re.escape(nprefix(org))
372 for target, periods in target_periods.items():
373 infix: str = re.escape(ninfix(target)) if target else YEAR_WITH_FOUR_DIGITS_REGEX.pattern
374 for period_unit, period_amount in periods.items():
375 if not isinstance(period_amount, int) or period_amount < 0:
376 parser.error(f"{option_string}: Period amount must be a non-negative integer: {period_amount}")
377 suffix: str = re.escape(nsuffix(period_unit))
378 regex: str = f"{prefix}{infix}.*{suffix}"
379 opts += ["--new-snapshot-filter-group", f"--include-snapshot-regex={regex}"]
380 duration_amount, duration_unit = xperiods.suffix_to_duration0(period_unit)
381 duration_unit_label: str | None = xperiods.period_labels.get(duration_unit)
382 opts += [
383 "--include-snapshot-times-and-ranks",
384 (
385 "notime"
386 if duration_unit_label is None or duration_amount * period_amount == 0
387 else f"{duration_amount * period_amount}{duration_unit_label}ago..anytime"
388 ),
389 f"latest{period_amount}",
390 ]
391 has_at_least_one_filter_clause = True
392 return has_at_least_one_filter_clause
395#############################################################################
396class DeleteDstSnapshotsExceptPlanAction(IncludeSnapshotPlanAction):
397 """Specialized include plan used to decide which dst snapshots to keep."""
399 def __call__(
400 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
401 ) -> None:
402 """Parses plan while preventing disasters."""
403 opts: list[str] | None = getattr(namespace, self.dest, None)
404 opts = [] if opts is None else opts
405 opts += ["--delete-dst-snapshots-except"]
406 if not self._add_opts(opts, parser, values, option_string=option_string):
407 parser.error(
408 f"{option_string}: Cowardly refusing to delete all snapshots on"
409 f"--delete-dst-snapshots-except-plan='{values}' (which means 'retain no snapshots' aka "
410 "'delete all snapshots'). Assuming this is an unintended pilot error rather than intended carnage. "
411 "Aborting. If this is really what is intended, use `--delete-dst-snapshots --include-snapshot-regex=.*` "
412 "instead to force the deletion."
413 )
414 setattr(namespace, self.dest, opts)
417#############################################################################
418class TimeRangeAndRankRangeAction(argparse.Action):
419 """Parses --include-snapshot-times-and-ranks option values."""
421 def __call__(
422 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
423 ) -> None:
424 """Converts user-supplied time and rank ranges into snapshot filters."""
426 def parse_time(time_spec: str) -> int | timedelta | None:
427 time_spec = time_spec.strip()
428 if time_spec == "*" or time_spec == "anytime":
429 return None
430 if time_spec.isdigit():
431 return int(time_spec)
432 try:
433 return timedelta(milliseconds=parse_duration_to_milliseconds(time_spec, regex_suffix=r"\s*ago"))
434 except ValueError:
435 try:
436 return unixtime_fromisoformat(time_spec)
437 except ValueError:
438 parser.error(f"{option_string}: Invalid duration, Unix time, or ISO 8601 datetime: {time_spec}")
440 assert isinstance(values, list)
441 assert len(values) > 0
442 value: str = values[0].strip()
443 if value == "notime":
444 value = "0..0"
445 if ".." not in value:
446 parser.error(f"{option_string}: Invalid time range: Missing '..' separator: {value}")
447 timerange_specs: list[int | timedelta | None] = [parse_time(time_spec) for time_spec in value.split("..", 1)]
448 rankranges: list[RankRange] = self._parse_rankranges(parser, values[1:], option_string=option_string)
449 setattr(namespace, self.dest, [timerange_specs] + rankranges)
450 timerange: UnixTimeRange = self._get_include_snapshot_times(timerange_specs)
451 _add_time_and_rank_snapshot_filter(namespace, self.dest, timerange, rankranges)
453 @staticmethod
454 def _get_include_snapshot_times(times: list[timedelta | int | None]) -> UnixTimeRange:
455 """Convert start and end times to ``UnixTimeRange`` for filtering."""
457 def utc_unix_time_in_seconds(time_spec: timedelta | int | None, default: int) -> timedelta | int:
458 if isinstance(time_spec, timedelta):
459 return time_spec
460 if isinstance(time_spec, int):
461 return int(time_spec)
462 return default
464 lo, hi = times
465 if lo is None and hi is None:
466 return None
467 lo = utc_unix_time_in_seconds(lo, default=0)
468 hi = utc_unix_time_in_seconds(hi, default=UNIX_TIME_INFINITY_SECS)
469 if isinstance(lo, int) and isinstance(hi, int):
470 return (lo, hi) if lo <= hi else (hi, lo)
471 return lo, hi
473 @staticmethod
474 def _parse_rankranges(parser: argparse.ArgumentParser, values: Any, option_string: str | None = None) -> list[RankRange]:
475 """Parses rank range strings like 'latest 3..latest 5' into tuples."""
477 def parse_rank(spec: str) -> tuple[bool, str, int, bool]:
478 spec = spec.strip()
479 if not (match := re.fullmatch(r"(all\s*except\s*)?(oldest|latest)\s*(\d+)%?", spec)):
480 parser.error(f"{option_string}: Invalid rank format: {spec}")
481 assert match
482 is_except: bool = bool(match.group(1))
483 kind: str = match.group(2)
484 num: int = int(match.group(3))
485 is_percent: bool = spec.endswith("%")
486 if is_percent and num > 100:
487 parser.error(f"{option_string}: Invalid rank: Percent must not be greater than 100: {spec}")
488 return is_except, kind, num, is_percent
490 rankranges: list[RankRange] = []
491 for value in values:
492 value = value.strip()
493 if ".." in value:
494 lo_split, hi_split = value.split("..", 1)
495 lo = parse_rank(lo_split)
496 hi = parse_rank(hi_split)
497 if lo[0] or hi[0]:
498 parser.error(f"{option_string}: Invalid rank range: {value}")
499 if lo[1] != hi[1]:
500 parser.error(f"{option_string}: Ambiguous rank range: Must not compare oldest with latest: {value}")
501 else:
502 hi = parse_rank(value)
503 is_except, kind, num, is_percent = hi
504 if is_except:
505 if is_percent:
506 negated_kind: str = "oldest" if kind == "latest" else "latest"
507 lo = parse_rank(f"{negated_kind}0")
508 hi = parse_rank(f"{negated_kind}{100-num}%")
509 else:
510 lo = parse_rank(f"{kind}{num}")
511 hi = parse_rank(f"{kind}100%")
512 else:
513 lo = parse_rank(f"{kind}0")
514 rankranges.append((lo[1:], hi[1:]))
515 return rankranges
518#############################################################################
519class CheckPercentRange(CheckRange):
520 """Argparse action verifying percentages fall within 0-100."""
522 def __call__(
523 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None
524 ) -> None:
525 """Normalizes integer or percent values and store them."""
526 assert isinstance(values, str)
527 original = values
528 values = values.strip()
529 is_percent: bool = values.endswith("%")
530 if is_percent:
531 values = values[0:-1]
532 try:
533 values = float(values)
534 except ValueError:
535 parser.error(f"{option_string}: Invalid percentage or number: {original}")
536 super().__call__(parser, namespace, values, option_string=option_string)
537 setattr(namespace, self.dest, (getattr(namespace, self.dest), is_percent))