Coverage for bzfs_main/argparse_actions.py: 100%

316 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2025-11-07 04:44 +0000

1# Copyright 2024 Wolfgang Hoschek AT mac DOT com 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# 

15"""Custom argparse actions shared by the 'bzfs' and 'bzfs_jobrunner' CLIs; These helpers validate and expand complex command 

16line syntax such as +file references, dataset pairs, and snapshot filters.""" 

17 

18from __future__ import ( 

19 annotations, 

20) 

21import argparse 

22import ast 

23import os 

24import re 

25from dataclasses import ( 

26 dataclass, 

27 field, 

28) 

29from datetime import ( 

30 timedelta, 

31) 

32from typing import ( 

33 Any, 

34) 

35 

36from bzfs_main.check_range import ( 

37 CheckRange, 

38) 

39from bzfs_main.filter import ( 

40 SNAPSHOT_REGEX_FILTER_NAME, 

41 SNAPSHOT_REGEX_FILTER_NAMES, 

42 RankRange, 

43 UnixTimeRange, 

44) 

45from bzfs_main.utils import ( 

46 SHELL_CHARS, 

47 SNAPSHOT_FILTERS_VAR, 

48 UNIX_TIME_INFINITY_SECS, 

49 YEAR_WITH_FOUR_DIGITS_REGEX, 

50 SnapshotPeriods, 

51 die, 

52 ninfix, 

53 nprefix, 

54 nsuffix, 

55 open_nofollow, 

56 parse_duration_to_milliseconds, 

57 unixtime_fromisoformat, 

58) 

59 

60 

61############################################################################# 

62@dataclass(order=True) 

63class SnapshotFilter: 

64 """Represents a snapshot filter with matching options and time range.""" 

65 

66 name: str 

67 timerange: UnixTimeRange # defined in bzfs_main.filter 

68 options: Any = field(compare=False, default=None) 

69 

70 

71def _add_snapshot_filter(args: argparse.Namespace, _filter: SnapshotFilter) -> None: 

72 """Appends snapshot filter to namespace list, creating the list if absent.""" 

73 

74 if not hasattr(args, SNAPSHOT_FILTERS_VAR): 

75 args.snapshot_filters_var = [[]] 

76 args.snapshot_filters_var[-1].append(_filter) 

77 

78 

79def _add_time_and_rank_snapshot_filter( 

80 args: argparse.Namespace, dst: str, timerange: UnixTimeRange, rankranges: list[RankRange] 

81) -> None: 

82 """Creates and adds a SnapshotFilter using timerange and rank ranges.""" 

83 

84 if timerange is None or len(rankranges) == 0 or any(rankrange[0] == rankrange[1] for rankrange in rankranges): 

85 _add_snapshot_filter(args, SnapshotFilter("include_snapshot_times", timerange, None)) 

86 else: 

87 assert timerange is not None 

88 _add_snapshot_filter(args, SnapshotFilter(dst, timerange, rankranges)) 

89 

90 

91def has_timerange_filter(snapshot_filters: list[list[SnapshotFilter]]) -> bool: 

92 """Interacts with add_time_and_rank_snapshot_filter() and optimize_snapshot_filters().""" 

93 

94 return any(f.timerange is not None for snapshot_filter in snapshot_filters for f in snapshot_filter) 

95 

96 

97def optimize_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> list[SnapshotFilter]: 

98 """Basic optimizations for the snapshot filter execution plan.""" 

99 

100 _merge_adjacent_snapshot_filters(snapshot_filters) 

101 _merge_adjacent_snapshot_regexes(snapshot_filters) 

102 snapshot_filters = [f for f in snapshot_filters if f.timerange or f.options] 

103 _reorder_snapshot_time_filters(snapshot_filters) 

104 return snapshot_filters 

105 

106 

107def _merge_adjacent_snapshot_filters(snapshot_filters: list[SnapshotFilter]) -> None: 

