Skip to content

Commit bd9f532

Browse files
Lee-Wromsharon98
authored andcommittedJul 26, 2024
Add start execution from triggerer support to dynamic task mapping (apache#39912)
* feat(dagrun): add start_from_trigger support to mapped operator * feat(mapped_operator): add partial support to start_trigger_args * feat(mappedoperator): do not include xcom when expanding start trigger args and flag
1 parent 106babe commit bd9f532

File tree

13 files changed

+254
-43
lines changed

13 files changed

+254
-43
lines changed
 

Diff for: ‎airflow/decorators/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,13 @@ def __attrs_post_init__(self):
550550
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
551551
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)
552552

553-
def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
553+
def _expand_mapped_kwargs(
554+
self, context: Context, session: Session, *, include_xcom: bool
555+
) -> tuple[Mapping[str, Any], set[int]]:
554556
# We only use op_kwargs_expand_input so this must always be empty.
555557
if self.expand_input is not EXPAND_INPUT_EMPTY:
556558
raise AssertionError(f"unexpected expand_input: {self.expand_input}")
557-
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session)
559+
op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session, include_xcom=include_xcom)
558560
return {"op_kwargs": op_kwargs}, resolved_oids
559561

560562
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:

Diff for: ‎airflow/models/abstractoperator.py

+22
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from airflow.models.operator import Operator
5656
from airflow.models.taskinstance import TaskInstance
5757
from airflow.task.priority_strategy import PriorityWeightStrategy
58+
from airflow.triggers.base import StartTriggerArgs
5859
from airflow.utils.task_group import TaskGroup
5960

6061
DEFAULT_OWNER: str = conf.get_mandatory_value("operators", "default_owner")
@@ -427,6 +428,27 @@ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> Bas
427428
"""
428429
raise NotImplementedError()
429430

431+
def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
432+
"""
433+
Get the start_from_trigger value of the current abstract operator.
434+
435+
MappedOperator uses this to unmap start_from_trigger to decide whether to start the task
436+
execution directly from triggerer.
437+
438+
:meta private:
439+
"""
440+
raise NotImplementedError()
441+
442+
def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
443+
"""
444+
Get the start_trigger_args value of the current abstract operator.
445+
446+
MappedOperator uses this to unmap start_trigger_args to decide how to start a task from triggerer.
447+
448+
:meta private:
449+
"""
450+
raise NotImplementedError()
451+
430452
@property
431453
def priority_weight_total(self) -> int:
432454
"""

Diff for: ‎airflow/models/baseoperator.py

