diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 0b0463713e29f..d0831adb6496f 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1696,10 +1696,18 @@ def _schedule_dag_run( and dag.dagrun_timeout and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout ): - dag_run.state = State.FAILED - dag_run.end_date = timezone.utcnow() - self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id) + dag_run.set_state(State.FAILED) + unfinished_task_instances = ( + session.query(TI) + .filter(TI.dag_id == dag_run.dag_id) + .filter(TI.execution_date == dag_run.execution_date) + .filter(TI.state.in_(State.unfinished)) + ) + for task_instance in unfinished_task_instances: + task_instance.state = State.SKIPPED + session.merge(task_instance) session.flush() + self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id) # Work out if we should allow creating a new DagRun now? self._update_dag_next_dagruns([session.query(DagModel).get(dag_run.dag_id)], session) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index ac5a837d7c4d2..097317087e9f2 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -23,6 +23,7 @@ import unittest from datetime import timedelta from tempfile import NamedTemporaryFile, mkdtemp +from time import sleep from unittest import mock from unittest.mock import MagicMock, patch from zipfile import ZipFile @@ -3813,6 +3814,68 @@ def test_do_schedule_max_active_runs_upstream_failed(self): ti = run2.get_task_instance(task1.task_id, session) assert ti.state == State.QUEUED + def test_do_schedule_max_active_runs_dag_timed_out(self): + """Test that tasks are set to a finished state when their DAG times out""" + + dag = DAG( + dag_id='test_max_active_run_with_dag_timed_out', + start_date=DEFAULT_DATE, + schedule_interval='@once', + max_active_runs=1, + catchup=True, + ) + dag.dagrun_timeout = datetime.timedelta(seconds=1) + + with dag: + task1 = BashOperator( + task_id='task1', + bash_command=' for((i=1;i<=600;i+=1)); do sleep "$i"; done', + ) + + session = settings.Session() + dagbag = DagBag( + dag_folder=os.devnull, + include_examples=False, + read_dags_from_db=True, + ) + + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db(session=session) + + run1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + session=session, + ) + run1_ti = run1.get_task_instance(task1.task_id, session) + run1_ti.state = State.RUNNING + + sleep(1) + + run2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE + timedelta(seconds=10), + state=State.RUNNING, + session=session, + ) + + dag.sync_to_db(session=session) + + job = SchedulerJob(subdir=os.devnull) + job.executor = MockExecutor() + job.processor_agent = mock.MagicMock(spec=DagFileProcessorAgent) + + _ = job._do_scheduling(session) + + assert run1.state == State.FAILED + assert run1_ti.state == State.SKIPPED + assert run2.state == State.RUNNING + + _ = job._do_scheduling(session) + run2_ti = run2.get_task_instance(task1.task_id, session) + assert run2_ti.state == State.QUEUED + def test_do_schedule_max_active_runs_task_removed(self): """Test that tasks in removed state don't count as actively running."""