108 """Merge adjacent filters of the same type if possible.""" 

109 

110 i = len(snapshot_filters) - 1 

111 while i >= 0: 

112 filter_i: SnapshotFilter = snapshot_filters[i] 

113 if isinstance(filter_i.options, list): 

114 j = i - 1 

115 if j >= 0 and snapshot_filters[j] == filter_i: 

116 lst: list = snapshot_filters[j].options 

117 assert isinstance(lst, list) 

118 lst += filter_i.options 

119 snapshot_filters.pop(i) 

120 i -= 1 

121 

122 

123def _merge_adjacent_snapshot_regexes(snapshot_filters: list[SnapshotFilter]) -> None: 

124 """Combine consecutive regex filters of the same kind for efficiency.""" 

125 

126 i = len(snapshot_filters) - 1 

127 while i >= 0: 

128 filter_i: SnapshotFilter = snapshot_filters[i] 

129 if filter_i.name in SNAPSHOT_REGEX_FILTER_NAMES: 

130 assert isinstance(filter_i.options, list) 

131 j = i - 1 

132 while j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES: 

133 if snapshot_filters[j].name == filter_i.name: 

134 lst: list[object] = snapshot_filters[j].options 

135 assert isinstance(lst, list) 

136 lst += filter_i.options 

137 snapshot_filters.pop(i) 

138 break 

139 j -= 1 

140 i -= 1 

141 

142 i = len(snapshot_filters) - 1 

143 while i >= 0: 

144 filter_i = snapshot_filters[i] 

145 name: str = filter_i.name 

146 if name in SNAPSHOT_REGEX_FILTER_NAMES: 

147 j = i - 1 

148 if j >= 0 and snapshot_filters[j].name in SNAPSHOT_REGEX_FILTER_NAMES: 

149 filter_j = snapshot_filters[j] 

150 assert filter_j.name != name 

151 snapshot_filters.pop(i) 

152 i -= 1 

153 else: 

154 name_j: str = next(iter(SNAPSHOT_REGEX_FILTER_NAMES.difference({name}))) 

155 filter_j = SnapshotFilter(name_j, None, []) 

156 sorted_filters: list[SnapshotFilter] = sorted([filter_i, filter_j]) 

157 exclude_regexes, include_regexes = (sorted_filters[0].options, sorted_filters[1].options) 

158 snapshot_filters[i] = SnapshotFilter(SNAPSHOT_REGEX_FILTER_NAME, None, (exclude_regexes, include_regexes)) 

159 i -= 1 

160 

161 

162def _reorder_snapshot_time_filters(snapshot_filters: list[SnapshotFilter]) -> None: 

163 """Reorder time filters before regex filters within execution plan sections.""" 

164 

165 def reorder_time_filters_within_section(i: int, j: int) -> None: 

166 while j > i: 

167 filter_j: SnapshotFilter = snapshot_filters[j] 

168 if filter_j.name == "include_snapshot_times": 

169 snapshot_filters.pop(j) 

170 snapshot_filters.insert(i + 1, filter_j) 

171 j -= 1 

172 

173 i = len(snapshot_filters) - 1 

174 j = i 

175 while i >= 0: 

176 name: str = snapshot_filters[i].name 

177 if name == "include_snapshot_times_and_ranks": 

178 reorder_time_filters_within_section(i, j) 

179 j = i - 1 

180 i -= 1 

181 reorder_time_filters_within_section(i, j) 

182 

183 

184def validate_no_argument_file( 

185 path: str, namespace: argparse.Namespace, err_prefix: str, parser: argparse.ArgumentParser | None = None 

186) -> None: 

187 """Checks that command line options do not include +file when disabled.""" 

188 if getattr(namespace, "no_argument_file", False): 

189 die(f"{err_prefix}Argument file inclusion is disabled: {path}", parser=parser) 

190 

191 

192############################################################################# 

