Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement context accessor for DatasetEvent extra #38481

Merged
merged 5 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,11 @@ repos:
entry: ./scripts/ci/pre_commit/pre_commit_sync_init_decorator.py
pass_filenames: false
files: ^airflow/models/dag\.py$|^airflow/(?:decorators|utils)/task_group\.py$
- id: check-template-context-variable-in-sync
name: Check all template context variable references are in sync
language: python
entry: ./scripts/ci/pre_commit/pre_commit_template_context_key_sync.py
files: ^airflow/models/taskinstance\.py$|^airflow/utils/context\.pyi?$|^docs/apache-airflow/templates-ref\.rst$
- id: check-base-operator-usage
language: pygrep
name: Check BaseOperator core imports
Expand Down
11 changes: 9 additions & 2 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | N
return ProvidersManager().dataset_uri_handlers.get(scheme)


def _sanitize_uri(uri: str) -> str:
def sanitize_uri(uri: str) -> str:
"""Sanitize a dataset URI.

This checks for URI validity, and normalizes the URI if needed. A fully
normalized URI is returned.

:meta private:
"""
if not uri:
raise ValueError("Dataset URI cannot be empty")
if uri.isspace():
Expand Down Expand Up @@ -110,7 +117,7 @@ class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(
converter=_sanitize_uri,
converter=sanitize_uri,
validator=[attr.validators.min_len(1), attr.validators.max_len(3000)],
)
extra: dict[str, Any] | None = None
Expand Down
14 changes: 11 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.utils import timezone
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge
from airflow.utils.context import (
ConnectionAccessor,
Context,
DatasetEventAccessors,
VariableAccessor,
context_merge,
)
from airflow.utils.email import send_email
from airflow.utils.helpers import prune_dict, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -766,6 +772,7 @@ def get_triggering_events() -> dict[str, list[DatasetEvent | DatasetEventPydanti
"dag_run": dag_run,
"data_interval_end": timezone.coerce_datetime(data_interval.end),
"data_interval_start": timezone.coerce_datetime(data_interval.start),
"dataset_events": DatasetEventAccessors(),
"ds": ds,
"ds_nodash": ds_nodash,
"execution_date": logical_date,
Expand Down Expand Up @@ -2569,7 +2576,7 @@ def _run_raw_task(
session.add(Log(self.state, self))
session.merge(self).task = self.task
if self.state == TaskInstanceState.SUCCESS:
self._register_dataset_changes(session=session)
self._register_dataset_changes(events=context["dataset_events"], session=session)

session.commit()
if self.state == TaskInstanceState.SUCCESS:
Expand All @@ -2579,7 +2586,7 @@ def _run_raw_task(

return None

def _register_dataset_changes(self, *, session: Session) -> None:
def _register_dataset_changes(self, *, events: DatasetEventAccessors, session: Session) -> None:
if TYPE_CHECKING:
assert self.task

Expand All @@ -2590,6 +2597,7 @@ def _register_dataset_changes(self, *, session: Session) -> None:
dataset_manager.register_dataset_change(
task_instance=self,
dataset=obj,
extra=events[obj].extra,
session=session,
)

Expand Down
34 changes: 34 additions & 0 deletions airflow/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@
ValuesView,
)

import attrs
import lazy_object_proxy

from airflow.datasets import Dataset, sanitize_uri
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils.types import NOTSET

Expand All @@ -54,6 +56,7 @@
"dag_run",
"data_interval_end",
"data_interval_start",
"dataset_events",
"ds",
"ds_nodash",
"execution_date",
Expand Down Expand Up @@ -146,6 +149,37 @@ def get(self, key: str, default_conn: Any = None) -> Any:
return default_conn


@attrs.define()
class DatasetEventAccessor:
"""Wrapper to access a DatasetEvent instance in template."""

extra: dict[str, Any]


class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
"""Lazy mapping of dataset event accessors."""

def __init__(self) -> None:
self._dict: dict[str, DatasetEventAccessor] = {}

def __iter__(self) -> Iterator[str]:
return iter(self._dict)

def __len__(self) -> int:
return len(self._dict)

def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor:
if isinstance(key, str):
uri = sanitize_uri(key)
elif isinstance(key, Dataset):
uri = key.uri
else:
return NotImplemented
if uri not in self._dict:
self._dict[uri] = DatasetEventAccessor({})
return self._dict[uri]


class AirflowContextDeprecationWarning(RemovedInAirflow3Warning):
"""Warn for usage of deprecated context variables in a task."""

Expand Down
12 changes: 11 additions & 1 deletion airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@
# declare "these are defined, but don't error if others are accessed" someday.
from __future__ import annotations

from typing import Any, Collection, Container, Iterable, Mapping, overload
from typing import Any, Collection, Container, Iterable, Iterator, Mapping, overload

from pendulum import DateTime

from airflow.configuration import AirflowConfigParser
from airflow.datasets import Dataset
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
Expand All @@ -55,6 +56,14 @@ class VariableAccessor:
class ConnectionAccessor:
def get(self, key: str, default_conn: Any = None) -> Any: ...

class DatasetEventAccessor:
extra: dict[str, Any]

class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
def __iter__(self) -> Iterator[str]: ...
def __len__(self) -> int: ...
def __getitem__(self, key: str | Dataset) -> DatasetEventAccessor: ...

# NOTE: Please keep this in sync with the following:
# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py
# * Table in docs/apache-airflow/templates-ref.rst
Expand All @@ -65,6 +74,7 @@ class Context(TypedDict, total=False):
dag_run: DagRun | DagRunPydantic
data_interval_end: DateTime
data_interval_start: DateTime
dataset_events: DatasetEventAccessors
ds: str
ds_nodash: str
exception: BaseException | str | None
Expand Down
2 changes: 2 additions & 0 deletions contributing-docs/08_static_code_checks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ require Breeze Docker image to be built locally.
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-system-tests-tocs | Check that system tests is properly added | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-template-context-variable-in-sync | Check all template context variable references are in sync | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-tests-in-the-right-folders | Check if tests are in the right folders | |
+-----------------------------------------------------------+--------------------------------------------------------------+---------+
| check-tests-unittest-testcase | Check that unit tests do not inherit from unittest.TestCase | |
Expand Down
Loading