|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +"""Abstract base class for DAG importers.""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import logging |
| 22 | +import os |
| 23 | +import threading |
| 24 | +from abc import ABC, abstractmethod |
| 25 | +from collections.abc import Iterator |
| 26 | +from dataclasses import dataclass, field |
| 27 | +from pathlib import Path |
| 28 | +from typing import TYPE_CHECKING |
| 29 | + |
| 30 | +from airflow._shared.module_loading.file_discovery import find_path_from_directory |
| 31 | +from airflow.configuration import conf |
| 32 | +from airflow.utils.file import might_contain_dag |
| 33 | + |
| 34 | +if TYPE_CHECKING: |
| 35 | + from airflow.sdk import DAG |
| 36 | + |
| 37 | +log = logging.getLogger(__name__) |
| 38 | + |
| 39 | + |
| 40 | +@dataclass |
| 41 | +class DagImportError: |
| 42 | + """Structured error information for DAG import failures.""" |
| 43 | + |
| 44 | + file_path: str |
| 45 | + message: str |
| 46 | + error_type: str = "import" |
| 47 | + line_number: int | None = None |
| 48 | + column_number: int | None = None |
| 49 | + context: str | None = None |
| 50 | + suggestion: str | None = None |
| 51 | + stacktrace: str | None = None |
| 52 | + |
| 53 | + def format_message(self) -> str: |
| 54 | + """Format the error as a human-readable string.""" |
| 55 | + parts = [f"Error in {self.file_path}"] |
| 56 | + if self.line_number is not None: |
| 57 | + loc = f"line {self.line_number}" |
| 58 | + if self.column_number is not None: |
| 59 | + loc += f", column {self.column_number}" |
| 60 | + parts.append(f"Location: {loc}") |
| 61 | + parts.append(f"Error ({self.error_type}): {self.message}") |
| 62 | + if self.context: |
| 63 | + parts.append(f"Context:\n{self.context}") |
| 64 | + if self.suggestion: |
| 65 | + parts.append(f"Suggestion: {self.suggestion}") |
| 66 | + return "\n".join(parts) |
| 67 | + |
| 68 | + |
| 69 | +@dataclass |
| 70 | +class DagImportWarning: |
| 71 | + """Warning information for non-fatal issues during DAG import.""" |
| 72 | + |
| 73 | + file_path: str |
| 74 | + message: str |
| 75 | + warning_type: str = "general" |
| 76 | + line_number: int | None = None |
| 77 | + |
| 78 | + |
| 79 | +@dataclass |
| 80 | +class DagImportResult: |
| 81 | + """Result of importing DAGs from a file.""" |
| 82 | + |
| 83 | + file_path: str |
| 84 | + dags: list[DAG] = field(default_factory=list) |
| 85 | + errors: list[DagImportError] = field(default_factory=list) |
| 86 | + skipped_files: list[str] = field(default_factory=list) |
| 87 | + warnings: list[DagImportWarning] = field(default_factory=list) |
| 88 | + |
| 89 | + @property |
| 90 | + def success(self) -> bool: |
| 91 | + """Return True if no fatal errors occurred.""" |
| 92 | + return len(self.errors) == 0 |
| 93 | + |
| 94 | + |
| 95 | +class AbstractDagImporter(ABC): |
| 96 | + """Abstract base class for DAG importers.""" |
| 97 | + |
| 98 | + @classmethod |
| 99 | + @abstractmethod |
| 100 | + def supported_extensions(cls) -> list[str]: |
| 101 | + """Return file extensions this importer handles (e.g., ['.py', '.zip']).""" |
| 102 | + |
| 103 | + @abstractmethod |
| 104 | + def import_file( |
| 105 | + self, |
| 106 | + file_path: str | Path, |
| 107 | + *, |
| 108 | + bundle_path: Path | None = None, |
| 109 | + bundle_name: str | None = None, |
| 110 | + safe_mode: bool = True, |
| 111 | + ) -> DagImportResult: |
| 112 | + """Import DAGs from a file.""" |
| 113 | + |
| 114 | + def can_handle(self, file_path: str | Path) -> bool: |
| 115 | + """Check if this importer can handle the given file.""" |
| 116 | + path = Path(file_path) if isinstance(file_path, str) else file_path |
| 117 | + return path.suffix.lower() in self.supported_extensions() |
| 118 | + |
| 119 | + def get_relative_path(self, file_path: str | Path, bundle_path: Path | None) -> str: |
| 120 | + """Get the relative file path from the bundle root.""" |
| 121 | + if bundle_path is None: |
| 122 | + return str(file_path) |
| 123 | + try: |
| 124 | + return str(Path(file_path).relative_to(bundle_path)) |
| 125 | + except ValueError: |
| 126 | + return str(file_path) |
| 127 | + |
| 128 | + def list_dag_files( |
| 129 | + self, |
| 130 | + directory: str | os.PathLike[str], |
| 131 | + safe_mode: bool = True, |
| 132 | + ) -> Iterator[str]: |
| 133 | + """ |
| 134 | + List DAG files in a directory that this importer can handle. |
| 135 | +
|
| 136 | + Override this method to customize file discovery for your importer. |
| 137 | + The default implementation finds files matching supported_extensions() |
| 138 | + and respects .airflowignore files. |
| 139 | +
|
| 140 | + :param directory: Directory to search for DAG files |
| 141 | + :param safe_mode: Whether to use heuristics to filter non-DAG files |
| 142 | + :return: Iterator of file paths |
| 143 | + """ |
| 144 | + ignore_file_syntax = conf.get_mandatory_value("core", "DAG_IGNORE_FILE_SYNTAX", fallback="glob") |
| 145 | + supported_exts = [ext.lower() for ext in self.supported_extensions()] |
| 146 | + |
| 147 | + for file_path in find_path_from_directory(directory, ".airflowignore", ignore_file_syntax): |
| 148 | + path = Path(file_path) |
| 149 | + |
| 150 | + if not path.is_file(): |
| 151 | + continue |
| 152 | + |
| 153 | + # Check if this importer handles this file extension |
| 154 | + if path.suffix.lower() not in supported_exts: |
| 155 | + continue |
| 156 | + |
| 157 | + # Apply safe_mode heuristic if enabled |
| 158 | + if safe_mode and not might_contain_dag(file_path, safe_mode): |
| 159 | + continue |
| 160 | + |
| 161 | + yield file_path |
| 162 | + |
| 163 | + |
| 164 | +class DagImporterRegistry: |
| 165 | + """ |
| 166 | + Registry for DAG importers. Singleton that manages importers by file extension. |
| 167 | +
|
| 168 | + Each file extension can only be handled by one importer at a time. If multiple |
| 169 | + importers claim the same extension, the last registered one wins and a warning |
| 170 | + is logged. The built-in PythonDagImporter handles .py and .zip extensions. |
| 171 | + """ |
| 172 | + |
| 173 | + _instance: DagImporterRegistry | None = None |
| 174 | + _importers: dict[str, AbstractDagImporter] |
| 175 | + _lock = threading.Lock() |
| 176 | + |
| 177 | + def __new__(cls) -> DagImporterRegistry: |
| 178 | + with cls._lock: |
| 179 | + if cls._instance is None: |
| 180 | + cls._instance = super().__new__(cls) |
| 181 | + cls._instance._importers = {} |
| 182 | + cls._instance._register_default_importers() |
| 183 | + return cls._instance |
| 184 | + |
| 185 | + def _register_default_importers(self) -> None: |
| 186 | + from airflow.dag_processing.importers.python_importer import PythonDagImporter |
| 187 | + |
| 188 | + self.register(PythonDagImporter()) |
| 189 | + |
| 190 | + def register(self, importer: AbstractDagImporter) -> None: |
| 191 | + """ |
| 192 | + Register an importer for its supported extensions. |
| 193 | +
|
| 194 | + Each extension can only have one importer. If an extension is already registered, |
| 195 | + the new importer will override it and a warning will be logged. |
| 196 | + """ |
| 197 | + for ext in importer.supported_extensions(): |
| 198 | + ext_lower = ext.lower() |
| 199 | + if ext_lower in self._importers: |
| 200 | + existing = self._importers[ext_lower] |
| 201 | + log.warning( |
| 202 | + "Extension '%s' already registered by %s, overriding with %s", |
| 203 | + ext, |
| 204 | + type(existing).__name__, |
| 205 | + type(importer).__name__, |
| 206 | + ) |
| 207 | + self._importers[ext_lower] = importer |
| 208 | + |
| 209 | + def get_importer(self, file_path: str | Path) -> AbstractDagImporter | None: |
| 210 | + """Get the appropriate importer for a file, or None if unsupported.""" |
| 211 | + path = Path(file_path) if isinstance(file_path, str) else file_path |
| 212 | + return self._importers.get(path.suffix.lower()) |
| 213 | + |
| 214 | + def can_handle(self, file_path: str | Path) -> bool: |
| 215 | + """Check if any registered importer can handle this file.""" |
| 216 | + return self.get_importer(file_path) is not None |
| 217 | + |
| 218 | + def supported_extensions(self) -> list[str]: |
| 219 | + """Return all registered file extensions.""" |
| 220 | + return list(self._importers.keys()) |
| 221 | + |
| 222 | + def list_dag_files( |
| 223 | + self, |
| 224 | + directory: str | os.PathLike[str], |
| 225 | + safe_mode: bool = True, |
| 226 | + ) -> list[str]: |
| 227 | + """ |
| 228 | + List all DAG files in a directory using all registered importers. |
| 229 | +
|
| 230 | + If directory is actually a file, returns that file if any importer can handle it. |
| 231 | +
|
| 232 | + :param directory: Directory (or file) to search for DAG files |
| 233 | + :param safe_mode: Whether to use heuristics to filter non-DAG files |
| 234 | + :return: List of file paths (deduplicated) |
| 235 | + """ |
| 236 | + path = Path(directory) |
| 237 | + |
| 238 | + # If it's a file, just return it if we can handle it |
| 239 | + if path.is_file(): |
| 240 | + if self.can_handle(path): |
| 241 | + return [str(path)] |
| 242 | + return [] |
| 243 | + |
| 244 | + if not path.is_dir(): |
| 245 | + return [] |
| 246 | + |
| 247 | + seen_files: set[str] = set() |
| 248 | + file_paths: list[str] = [] |
| 249 | + |
| 250 | + for importer in set(self._importers.values()): |
| 251 | + for file_path in importer.list_dag_files(directory, safe_mode): |
| 252 | + if file_path not in seen_files: |
| 253 | + seen_files.add(file_path) |
| 254 | + file_paths.append(file_path) |
| 255 | + |
| 256 | + return file_paths |
| 257 | + |
| 258 | + @classmethod |
| 259 | + def reset(cls) -> None: |
| 260 | + """Reset the singleton (for testing).""" |
| 261 | + cls._instance = None |
| 262 | + |
| 263 | + |
| 264 | +def get_importer_registry() -> DagImporterRegistry: |
| 265 | + """Get the global importer registry instance.""" |
| 266 | + return DagImporterRegistry() |
0 commit comments