+22
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,28 @@ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> Bas
17951795
"""
17961796
return self
17971797

1798+
def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
1799+
"""
1800+
Get the start_from_trigger value of the current abstract operator.
1801+
1802+
Since a BaseOperator is not mapped to begin with, this simply returns
1803+
the original value of start_from_trigger.
1804+
1805+
:meta private:
1806+
"""
1807+
return self.start_from_trigger
1808+
1809+
def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
1810+
"""
1811+
Get the start_trigger_args value of the current abstract operator.
1812+
1813+
Since a BaseOperator is not mapped to begin with, this simply returns
1814+
the original value of start_trigger_args.
1815+
1816+
:meta private:
1817+
"""
1818+
return self.start_trigger_args
1819+
17981820

17991821
# TODO: Deprecate for Airflow 3.0
18001822
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]

Diff for: ‎airflow/models/dagrun.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -1577,11 +1577,21 @@ def schedule_tis(
15771577
and not ti.task.outlets
15781578
):
15791579
dummy_ti_ids.append((ti.task_id, ti.map_index))
1580-
elif ti.task.start_from_trigger is True and ti.task.start_trigger_args is not None:
1581-
ti.start_date = timezone.utcnow()
1582-
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
1583-
ti.try_number += 1
1584-
ti.defer_task(exception=None, session=session)
1580+
# check "start_trigger_args" to see whether the operator supports start execution from triggerer
1581+
# if so, we'll then check "start_from_trigger" to see whether this feature is turned on and defer
1582+
# this task.
1583+
# if not, we'll add this "ti" into "schedulable_ti_ids" and later execute it to run in the worker
1584+
elif ti.task.start_trigger_args is not None:
1585+
context = ti.get_template_context()
1586+
start_from_trigger = ti.task.expand_start_from_trigger(context=context, session=session)
1587+
1588+
if start_from_trigger:
1589+
ti.start_date = timezone.utcnow()
1590+
if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE:
1591+
ti.try_number += 1
1592+
ti.defer_task(exception=None, session=session)
1593+
else:
1594+
schedulable_ti_ids.append((ti.task_id, ti.map_index))
15851595
else:
15861596
schedulable_ti_ids.append((ti.task_id, ti.map_index))
15871597

Diff for: ‎airflow/models/expandinput.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
6969
yield from self._input.iter_references()
7070

7171
@provide_session
72-
def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any:
73-
data, _ = self._input.resolve(context, session=session)
72+
def resolve(self, context: Context, *, include_xcom: bool, session: Session = NEW_SESSION) -> Any:
73+
data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
7474
return data[self._key]
7575

7676

@@ -165,9 +165,11 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
165165
lengths = self._get_map_lengths(run_id, session=session)
166166
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
167167

168-
def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
169-
if _needs_run_time_resolution(value):
170-
value = value.resolve(context, session=session)
168+
def _expand_mapped_field(
169+
self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
170+
) -> Any:
171+
if include_xcom and _needs_run_time_resolution(value):
172+
value = value.resolve(context, session=session, include_xcom=include_xcom)
171173
map_index = context["ti"].map_index
172174
if map_index < 0:
173175
raise RuntimeError("can't resolve task-mapping argument without expanding")
@@ -203,8 +205,13 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
203205
if isinstance(x, XComArg):
204206
yield from x.iter_references()
205207

206-
def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
207-
data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
208+
def resolve(
209+
self, context: Context, session: Session, *, include_xcom: bool
210+
) -> tuple[Mapping[str, Any], set[int]]:
211+
data = {
212+
k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
213+
for k, v in self.value.items()
214+
}
208215
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
209216
resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
210217
return data, resolved_oids
@@ -248,7 +255,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
248255
if isinstance(x, XComArg):
249256
yield from x.iter_references()
250257

251-
def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
258+
def resolve(
259+
self, context: Context, session: Session, *, include_xcom: bool
260+
) -> tuple[Mapping[str, Any], set[int]]:
252261
map_index = context["ti"].map_index
253262
if map_index < 0:
254263
raise RuntimeError("can't resolve task-mapping argument without expanding")
@@ -257,9 +266,9 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
257266
if isinstance(self.value, collections.abc.Sized):
258267
mapping = self.value[map_index]
259268
if not isinstance(mapping, collections.abc.Mapping):
260-
mapping = mapping.resolve(context, session)
261-
else:
262-
mappings = self.value.resolve(context, session)
269+
mapping = mapping.resolve(context, session, include_xcom=include_xcom)
270+
elif include_xcom:
271+
mappings = self.value.resolve(context, session, include_xcom=include_xcom)
263272
if not isinstance(mappings, collections.abc.Sequence):
264273
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
265274
mapping = mappings[map_index]

Diff for: ‎airflow/models/mappedoperator.py

+70-5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from airflow.serialization.enums import DagAttributeTypes
5252
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
5353
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
54+
from airflow.triggers.base import StartTriggerArgs
5455
from airflow.typing_compat import Literal
5556
from airflow.utils.context import context_update_for_unmapped
5657
from airflow.utils.helpers import is_container, prevent_duplicates
@@ -81,7 +82,6 @@
8182
from airflow.models.param import ParamsDict
8283
from airflow.models.xcom_arg import XComArg
8384
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
84-
from airflow.triggers.base import StartTriggerArgs
8585
from airflow.utils.context import Context
8686
from airflow.utils.operator_resources import Resources
8787
from airflow.utils.task_group import TaskGroup
@@ -688,14 +688,16 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
688688
"""Implement DAGNode."""
689689
return DagAttributeTypes.OP, self.task_id
690690

691-
def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
691+
def _expand_mapped_kwargs(
692+
self, context: Context, session: Session, *, include_xcom: bool
693+
) -> tuple[Mapping[str, Any], set[int]]:
692694
"""
693695
Get the kwargs to create the unmapped operator.
694696
695697
This exists because taskflow operators expand against op_kwargs, not the
696698
entire operator kwargs dict.
697699
"""
698-
return self._get_specified_expand_input().resolve(context, session)
700+
return self._get_specified_expand_input().resolve(context, session, include_xcom=include_xcom)
699701

700702
def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
701703
"""
@@ -729,6 +731,69 @@ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -
729731
"params": params,
730732
}
731733

