diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 3819d6d135fda..f351fd593eacf 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -16,11 +16,12 @@ # specific language governing permissions and limitations # under the License. -from typing import Dict, Iterable, Tuple +from typing import Dict, Iterable, Optional, Tuple from sqlalchemy import Column, Integer, String, Text, func from sqlalchemy.orm.session import Session +from airflow.exceptions import AirflowException from airflow.models.base import Base from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.typing_compat import TypedDict @@ -33,6 +34,7 @@ class PoolStats(TypedDict): total: int running: int queued: int + open: int class Pool(Base): @@ -79,7 +81,7 @@ def get_default_pool(session: Session = None): @provide_session def slots_stats(session: Session = None) -> Dict[str, PoolStats]: """ - Get Pool stats (Number of Running, Queued & Total tasks) + Get Pool stats (Number of Running, Queued, Open & Total tasks) :param session: SQLAlchemy ORM Session """ @@ -89,7 +91,7 @@ def slots_stats(session: Session = None) -> Dict[str, PoolStats]: pool_rows: Iterable[Tuple[str, int]] = session.query(Pool.pool, Pool.slots).all() for (pool_name, total_slots) in pool_rows: - pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0) + pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0) state_count_by_pool = ( session.query(TaskInstance.pool, TaskInstance.state, func.count()) @@ -98,19 +100,28 @@ def slots_stats(session: Session = None) -> Dict[str, PoolStats]: ).all() # calculate queued and running metrics + count: int for (pool_name, state, count) in state_count_by_pool: - pool = pools.get(pool_name) - if not pool: + stats_dict: Optional[PoolStats] = pools.get(pool_name) + if not stats_dict: continue - pool[state] = count + # TypedDict key must be a string literal, so we use if-statements to set value + if state == "running": + stats_dict["running"] = count + elif state == "queued": + stats_dict["queued"] = count + else: + raise AirflowException( + f"Unexpected state. Expected values: {EXECUTION_STATES}." + ) # calculate open metric - for pool_name, stats in pools.items(): - if stats["total"] == -1: + for pool_name, stats_dict in pools.items(): + if stats_dict["total"] == -1: # -1 means infinite - stats["open"] = -1 + stats_dict["open"] = -1 else: - stats["open"] = stats["total"] - stats[State.RUNNING] - stats[State.QUEUED] + stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"] return pools