193class NonEmptyStringAction(argparse.Action): 

194 """Argparse action rejecting empty string values.""" 

195 

196 def __call__( 

197 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

198 ) -> None: 

199 """Strip whitespace and reject empty values.""" 

200 values = values.strip() 

201 if values == "": 

202 parser.error(f"{option_string}: Empty string is not valid") 

203 setattr(namespace, self.dest, values) 

204 

205 

206############################################################################# 

207class DatasetPairsAction(argparse.Action): 

208 """Parses alternating source/destination dataset arguments.""" 

209 

210 def __call__( 

211 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

212 ) -> None: 

213 """Validates dataset pair arguments and expand '+file' notation.""" 

214 datasets: list[str] = [] 

215 err_prefix: str = f"{option_string or self.dest}: " 

216 

217 for value in values: 

218 if not value.startswith("+"): 

219 datasets.append(value) 

220 else: 

221 path: str = value[1:] 

222 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser) 

223 if "bzfs_argument_file" not in os.path.basename(path): 

224 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}") 

225 try: 

226 with open_nofollow(path, "r", encoding="utf-8") as fd: 

227 for i, line in enumerate(fd.read().splitlines()): 

228 if line.startswith("#") or not line.strip(): 

229 continue 

230 splits: list[str] = line.split("\t", 1) 

231 if len(splits) <= 1: 

232 parser.error(f"{err_prefix}Line must contain tab-separated SRC_DATASET and DST_DATASET: {i}") 

233 src_root_dataset, dst_root_dataset = splits 

234 if not src_root_dataset.strip() or not dst_root_dataset.strip(): 

235 parser.error( 

236 f"{err_prefix}SRC_DATASET and DST_DATASET must not be empty or whitespace-only: {i}" 

237 ) 

238 datasets.append(src_root_dataset) 

239 datasets.append(dst_root_dataset) 

240 except OSError as e: 

241 parser.error(f"{err_prefix}{e}") 

242 

243 if len(datasets) % 2 != 0: 

244 parser.error(f"{err_prefix}Each SRC_DATASET must have a corresponding DST_DATASET: {datasets}") 

245 root_dataset_pairs: list[tuple[str, str]] = [(datasets[i], datasets[i + 1]) for i in range(0, len(datasets), 2)] 

246 setattr(namespace, self.dest, root_dataset_pairs) 

247 

248 

249############################################################################# 

250class SSHConfigFileNameAction(argparse.Action): 

251 """Validates SSH config file argument contains no whitespace or shell chars.""" 

252 

253 def __call__( 

254 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

255 ) -> None: 

256 """Reject invalid file names with spaces or shell metacharacters.""" 

257 

258 values = values.strip() 

259 if values == "": 

260 parser.error(f"{option_string}: Empty string is not valid") 

261 if any(char in SHELL_CHARS or char.isspace() for char in values): 

262 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace or special chars.") 

263 setattr(namespace, self.dest, values) 

264 

265 

266############################################################################# 

267class SafeFileNameAction(argparse.Action): 

268 """Ensures filenames lack path separators and weird whitespace.""" 

269 

270 def __call__( 

271 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

272 ) -> None: 

273 """Rejects filenames containing path traversal or unusual whitespace.""" 

274 if ".." in values or "/" in values or "\\" in values: 

275 parser.error(f"{option_string}: Invalid file name '{values}': must not contain '..' or '/' or '\\'.") 

276 if any(char.isspace() and char != " " for char in values): 

277 parser.error(f"{option_string}: Invalid file name '{values}': must not contain whitespace other than space.") 

278 setattr(namespace, self.dest, values) 

279 

280 

281############################################################################# 

282class SafeDirectoryNameAction(argparse.Action): 

283 """Validates directory name argument, allowing only simple spaces.""" 

284 

285 def __call__( 

286 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

287 ) -> None: 

288 """Rejects directory names with weird whitespace or emptiness.""" 

