Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typing in airflow/models/pool.py #9835

Merged
merged 1 commit into from
Jul 15, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions airflow/models/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict, Iterable, Tuple
from typing import Dict, Iterable, Optional, Tuple

from sqlalchemy import Column, Integer, String, Text, func
from sqlalchemy.orm.session import Session

from airflow.exceptions import AirflowException
from airflow.models.base import Base
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.typing_compat import TypedDict
Expand All @@ -33,6 +34,7 @@ class PoolStats(TypedDict):
total: int
running: int
queued: int
open: int


class Pool(Base):
Expand Down Expand Up @@ -79,7 +81,7 @@ def get_default_pool(session: Session = None):
@provide_session
def slots_stats(session: Session = None) -> Dict[str, PoolStats]:
"""
Get Pool stats (Number of Running, Queued & Total tasks)
Get Pool stats (Number of Running, Queued, Open & Total tasks)

:param session: SQLAlchemy ORM Session
"""
Expand All @@ -89,7 +91,7 @@ def slots_stats(session: Session = None) -> Dict[str, PoolStats]:

pool_rows: Iterable[Tuple[str, int]] = session.query(Pool.pool, Pool.slots).all()
for (pool_name, total_slots) in pool_rows:
pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0)
pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0)

state_count_by_pool = (
session.query(TaskInstance.pool, TaskInstance.state, func.count())
Expand All @@ -98,19 +100,28 @@ def slots_stats(session: Session = None) -> Dict[str, PoolStats]:
).all()

# calculate queued and running metrics
count: int
for (pool_name, state, count) in state_count_by_pool:
pool = pools.get(pool_name)
if not pool:
stats_dict: Optional[PoolStats] = pools.get(pool_name)
if not stats_dict:
continue
pool[state] = count
# TypedDict key must be a string literal, so we use if-statements to set value
if state == "running":
stats_dict["running"] = count
elif state == "queued":
stats_dict["queued"] = count
else:
raise AirflowException(
f"Unexpected state. Expected values: {EXECUTION_STATES}."
)

# calculate open metric
for pool_name, stats in pools.items():
if stats["total"] == -1:
for pool_name, stats_dict in pools.items():
if stats_dict["total"] == -1:
# -1 means infinite
stats["open"] = -1
stats_dict["open"] = -1
else:
stats["open"] = stats["total"] - stats[State.RUNNING] - stats[State.QUEUED]
stats_dict["open"] = stats_dict["total"] - stats_dict["running"] - stats_dict["queued"]

return pools

Expand Down