Coverage for bzfs_main / argparse_actions.py: 100%

327 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 08:03 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Custom argparse actions shared by the 'bzfs' and 'bzfs_jobrunner' CLIs; These helpers validate and expand complex command 

16line syntax such as +file references, dataset pairs, and snapshot filters.""" 

17 

18from __future__ import ( 

19 annotations, 

20) 

21import argparse 

22import ast 

23import os 

24import re 

25from dataclasses import ( 

26 dataclass, 

27 field, 

28) 

29from datetime import ( 

30 timedelta, 

31) 

32from typing import ( 

33 Any, 

34 final, 

35) 

36 

37from bzfs_main.filter import ( 

38 SNAPSHOT_FILTERS_VAR, 

39 SNAPSHOT_REGEX_FILTER_NAME, 

40 SNAPSHOT_REGEX_FILTER_NAMES, 

41 RankRange, 

42 UnixTimeRange, 

43) 

44from bzfs_main.util.check_range import ( 

45 CheckRange, 

46) 

47from bzfs_main.util.utils import ( 

48 SHELL_CHARS, 

49 UNIX_TIME_INFINITY_SECS, 

50 YEAR_WITH_FOUR_DIGITS_REGEX, 

51 SnapshotPeriods, 

52 die, 

53 ninfix, 

54 nprefix, 

55 nsuffix, 

56 open_nofollow, 

57 parse_duration_to_milliseconds, 

58 unixtime_fromisoformat, 

59) 

60 

61 

62############################################################################# 

63@dataclass(order=True) 

64@final 

65class SnapshotFilter: 

66 """Represents a snapshot filter with matching options and time range.""" 

67 

68 name: str 

69 timerange: UnixTimeRange # defined in bzfs_main.filter 

70 options: Any = field(compare=False, default=None) 

71 

72 

73def _add_snapshot_filter(args: argparse.Namespace, _filter: SnapshotFilter) -> None: 

74 """Appends snapshot filter to namespace list, creating the list if absent.""" 

75 

76 if not hasattr(args, SNAPSHOT_FILTERS_VAR): 

77 args.snapshot_filters_var = [[]] 

78 args.snapshot_filters_var[-1].append(_filter) 

79 

80 

81def _add_time_and_rank_snapshot_filter( 

82 args: argparse.Namespace, dst: str, timerange: UnixTimeRange, rankranges: list[RankRange] 

83) -> None: 

84 """Creates and adds a SnapshotFilter using timerange and rank ranges.""" 

85 

86 if timerange is None or len(rankranges) == 0 or any(rankrange[0] == rankrange[1] for rankrange in rankranges): 

87 _add_snapshot_filter(args, SnapshotFilter("include_snapshot_times", timerange, None)) 

88 else: 

89 assert timerange is not None 

90 _add_snapshot_filter(args, SnapshotFilter(dst, timerange, rankranges)) 

91 

92 

93def has_timerange_filter(snapshot_filters: list[list[SnapshotFilter]]) -> bool: 

94 """Interacts with add_time_and_rank_snapshot_filter() and optimize_snapshot_filters().""" 

95 

96 return any(f.timerange is not None for snapshot_filter in snapshot_filters for f in snapshot_filter) 

97 

98 

99def optimize_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> list[SnapshotFilter]: 

100 """Basic optimizations for the snapshot filter execution plan.""" 

101 

102 _merge_adjacent_snapshot_filters(snapshot_filters) 

103 _merge_adjacent_snapshot_regexes(snapshot_filters) 

104 snapshot_filters = [f for f in snapshot_filters if f.timerange or f.options] 

105 _reorder_snapshot_time_filters(snapshot_filters) 

106 return snapshot_filters 

107 

108 

109def _merge_adjacent_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> None: 

110 """Merge adjacent filters of the same type if possible.""" 

111 

112 i = len(snapshot_filters) - 1 

113 while i >= 0: 

114 filter_i: SnapshotFilter = snapshot_filters[i] 

115 if isinstance(filter_i.options, list): 

116 j = i - 1 

117 if j >= 0 and snapshot_filters[j] == filter_i: 

118 lst: list = snapshot_filters[j].options 

119 assert isinstance(lst, list) 

120 lst += filter_i.options 

121 snapshot_filters.pop(i) 

122 i -= 1 

123 

124 

125def _merge_adjacent_snapshot_regexes(snapshot_filters: list[SnapshotFilter]) -> None: 

126 """Combine consecutive regex filters of the same kind for efficiency.""" 

127 

128 i = len(snapshot_filters) - 1 

129 while i >= 0: 

130 filter_i: SnapshotFilter = snapshot_filters[i] 

131 if filter_i.name in SNAPSHOT_REGEX_FILTER_NAMES: 

132 assert isinstance(filter_i.options, list) 

133 j = i - 1 

134 while j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES: 

135 if snapshot_filters[j].name == filter_i.name: 

136 lst: list[object] = snapshot_filters[j].options 

137 assert isinstance(lst, list) 

138 lst += filter_i.options 

139 snapshot_filters.pop(i) 

140 break 

141 j -= 1 

142 i -= 1 

143 

144 i = len(snapshot_filters) - 1 

145 while i >= 0: 

146 filter_i = snapshot_filters[i] 

147 name: str = filter_i.name 

148 if name in SNAPSHOT_REGEX_FILTER_NAMES: 

149 j = i - 1 

150 if j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES: 

151 filter_j = snapshot_filters[j] 

152 assert filter_j.name != name 

153 snapshot_filters.pop(i) 

154 i -= 1 

155 else: 

156 name_j: str = next(iter(SNAPSHOT_REGEX_FILTER_NAMES.difference({name}))) 

157 filter_j = SnapshotFilter(name_j, None, []) 

158 sorted_filters: list[SnapshotFilter] = sorted([filter_i, filter_j]) 

159 exclude_regexes, include_regexes = (sorted_filters[0].options, sorted_filters[1].options) 

160 snapshot_filters[i] = SnapshotFilter(SNAPSHOT_REGEX_FILTER_NAME, None, (exclude_regexes, include_regexes)) 

161 i -= 1 

162 

163 

164def _reorder_snapshot_time_filters(snapshot_filters: list[SnapshotFilter]) -> None: 

165 """Reorder time filters before regex filters within execution plan sections.""" 

166 

167 def reorder_time_filters_within_section(i: int, j: int) -> None: 

168 while j > i: 

169 filter_j: SnapshotFilter = snapshot_filters[j] 

170 if filter_j.name == "include_snapshot_times": 

171 snapshot_filters.pop(j) 

172 snapshot_filters.insert(i + 1, filter_j) 

173 j -= 1 

174 

175 i = len(snapshot_filters) - 1 

176 j = i 

177 while i >= 0: 

178 name: str = snapshot_filters[i].name 

179 if name == "include_snapshot_times_and_ranks": 

180 reorder_time_filters_within_section(i, j) 

181 j = i - 1 

182 i -= 1 

183 reorder_time_filters_within_section(i, j) 

184 

185 

186def validate_no_argument_file( 

187 path: str, namespace: argparse.Namespace, err_prefix: str, parser: argparse.ArgumentParser | None = None 

188) -> None: 

189 """Checks that command line options do not include +file when disabled.""" 

190 if getattr(namespace, "no_argument_file", False): 

191 die(f"{err_prefix}Argument file inclusion is disabled: {path}", parser=parser) 

192 

193 

194############################################################################# 

195@final 

196class NonEmptyStringAction(argparse.Action): 

197 """Argparse action rejecting empty string values.""" 

198 

199 def __call__( 

200 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

201 ) -> None: 

202 """Strip whitespace and reject empty values.""" 

203 values = values.strip() 

204 if values == "": 

205 parser.error(f"{option_string}: Empty string is not valid") 

206 setattr(namespace, self.dest, values) 

207 

208 

209############################################################################# 

210@final 

211class DatasetPairsAction(argparse.Action): 

212 """Parses alternating source/destination dataset arguments.""" 

213 

214 def __call__( 

215 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

216 ) -> None: 

217 """Validates dataset pair arguments and expand '+file' notation.""" 

218 datasets: list[str] = [] 

219 err_prefix: str = f"{option_string or self.dest}: " 

220 

221 for value in values: 

222 if not value.startswith("+"): 

223 datasets.append(value) 

224 else: 

225 path: str = value[1:] 

226 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser) 

227 if "bzfs_argument_file" not in os.path.basename(path): 

228 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}") 

229 try: 

230 with open_nofollow(path, "r", encoding="utf-8") as fd: 

231 for i, line in enumerate(fd.read().splitlines()): 

232 if line.startswith("#") or not line.strip(): 

233 continue 

234 splits: list[str] = line.split("\t", 1) 

235 if len(splits) <= 1: 

236 parser.error(f"{err_prefix}Line must contain tab-separated SRC_DATASET and DST_DATASET: {i}") 

237 src_root_dataset, dst_root_dataset = splits 

238 if not src_root_dataset.strip() or not dst_root_dataset.strip(): 

239 parser.error( 

240 f"{err_prefix}SRC_DATASET and DST_DATASET must not be empty or whitespace-only: {i}" 

241 ) 

242 datasets.append(src_root_dataset) 

243 datasets.append(dst_root_dataset) 

244 except OSError as e: 

245 parser.error(f"{err_prefix}{e}") 

246 

247 if len(datasets) % 2 != 0: 

248 parser.error(f"{err_prefix}Each SRC_DATASET must have a corresponding DST_DATASET: {datasets}") 

249 root_dataset_pairs: list[tuple[str, str]] = [(datasets[i], datasets[i + 1]) for i in range(0, len(datasets), 2)] 

250 setattr(namespace, self.dest, root_dataset_pairs) 

251 

252 

253############################################################################# 

254@final 

255class SSHConfigFileNameAction(argparse.Action): 

256 """Validates SSH config file argument contains no whitespace or shell chars.""" 

257 

258 def __call__( 

259 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

260 ) -> None: 

261 """Reject invalid file names with spaces or shell metacharacters.""" 

262 

263 values = values.strip() 

264 if values == "": 

265 parser.error(f"{option_string}: Empty string is not valid") 

266 if any(char in SHELL_CHARS or char.isspace() for char in values): 

267 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace or special chars.") 

268 setattr(namespace, self.dest, values) 

269 

270 

271############################################################################# 

272@final 

273class SafeFileNameAction(argparse.Action): 

274 """Ensures filenames lack path separators and weird whitespace.""" 

275 

276 def __call__( 

277 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

278 ) -> None: 

279 """Rejects filenames containing path traversal or unusual whitespace.""" 

280 if ".." in values or "/" in values or "\\" in values: 

281 parser.error(f"{option_string}: Invalid file name '{values}': must not contain '..' or '/' or '\\'.") 

282 if any(char.isspace() and char != " " for char in values): 

283 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace other than space.") 

284 setattr(namespace, self.dest, values) 

285 

286 

287############################################################################# 

288@final 

289class SafeDirectoryNameAction(argparse.Action): 

290 """Validates directory name argument, allowing only simple spaces.""" 

291 

292 def __call__( 

293 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

294 ) -> None: 

295 """Rejects directory names with weird whitespace or emptiness.""" 

296 values = values.strip() 

297 if values == "": 

298 parser.error(f"{option_string}: Empty string is not valid") 

299 if any(char.isspace() and char != " " for char in values): 

300 parser.error(f"{option_string}: Invalid dir name '{values}': must not contain whitespace other than space.") 

301 setattr(namespace, self.dest, values) 

302 

303 

304############################################################################# 

305@final 

306class NewSnapshotFilterGroupAction(argparse.Action): 

307 """Starts a new filter group when seen in command line arguments.""" 

308 

309 def __call__( 

310 self, parser: argparse.ArgumentParser, args: argparse.Namespace, values: Any, option_string: str | None = None 

311 ) -> None: 

312 """Insert an empty group before adding new snapshot filters.""" 

313 if not hasattr(args, SNAPSHOT_FILTERS_VAR): 

314 args.snapshot_filters_var = [[]] 

315 elif len(args.snapshot_filters_var[-1]) > 0: 

316 args.snapshot_filters_var.append([]) 

317 

318 

319############################################################################# 

320@final 

321class FileOrLiteralAction(argparse.Action): 

322 """Allows '@file' style argument expansion with '+' prefix.""" 

323 

324 def __call__( 

325 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

326 ) -> None: 

327 """Expands file arguments and appends them to the namespace.""" 

328 

329 current_values: list[str] | None = getattr(namespace, self.dest, None) 

330 if current_values is None: 

331 current_values = [] 

332 extra_values: list[str] = [] 

333 err_prefix: str = f"{option_string or self.dest}: " 

334 for value in values: 

335 if not value.startswith("+"): 

336 extra_values.append(value) 

337 else: 

338 path: str = value[1:] 

339 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser) 

340 if "bzfs_argument_file" not in os.path.basename(path): 

341 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}") 

342 try: 

343 with open_nofollow(path, "r", encoding="utf-8") as fd: 

344 for line in fd.read().splitlines(): 

345 if line.startswith("#") or not line.strip(): 

346 continue 

347 extra_values.append(line) 

348 except OSError as e: 

349 parser.error(f"{err_prefix}{e}") 

350 current_values += extra_values 

351 setattr(namespace, self.dest, current_values) 

352 if self.dest in SNAPSHOT_REGEX_FILTER_NAMES: 

353 _add_snapshot_filter(namespace, SnapshotFilter(self.dest, None, extra_values)) 

354 

355 

356############################################################################# 

357class IncludeSnapshotPlanAction(argparse.Action): 

358 """Parses include plan dictionaries from the command line.""" 

359 

360 def __call__( 

361 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

362 ) -> None: 

363 """Builds a list of snapshot filters from a serialized plan.""" 

364 opts: list[str] | None = getattr(namespace, self.dest, None) 

365 opts = [] if opts is None else opts 

366 if not self._add_opts(opts, parser, values, option_string=option_string): 

367 opts += ["--new-snapshot-filter-group", "--include-snapshot-regex=!.*"] 

368 setattr(namespace, self.dest, opts) 

369 

370 def _add_opts( 

371 self, 

372 opts: list[str], 

373 parser: argparse.ArgumentParser, 

374 values: str, 

375 option_string: str | None = None, 

376 ) -> bool: 

377 xperiods: SnapshotPeriods = SnapshotPeriods() 

378 has_at_least_one_filter_clause: bool = False 

379 for org, target_periods in ast.literal_eval(values).items(): 

380 prefix: str = re.escape(nprefix(org)) 

381 for target, periods in target_periods.items(): 

382 infix: str = re.escape(ninfix(target)) if target else YEAR_WITH_FOUR_DIGITS_REGEX.pattern 

383 for period_unit, period_amount in periods.items(): 

384 if not isinstance(period_amount, int) or period_amount < 0: 

385 parser.error(f"{option_string}: Period amount must be a non-negative integer: {period_amount}") 

386 suffix: str = re.escape(nsuffix(period_unit)) 

387 regex: str = f"{prefix}{infix}.*{suffix}" 

388 opts += ["--new-snapshot-filter-group", f"--include-snapshot-regex={regex}"] 

389 duration_amount, duration_unit = xperiods.suffix_to_duration0(period_unit) 

390 duration_unit_label: str | None = xperiods.period_labels.get(duration_unit) 

391 opts += [ 

392 "--include-snapshot-times-and-ranks", 

393 ( 

394 "notime" 

395 if duration_unit_label is None or duration_amount * period_amount == 0 

396 else f"{duration_amount * period_amount}{duration_unit_label}ago..anytime" 

397 ), 

398 f"latest{period_amount}", 

399 ] 

400 has_at_least_one_filter_clause = True 

401 return has_at_least_one_filter_clause 

402 

403 

404############################################################################# 

405@final 

406class DeleteDstSnapshotsExceptPlanAction(IncludeSnapshotPlanAction): 

407 """Specialized include plan used to decide which dst snapshots to keep.""" 

408 

409 def __call__( 

410 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

411 ) -> None: 

412 """Parses plan while preventing disasters.""" 

413 opts: list[str] | None = getattr(namespace, self.dest, None) 

414 opts = [] if opts is None else opts 

415 opts += ["--delete-dst-snapshots-except"] 

416 if not self._add_opts(opts, parser, values, option_string=option_string): 

417 parser.error( 

418 f"{option_string}: Cowardly refusing to delete all snapshots on" 

419 f"--delete-dst-snapshots-except-plan='{values}' (which means 'retain no snapshots' aka " 

420 "'delete all snapshots'). Assuming this is an unintended pilot error rather than intended carnage. " 

421 "Aborting. If this is really what is intended, use `--delete-dst-snapshots --include-snapshot-regex=.*` " 

422 "instead to force the deletion." 

423 ) 

424 setattr(namespace, self.dest, opts) 

425 

426 

427############################################################################# 

428@final 

429class TimeRangeAndRankRangeAction(argparse.Action): 

430 """Parses --include-snapshot-times-and-ranks option values.""" 

431 

432 def __call__( 

433 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

434 ) -> None: 

435 """Converts user-supplied time and rank ranges into snapshot filters.""" 

436 

437 def parse_time(time_spec: str) -> int | timedelta | None: 

438 time_spec = time_spec.strip() 

439 if time_spec == "*" or time_spec == "anytime": 

440 return None 

441 if time_spec.isdigit(): 

442 return int(time_spec) 

443 try: 

444 return timedelta(milliseconds=parse_duration_to_milliseconds(time_spec, regex_suffix=r"\s*ago")) 

445 except ValueError: 

446 try: 

447 return unixtime_fromisoformat(time_spec) 

448 except ValueError: 

449 parser.error(f"{option_string}: Invalid duration, Unix time, or ISO 8601 datetime: {time_spec}") 

450 

451 assert isinstance(values, list) 

452 assert len(values) > 0 

453 value: str = values[0].strip() 

454 if value == "notime": 

455 value = "0..0" 

456 if ".." not in value: 

457 parser.error(f"{option_string}: Invalid time range: Missing '..' separator: {value}") 

458 timerange_specs: list[int | timedelta | None] = [parse_time(time_spec) for time_spec in value.split("..", 1)] 

459 rankranges: list[RankRange] = self._parse_rankranges(parser, values[1:], option_string=option_string) 

460 setattr(namespace, self.dest, [timerange_specs] + rankranges) 

461 timerange: UnixTimeRange = self._get_include_snapshot_times(timerange_specs) 

462 _add_time_and_rank_snapshot_filter(namespace, self.dest, timerange, rankranges) 

463 

464 @staticmethod 

465 def _get_include_snapshot_times(times: list[timedelta | int | None]) -> UnixTimeRange: 

466 """Convert start and end times to ``UnixTimeRange`` for filtering.""" 

467 

468 def utc_unix_time_in_seconds(time_spec: timedelta | int | None, default: int) -> timedelta | int: 

469 if isinstance(time_spec, timedelta): 

470 return time_spec 

471 if isinstance(time_spec, int): 

472 return int(time_spec) 

473 return default 

474 

475 lo, hi = times 

476 if lo is None and hi is None: 

477 return None 

478 lo = utc_unix_time_in_seconds(lo, default=0) 

479 hi = utc_unix_time_in_seconds(hi, default=UNIX_TIME_INFINITY_SECS) 

480 if isinstance(lo, int) and isinstance(hi, int): 

481 return (lo, hi) if lo <= hi else (hi, lo) 

482 return lo, hi 

483 

484 @staticmethod 

485 def _parse_rankranges(parser: argparse.ArgumentParser, values: Any, option_string: str | None = None) -> list[RankRange]: 

486 """Parses rank range strings like 'latest 3..latest 5' into tuples.""" 

487 

488 def parse_rank(spec: str) -> tuple[bool, str, int, bool]: 

489 spec = spec.strip() 

490 if not (match := re.fullmatch(r"(all\s*except\s*)?(oldest|latest)\s*(\d+)%?", spec)): 

491 parser.error(f"{option_string}: Invalid rank format: {spec}") 

492 assert match 

493 is_except: bool = bool(match.group(1)) 

494 kind: str = match.group(2) 

495 num: int = int(match.group(3)) 

496 is_percent: bool = spec.endswith("%") 

497 if is_percent and num > 100: 

498 parser.error(f"{option_string}: Invalid rank: Percent must not be greater than 100: {spec}") 

499 return is_except, kind, num, is_percent 

500 

501 rankranges: list[RankRange] = [] 

502 for value in values: 

503 value = value.strip() 

504 if ".." in value: 

505 lo_split, hi_split = value.split("..", 1) 

506 lo = parse_rank(lo_split) 

507 hi = parse_rank(hi_split) 

508 if lo[0] or hi[0]: 

509 parser.error(f"{option_string}: Invalid rank range: {value}") 

510 if lo[1] != hi[1]: 

511 parser.error(f"{option_string}: Ambiguous rank range: Must not compare oldest with latest: {value}") 

512 else: 

513 hi = parse_rank(value) 

514 is_except, kind, num, is_percent = hi 

515 if is_except: 

516 if is_percent: 

517 negated_kind: str = "oldest" if kind == "latest" else "latest" 

518 lo = parse_rank(f"{negated_kind}0") 

519 hi = parse_rank(f"{negated_kind}{100-num}%") 

520 else: 

521 lo = parse_rank(f"{kind}{num}") 

522 hi = parse_rank(f"{kind}100%") 

523 else: 

524 lo = parse_rank(f"{kind}0") 

525 rankranges.append((lo[1:], hi[1:])) 

526 return rankranges 

527 

528 

529############################################################################# 

530@final 

531class CheckPercentRange(CheckRange): 

532 """Argparse action verifying percentages fall within 0-100.""" 

533 

534 def __call__( 

535 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

536 ) -> None: 

537 """Normalizes integer or percent values and store them.""" 

538 assert isinstance(values, str) 

539 original = values 

540 values = values.strip() 

541 is_percent: bool = values.endswith("%") 

542 if is_percent: 

543 values = values[0:-1] 

544 try: 

545 values = float(values) 

546 except ValueError: 

547 parser.error(f"{option_string}: Invalid percentage or number: {original}") 

548 super().__call__(parser, namespace, values, option_string=option_string) 

549 setattr(namespace, self.dest, (getattr(namespace, self.dest), is_percent))