289 values = values.strip() 

290 if values == "": 

291 parser.error(f"{option_string}: Empty string is not valid") 

292 if any(char.isspace() and char != " " for char in values): 

293 parser.error(f"{option_string}: Invalid dir name '{values}': must not contain whitespace other than space.") 

294 setattr(namespace, self.dest, values) 

295 

296 

297############################################################################# 

298class NewSnapshotFilterGroupAction(argparse.Action): 

299 """Starts a new filter group when seen in command line arguments.""" 

300 

301 def __call__( 

302 self, parser: argparse.ArgumentParser, args: argparse.Namespace, values: Any, option_string: str | None = None 

303 ) -> None: 

304 """Insert an empty group before adding new snapshot filters.""" 

305 if not hasattr(args, SNAPSHOT_FILTERS_VAR): 

306 args.snapshot_filters_var = [[]] 

307 elif len(args.snapshot_filters_var[-1]) > 0: 

308 args.snapshot_filters_var.append([]) 

309 

310 

311############################################################################# 

312class FileOrLiteralAction(argparse.Action): 

313 """Allows '@file' style argument expansion with '+' prefix.""" 

314 

315 def __call__( 

316 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

317 ) -> None: 

318 """Expands file arguments and appends them to the namespace.""" 

319 

320 current_values: list[str] | None = getattr(namespace, self.dest, None) 

321 if current_values is None: 

322 current_values = [] 

323 extra_values: list[str] = [] 

324 err_prefix: str = f"{option_string or self.dest}: " 

325 for value in values: 

326 if not value.startswith("+"): 

327 extra_values.append(value) 

328 else: 

329 path: str = value[1:] 

330 validate_no_argument_file(path, namespace, err_prefix=err_prefix, parser=parser) 

331 if "bzfs_argument_file" not in os.path.basename(path): 

332 parser.error(f"{err_prefix}basename must contain substring 'bzfs_argument_file': {path}") 

333 try: 

334 with open_nofollow(path, "r", encoding="utf-8") as fd: 

335 for line in fd.read().splitlines(): 

336 if line.startswith("#") or not line.strip(): 

337 continue 

338 extra_values.append(line) 

339 except OSError as e: 

340 parser.error(f"{err_prefix}{e}") 

341 current_values += extra_values 

342 setattr(namespace, self.dest, current_values) 

343 if self.dest in SNAPSHOT_REGEX_FILTER_NAMES: 

344 _add_snapshot_filter(namespace, SnapshotFilter(self.dest, None, extra_values)) 

345 

346 

347############################################################################# 

348class IncludeSnapshotPlanAction(argparse.Action): 

349 """Parses include plan dictionaries from the command line.""" 

350 

351 def __call__( 

352 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

353 ) -> None: 

354 """Builds a list of snapshot filters from a serialized plan.""" 

355 opts: list[str] | None = getattr(namespace, self.dest, None) 

356 opts = [] if opts is None else opts 

357 if not self._add_opts(opts, parser, values, option_string=option_string): 

358 opts += ["--new-snapshot-filter-group", "--include-snapshot-regex=!.*"] 

359 setattr(namespace, self.dest, opts) 

360 

361 def _add_opts( 

362 self, 

363 opts: list[str], 

364 parser: argparse.ArgumentParser, 

365 values: str, 

366 option_string: str | None = None, 

367 ) -> bool: 

368 xperiods: SnapshotPeriods = SnapshotPeriods() 

369 has_at_least_one_filter_clause: bool = False 

370 for org, target_periods in ast.literal_eval(values).items(): 

371 prefix: str = re.escape(nprefix(org)) 

372 for target, periods in target_periods.items(): 

373 infix: str = re.escape(ninfix(target)) if target else YEAR_WITH_FOUR_DIGITS_REGEX.pattern 

374 for period_unit, period_amount in periods.items(): 

