Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add start execution from triggerer support to dynamic task mapping #39912

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
1d8201e
feat(dagrun): add start_from_trigger support to mapped operator
Lee-W May 29, 2024
88a41b6
refactor(dagrun): reduplicate task scheduling if-else blocks
Lee-W May 30, 2024
b3e29ba
docs(deferring): add dynamic task mapping support to "Triggering Defe…
Lee-W May 31, 2024
7ada27d
test(dagrun): add test to start_trigger with mapped op
Lee-W May 31, 2024
5304bcf
refactor(mappedoperator): extract start trigger args expansion to map…
Lee-W Jun 21, 2024
61cfb89
feat(mapped_operator): add partial support to start_trigger_args
Lee-W Jun 22, 2024
7a94a9c
fix(mappedoperator): make mapped kwargs override partial kwargs when …
Lee-W Jun 24, 2024
f9c57d6
docs(deferring): add partial example to start_trigger
Lee-W Jun 24, 2024
1a22d0b
style(dagrun): change is True checking style
Lee-W Jun 24, 2024
2c3c64a
refactor(mappedoperator): split _expand_start_trigger into _expand_st…
Lee-W Jun 24, 2024
a94bfe2
docs(dagrun): improve the deferring section of the schedule_tis docst…
Lee-W Jun 24, 2024
340c6e3
feat(mappedoperator): do not include xcom when expanding start trigge…
Lee-W Jun 28, 2024
061b8e1
refactor: make include_xcom arg in _expand_mapped_kwargs a required a…
Lee-W Jul 12, 2024
62bef82
refactor(expandinput): make include_com in resolve a required argument
Lee-W Jul 12, 2024
b53b12b
docs(deferring): add more instruciotn for mapped op with start_from_t…
Lee-W Jul 12, 2024
0c29629
refactor(operator): make expand_start_from_trigger part of the abstra…
Lee-W Jul 19, 2024
3ee46d2
refactor(operator): make expand_start_trigger_args part of the abstra…
Lee-W Jul 19, 2024
636265e
refactor(taskinstance): change the exception type of "no exception an…
Lee-W Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,13 @@ def __attrs_post_init__(self):
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
def _expand_mapped_kwargs(
self, context: Context, session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session)
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
return {"op_kwargs": op_kwargs}, resolved_oids

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
Expand Down
22 changes: 22 additions & 0 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.task_group import TaskGroup

DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
Expand Down Expand Up @@ -427,6 +428,27 @@ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> Bas
"""
raise NotImplementedError()

def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
"""
Get the start_from_trigger value of the current abstract operator.

MappedOperator uses this to unmap start_from_trigger to decide whether to start the task
execution directly from triggerer.

:meta private:
"""
raise NotImplementedError()

def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
"""
Get the start_trigger_args value of the current abstract operator.

MappedOperator uses this to unmap start_trigger_args to decide how to start a task from triggerer.

:meta private:
"""
raise NotImplementedError()

@property
def priority_weight_total(self) -> int:
"""
Expand Down
22 changes: 22 additions & 0 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,28 @@ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> Bas
"""
return self

def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
"""
Get the start_from_trigger value of the current abstract operator.

Since a BaseOperator is not mapped to begin with, this simply returns
the original value of start_from_trigger.

:meta private:
"""
return self.start_from_trigger

def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
"""
Get the start_trigger_args value of the current abstract operator.

Since a BaseOperator is not mapped to begin with, this simply returns
the original value of start_trigger_args.

:meta private:
"""
return self.start_trigger_args


# TODO: Deprecate for Airflow 3.0
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
Expand Down
20 changes: 15 additions & 5 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,11 +1577,21 @@ def schedule_tis(
and not ti.task.outlets
):
dummy_ti_ids.append((ti.task_id, ti.map_index))
elif ti.task.start_from_trigger is True and ti.task.start_trigger_args is not None:
ti.start_date = timezone.utcnow()
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
ti.defer_task(exception=None, session=session)
# check "start_trigger_args" to see whether the operator supports start execution from triggerer
# if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer
# this task.
# if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker
elif ti.task.start_trigger_args is not None:
context = ti.get_template_context()
start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session)

if start_from_trigger:
ti.start_date = timezone.utcnow()
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
ti.try_number += 1
ti.defer_task(exception=None, session=session)
else:
schedulable_ti_ids.append((ti.task_id, ti.map_index))
else:
schedulable_ti_ids.append((ti.task_id, ti.map_index))

Expand Down
31 changes: 20 additions & 11 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from self._input.iter_references()

@provide_session
def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any:
data, _ = self._input.resolve(context, session=session)
def resolve(self, context: Context, *, include_xcom: bool, session: Session = NEW_SESSION) -> Any:
data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
return data[self._key]


Expand Down Expand Up @@ -165,9 +165,11 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
lengths = self._get_map_lengths(run_id, session=session)
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)

def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
if _needs_run_time_resolution(value):
value = value.resolve(context, session=session)
def _expand_mapped_field(
self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
) -> Any:
if include_xcom and _needs_run_time_resolution(value):
value = value.resolve(context, session=session, include_xcom=include_xcom)
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
Expand Down Expand Up @@ -203,8 +205,13 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
if isinstance(x, XComArg):
yield from x.iter_references()

def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
def resolve(
self, context: Context, session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
data = {
k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
for k, v in self.value.items()
}
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
return data, resolved_oids
Expand Down Expand Up @@ -248,7 +255,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
if isinstance(x, XComArg):
yield from x.iter_references()

def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
def resolve(
self, context: Context, session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
Expand All @@ -257,9 +266,9 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
if isinstance(self.value, collections.abc.Sized):
mapping = self.value[map_index]
if not isinstance(mapping, collections.abc.Mapping):
mapping = mapping.resolve(context, session)
else:
mappings = self.value.resolve(context, session)
mapping = mapping.resolve(context, session, include_xcom=include_xcom)
elif include_xcom:
mappings = self.value.resolve(context, session, include_xcom=include_xcom)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]
Expand Down
75 changes: 70 additions & 5 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from airflow.serialization.enums import DagAttributeTypes
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.triggers.base import StartTriggerArgs
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
Expand Down Expand Up @@ -81,7 +82,6 @@
from airflow.models.param import ParamsDict
from airflow.models.xcom_arg import XComArg
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.context import Context
from airflow.utils.operator_resources import Resources
from airflow.utils.task_group import TaskGroup
Expand Down Expand Up @@ -688,14 +688,16 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Implement DAGNode."""
return DagAttributeTypes.OP, self.task_id