734+
def expand_start_from_trigger(self, *, context: Context, session: Session) -> bool:
735+
"""
736+
Get the start_from_trigger value of the current abstract operator.
737+
738+
MappedOperator uses this to unmap start_from_trigger to decide whether to start the task
739+
execution directly from triggerer.
740+
741+
:meta private:
742+
"""
743+
# start_from_trigger only makes sense when start_trigger_args exists.
744+
if not self.start_trigger_args:
745+
return False
746+
747+
mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False)
748+
if self._disallow_kwargs_override:
749+
prevent_duplicates(
750+
self.partial_kwargs,
751+
mapped_kwargs,
752+
fail_reason="unmappable or already specified",
753+
)
754+
755+
# Ordering is significant; mapped kwargs should override partial ones.
756+
return mapped_kwargs.get(
757+
"start_from_trigger", self.partial_kwargs.get("start_from_trigger", self.start_from_trigger)
758+
)
759+
760+
def expand_start_trigger_args(self, *, context: Context, session: Session) -> StartTriggerArgs | None:
761+
"""
762+
Get the kwargs to create the unmapped start_trigger_args.
763+
764+
This method is for allowing mapped operator to start execution from triggerer.
765+
"""
766+
if not self.start_trigger_args:
767+
return None
768+
769+
mapped_kwargs, _ = self._expand_mapped_kwargs(context, session, include_xcom=False)
770+
if self._disallow_kwargs_override:
771+
prevent_duplicates(
772+
self.partial_kwargs,
773+
mapped_kwargs,
774+
fail_reason="unmappable or already specified",
775+
)
776+
777+
# Ordering is significant; mapped kwargs should override partial ones.
778+
trigger_kwargs = mapped_kwargs.get(
779+
"trigger_kwargs",
780+
self.partial_kwargs.get("trigger_kwargs", self.start_trigger_args.trigger_kwargs),
781+
)
782+
next_kwargs = mapped_kwargs.get(
783+
"next_kwargs",
784+
self.partial_kwargs.get("next_kwargs", self.start_trigger_args.next_kwargs),
785+
)
786+
timeout = mapped_kwargs.get(
787+
"trigger_timeout", self.partial_kwargs.get("trigger_timeout", self.start_trigger_args.timeout)
788+
)
789+
return StartTriggerArgs(
790+
trigger_cls=self.start_trigger_args.trigger_cls,
791+
trigger_kwargs=trigger_kwargs,
792+
next_method=self.start_trigger_args.next_method,
793+
next_kwargs=next_kwargs,
794+
timeout=timeout,
795+
)
796+
732797
def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
733798
"""
734799
Get the "normal" Operator after applying the current mapping.
@@ -749,7 +814,7 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
749814
if isinstance(resolve, collections.abc.Mapping):
750815
kwargs = resolve
751816
elif resolve is not None:
752-
kwargs, _ = self._expand_mapped_kwargs(*resolve)
817+
kwargs, _ = self._expand_mapped_kwargs(*resolve, include_xcom=True)
753818
else:
754819
raise RuntimeError("cannot unmap a non-serialized operator without context")
755820
kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
@@ -844,7 +909,7 @@ def render_template_fields(
844909
# set_current_task_session context manager to store the session in the current task.
845910
session = get_current_task_instance_session()
846911

847-
mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
912+
mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session, include_xcom=True)
848913
unmapped_task = self.unmap(mapped_kwargs)
849914
context_update_for_unmapped(context, unmapped_task)
850915

Diff for: ‎airflow/models/param.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
329329
def iter_references(self) -> Iterable[tuple[Operator, str]]:
330330
return ()
331331

332-
def resolve(self, context: Context) -> Any:
332+
def resolve(self, context: Context, *, include_xcom: bool) -> Any:
333333
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
334334
with contextlib.suppress(KeyError):
335335
return context["dag_run"].conf[self._name]

Diff for: ‎airflow/models/taskinstance.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
AirflowTaskTimeout,
8484
DagRunNotFound,
8585
RemovedInAirflow3Warning,
86+
TaskDeferralError,
8687
TaskDeferred,
8788
UnmappableXComLengthPushed,
8889
UnmappableXComTypePushed,
@@ -1617,15 +1618,23 @@ def _defer_task(
16171618
next_kwargs = exception.kwargs
16181619
timeout = exception.timeout
16191620
elif ti.task is not None and ti.task.start_trigger_args is not None:
1621+
context = ti.get_template_context()
1622+
start_trigger_args = ti.task.expand_start_trigger_args(context=context, session=session)
1623+
if start_trigger_args is None:
1624+
raise TaskDeferralError(
1625+
"A none 'None' start_trigger_args has been change to 'None' during expandion"
1626+
)
1627+
1628+
trigger_kwargs = start_trigger_args.trigger_kwargs or {}
1629+
next_kwargs = start_trigger_args.next_kwargs
1630+
next_method = start_trigger_args.next_method
1631+
timeout = start_trigger_args.timeout
16201632
trigger_row = Trigger(
16211633
classpath=ti.task.start_trigger_args.trigger_cls,
1622-
kwargs=ti.task.start_trigger_args.trigger_kwargs or {},
1634+
kwargs=trigger_kwargs,
16231635
)
1624-
next_kwargs = ti.task.start_trigger_args.next_kwargs
1625-
next_method = ti.task.start_trigger_args.next_method
1626-
timeout = ti.task.start_trigger_args.timeout
16271636
else:
1628-
raise AirflowException("exception and ti.task.start_trigger_args cannot both be None")
1637+
raise TaskDeferralError("exception and ti.task.start_trigger_args cannot both be None")
16291638

16301639
# First, make the trigger entry
16311640
session.add(trigger_row)

Diff for: ‎airflow/models/xcom_arg.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
208208
raise NotImplementedError()
209209

210210
@provide_session
211-
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
211+
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
212212
"""
213213
Pull XCom value.
214214
@@ -437,7 +437,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
437437
)
438438