375 if not isinstance(period_amount, int) or period_amount < 0: 

376 parser.error(f"{option_string}: Period amount must be a non-negative integer: {period_amount}") 

377 suffix: str = re.escape(nsuffix(period_unit)) 

378 regex: str = f"{prefix}{infix}.*{suffix}" 

379 opts += ["--new-snapshot-filter-group", f"--include-snapshot-regex={regex}"] 

380 duration_amount, duration_unit = xperiods.suffix_to_duration0(period_unit) 

381 duration_unit_label: str | None = xperiods.period_labels.get(duration_unit) 

382 opts += [ 

383 "--include-snapshot-times-and-ranks", 

384 ( 

385 "notime" 

386 if duration_unit_label is None or duration_amount * period_amount == 0 

387 else f"{duration_amount * period_amount}{duration_unit_label}ago..anytime" 

388 ), 

389 f"latest{period_amount}", 

390 ] 

391 has_at_least_one_filter_clause = True 

392 return has_at_least_one_filter_clause 

393 

394 

395############################################################################# 

396class DeleteDstSnapshotsExceptPlanAction(IncludeSnapshotPlanAction): 

397 """Specialized include plan used to decide which dst snapshots to keep.""" 

398 

399 def __call__( 

400 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

401 ) -> None: 

402 """Parses plan while preventing disasters.""" 

403 opts: list[str] | None = getattr(namespace, self.dest, None) 

404 opts = [] if opts is None else opts 

405 opts += ["--delete-dst-snapshots-except"] 

406 if not self._add_opts(opts, parser, values, option_string=option_string): 

407 parser.error( 

408 f"{option_string}: Cowardly refusing to delete all snapshots on" 

409 f"--delete-dst-snapshots-except-plan='{values}' (which means 'retain no snapshots' aka " 

410 "'delete all snapshots'). Assuming this is an unintended pilot error rather than intended carnage. " 

411 "Aborting. If this is really what is intended, use `--delete-dst-snapshots --include-snapshot-regex=.*` " 

412 "instead to force the deletion." 

413 ) 

414 setattr(namespace, self.dest, opts) 

415 

416 

417############################################################################# 

418class TimeRangeAndRankRangeAction(argparse.Action): 

419 """Parses --include-snapshot-times-and-ranks option values.""" 

420 

421 def __call__( 

422 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

423 ) -> None: 

424 """Converts user-supplied time and rank ranges into snapshot filters.""" 

425 

426 def parse_time(time_spec: str) -> int | timedelta | None: 

427 time_spec = time_spec.strip() 

428 if time_spec == "*" or time_spec == "anytime": 

429 return None 

430 if time_spec.isdigit(): 

431 return int(time_spec) 

432 try: 

433 return timedelta(milliseconds=parse_duration_to_milliseconds(time_spec, regex_suffix=r"\s*ago")) 

434 except ValueError: 

435 try: 

436 return unixtime_fromisoformat(time_spec) 

437 except ValueError: 

438 parser.error(f"{option_string}: Invalid duration, Unix time, or ISO 8601 datetime: {time_spec}") 

439 

440 assert isinstance(values, list) 

441 assert len(values) > 0 

442 value: str = values[0].strip() 

443 if value == "notime": 

444 value = "0..0" 

445 if ".." not in value: 

446 parser.error(f"{option_string}: Invalid time range: Missing '..' separator: {value}") 

447 timerange_specs: list[int | timedelta | None] = [parse_time(time_spec) for time_spec in value.split("..", 1)] 

448 rankranges: list[RankRange] = self._parse_rankranges(parser, values[1:], option_string=option_string) 

449 setattr(namespace, self.dest, [timerange_specs] + rankranges) 

450 timerange: UnixTimeRange = self._get_include_snapshot_times(timerange_specs) 

451 _add_time_and_rank_snapshot_filter(namespace, self.dest, timerange, rankranges) 

452 

453 @staticmethod 