def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
def _expand_mapped_kwargs(
self, context: Context, session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
"""
Get the kwargs to create the unmapped operator.

This exists because taskflow operators expand against op_kwargs, not the
entire operator kwargs dict.
"""
return self._get_specified_expand_input().resolve(context, session)
return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom)

def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -729,6 +731,69 @@ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -
"params": params,
}

def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
"""
Get the start_from_trigger value of the current abstract operator.

MappedOperator uses this to unmap start_from_trigger to decide whether to start the task
execution directly from triggerer.

:meta private:
"""
# start_from_trigger only makes sense when start_trigger_args exists.
if not self.start_trigger_args:
return False

mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False)
if self._disallow_kwargs_override:
prevent_duplicates(
self.partial_kwargs,
mapped_kwargs,
fail_reason="unmappable or already specified",
)

# Ordering is significant; mapped kwargs should override partial ones.
return mapped_kwargs.get(
"start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger)
)

def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
"""
Get the kwargs to create the unmapped start_trigger_args.

This method is for allowing mapped operator to start execution from triggerer.
"""
if not self.start_trigger_args:
return None

mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False)
if self._disallow_kwargs_override:
prevent_duplicates(
self.partial_kwargs,
mapped_kwargs,
fail_reason="unmappable or already specified",
)

# Ordering is significant; mapped kwargs should override partial ones.
trigger_kwargs = mapped_kwargs.get(
"trigger_kwargs",
self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs),
)
next_kwargs = mapped_kwargs.get(
"next_kwargs",
self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs),
)
timeout = mapped_kwargs.get(
"trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout)
)
return StartTriggerArgs(
trigger_cls=self.start_trigger_args.trigger_cls,
trigger_kwargs=trigger_kwargs,
next_method=self.start_trigger_args.next_method,
next_kwargs=next_kwargs,
timeout=timeout,
)

def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
"""
Get the "normal" Operator after applying the current mapping.
Expand All @@ -749,7 +814,7 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
if isinstance(resolve, collections.abc.Mapping):
kwargs = resolve
elif resolve is not None:
kwargs, _ = self._expand_mapped_kwargs(*resolve)
kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True)
else:
raise RuntimeError("cannot unmap a non-serialized operator without context")
kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
Expand Down Expand Up @@ -844,7 +909,7 @@ def render_template_fields(
# set_current_task_session context manager to store the session in the current task.
session = get_current_task_instance_session()

mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True)
unmapped_task = self.unmap(mapped_kwargs)
context_update_for_unmapped(context, unmapped_task)

Expand Down
2 changes: 1 addition & 1 deletion airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Context) -> Any:
def resolve(self, context: Context, *, include_xcom: bool) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
with contextlib.suppress(KeyError):
return context["dag_run"].conf[self._name]
Expand Down
19 changes: 14 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
AirflowTaskTimeout,
DagRunNotFound,
RemovedInAirflow3Warning,
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
Expand Down Expand Up @@ -1617,15 +1618,23 @@ def _defer_task(
next_kwargs = exception.kwargs
timeout = exception.timeout
elif ti.task is not None and ti.task.start_trigger_args is not None:
context = ti.get_template_context()
start_trigger_args = ti.task.expand_start_trigger_args(context=context, session=session)
if start_trigger_args is None:
raise TaskDeferralError(
"A none 'None' start_trigger_args has been change to 'None' during expandion"
)

trigger_kwargs = start_trigger_args.trigger_kwargs or {}
next_kwargs = start_trigger_args.next_kwargs
next_method = start_trigger_args.next_method
timeout = start_trigger_args.timeout
trigger_row = Trigger(
classpath=ti.task.start_trigger_args.trigger_cls,
kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
kwargs=trigger_kwargs,
)
next_kwargs = ti.task.start_trigger_args.next_kwargs
next_method = ti.task.start_trigger_args.next_method
timeout = ti.task.start_trigger_args.timeout
else:
raise AirflowException("exception and ti.task.start_trigger_args cannot both be None")
raise TaskDeferralError("exception and ti.task.start_trigger_args cannot both be None")

# First, make the trigger entry
session.add(trigger_row)
Expand Down
16 changes: 8 additions & 8 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
raise NotImplementedError()

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
"""
Pull XCom value.

Expand Down Expand Up @@ -437,7 +437,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
ti = context["ti"]
if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
Expand Down Expand Up @@ -551,8 +551,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return self.arg.get_task_map_length(run_id, session=session)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
value = self.arg.resolve(context, session=session)
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
value = self.arg.resolve(context, session=session, include_xcom=include_xcom)
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
return _MapResult(value, self.callables)
Expand Down Expand Up @@ -632,8 +632,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return max(ready_lengths)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
values = [arg.resolve(context, session=session) for arg in self.args]
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
Expand Down Expand Up @@ -707,8 +707,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return sum(ready_lengths)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
values = [arg.resolve(context, session=session) for arg in self.args]
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}")
Expand Down
Loading