439439
@provide_session
440-
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
440+
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
441441
ti = context["ti"]
442442
if TYPE_CHECKING:
443443
assert isinstance(ti, TaskInstance)
@@ -551,8 +551,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
551551
return self.arg.get_task_map_length(run_id, session=session)
552552

553553
@provide_session
554-
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
555-
value = self.arg.resolve(context, session=session)
554+
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
555+
value = self.arg.resolve(context, session=session, include_xcom=include_xcom)
556556
if not isinstance(value, (Sequence, dict)):
557557
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
558558
return _MapResult(value, self.callables)
@@ -632,8 +632,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
632632
return max(ready_lengths)
633633

634634
@provide_session
635-
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
636-
values = [arg.resolve(context, session=session) for arg in self.args]
635+
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
636+
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
637637
for value in values:
638638
if not isinstance(value, (Sequence, dict)):
639639
raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
@@ -707,8 +707,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
707707
return sum(ready_lengths)
708708

709709
@provide_session
710-
def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
711-
values = [arg.resolve(context, session=session) for arg in self.args]
710+
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool) -> Any:
711+
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
712712
for value in values:
713713
if not isinstance(value, (Sequence, dict)):
714714
raise ValueError(f"XCom concat expects sequence or dict, not {type(value).__name__}")

0 commit comments

Comments
 (0)