454 def _get_include_snapshot_times(times: list[timedelta | int | None]) -> UnixTimeRange: 

455 """Convert start and end times to ``UnixTimeRange`` for filtering.""" 

456 

457 def utc_unix_time_in_seconds(time_spec: timedelta | int | None, default: int) -> timedelta | int: 

458 if isinstance(time_spec, timedelta): 

459 return time_spec 

460 if isinstance(time_spec, int): 

461 return int(time_spec) 

462 return default 

463 

464 lo, hi = times 

465 if lo is None and hi is None: 

466 return None 

467 lo = utc_unix_time_in_seconds(lo, default=0) 

468 hi = utc_unix_time_in_seconds(hi, default=UNIX_TIME_INFINITY_SECS) 

469 if isinstance(lo, int) and isinstance(hi, int): 

470 return (lo, hi) if lo <= hi else (hi, lo) 

471 return lo, hi 

472 

473 @staticmethod 

474 def _parse_rankranges(parser: argparse.ArgumentParser, values: Any, option_string: str | None = None) -> list[RankRange]: 

475 """Parses rank range strings like 'latest 3..latest 5' into tuples.""" 

476 

477 def parse_rank(spec: str) -> tuple[bool, str, int, bool]: 

478 spec = spec.strip() 

479 if not (match := re.fullmatch(r"(all\s*except\s*)?(oldest|latest)\s*(\d+)%?", spec)): 

480 parser.error(f"{option_string}: Invalid rank format: {spec}") 

481 assert match 

482 is_except: bool = bool(match.group(1)) 

483 kind: str = match.group(2) 

484 num: int = int(match.group(3)) 

485 is_percent: bool = spec.endswith("%") 

486 if is_percent and num > 100: 

487 parser.error(f"{option_string}: Invalid rank: Percent must not be greater than 100: {spec}") 

488 return is_except, kind, num, is_percent 

489 

490 rankranges: list[RankRange] = [] 

491 for value in values: 

492 value = value.strip() 

493 if ".." in value: 

494 lo_split, hi_split = value.split("..", 1) 

495 lo = parse_rank(lo_split) 

496 hi = parse_rank(hi_split) 

497 if lo[0] or hi[0]: 

498 parser.error(f"{option_string}: Invalid rank range: {value}") 

499 if lo[1] != hi[1]: 

500 parser.error(f"{option_string}: Ambiguous rank range: Must not compare oldest with latest: {value}") 

501 else: 

502 hi = parse_rank(value) 

503 is_except, kind, num, is_percent = hi 

504 if is_except: 

505 if is_percent: 

506 negated_kind: str = "oldest" if kind == "latest" else "latest" 

507 lo = parse_rank(f"{negated_kind}0") 

508 hi = parse_rank(f"{negated_kind}{100-num}%") 

509 else: 

510 lo = parse_rank(f"{kind}{num}") 

511 hi = parse_rank(f"{kind}100%") 

512 else: 

513 lo = parse_rank(f"{kind}0") 

514 rankranges.append((lo[1:], hi[1:])) 

515 return rankranges 

516 

517 

518############################################################################# 

519class CheckPercentRange(CheckRange): 

520 """Argparse action verifying percentages fall within 0-100.""" 

521 

522 def __call__( 

523 self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, values: Any, option_string: str | None = None 

524 ) -> None: 

525 """Normalizes integer or percent values and store them.""" 

526 assert isinstance(values, str) 

527 original = values 

528 values = values.strip() 

529 is_percent: bool = values.endswith("%") 

530 if is_percent: 

531 values = values[0:-1] 

532 try: 

533 values = float(values) 

534 except ValueError: 

535 parser.error(f"{option_string}: Invalid percentage or number: {original}") 

536 super().__call__(parser, namespace, values, option_string=option_string) 

537 setattr(namespace, self.dest, (getattr(namespace, self.dest), is_percent))