Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add uncached downloading support #917

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/program/db/db_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,38 @@
from typing import TYPE_CHECKING

from loguru import logger
from sqlalchemy import delete, exists, insert, inspect, or_, select, text
from sqlalchemy import delete, insert, inspect, select, text
from sqlalchemy.orm import Session, joinedload, selectinload

import alembic
from program.media.stream import Stream, StreamBlacklistRelation, StreamRelation
from program.services.libraries.symlink import fix_broken_symlinks
from program.settings.manager import settings_manager
from program.utils import root_dir
from program.media.state import States

from .db import db

if TYPE_CHECKING:
from program.media.item import MediaItem

def get_items_by_state(state: States, session = None):
from program.media.item import MediaItem, Season, Show

_session = session if session else db.Session()
with _session:
query = (select(MediaItem)
.where(MediaItem.state == state)
.options(
selectinload(Show.seasons)
.selectinload(Season.episodes)
))
items = _session.execute(query).scalars().all()
for item in items:
_session.expunge(item)

return items

def get_item_by_id(item_id: str, item_types = None, session = None):
if not item_id:
return None
Expand Down Expand Up @@ -185,10 +203,6 @@ def reset_streams(item: "MediaItem"):
)
session.commit()

def clear_streams(item: "MediaItem"):
"""Clear all streams for a media item."""
reset_streams(item)

def clear_streams_by_id(media_item_id: str):
"""Clear all streams for a media item by the MediaItem id."""
with db.Session() as session:
Expand Down
82 changes: 55 additions & 27 deletions src/program/media/item.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
"""MediaItem class"""
from datetime import datetime
from enum import Enum
import json
from pathlib import Path
from typing import List, Optional, Self
from typing import List, Optional, Self, Type

from pydantic import BaseModel
import sqlalchemy
from loguru import logger
from RTN import parse
from sqlalchemy import Index
from sqlalchemy import JSON, Index, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, object_session, relationship

from program.db.db import db
from program.managers.sse_manager import sse_manager
from program.media.state import States
from program.media.subtitle import Subtitle
from program.services.downloaders.models import DownloadedTorrent

from ..db.db_functions import blacklist_stream, reset_streams
from ..db.db_functions import blacklist_stream, reset_streams as db_reset_streams
from .stream import Stream

class PydanticType(TypeDecorator):
"""Custom SQLAlchemy type for Pydantic models"""

impl = JSON
cache_ok = True

def __init__(self, pydantic_model: Type[BaseModel], *args, **kwargs):
super().__init__(*args, **kwargs)
self.pydantic_model = pydantic_model

def process_bind_param(self, value, dialect):
"""Convert Pydantic model to JSON when saving to DB"""
if value is None:
return None
if isinstance(value, dict):
return value
return value.model_dump(mode="json")

def process_result_value(self, value, dialect):
"""Convert JSON to Pydantic model when loading from DB"""
if value is None:
return None
return self.pydantic_model.model_validate(value)


class MediaItem(db.Model):
"""MediaItem class"""
Expand All @@ -34,7 +60,7 @@ class MediaItem(db.Model):
indexed_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True)
scraped_at: Mapped[Optional[datetime]] = mapped_column(sqlalchemy.DateTime, nullable=True)
scraped_times: Mapped[Optional[int]] = mapped_column(sqlalchemy.Integer, default=0)
active_stream: Mapped[Optional[dict]] = mapped_column(sqlalchemy.JSON, nullable=True)
active_stream: Mapped[DownloadedTorrent] = mapped_column(PydanticType(DownloadedTorrent))
streams: Mapped[list[Stream]] = relationship(secondary="StreamRelation", back_populates="parents", lazy="selectin", cascade="all")
blacklisted_streams: Mapped[list[Stream]] = relationship(secondary="StreamBlacklistRelation", back_populates="blacklisted_parents", lazy="selectin", cascade="all")
symlinked: Mapped[Optional[bool]] = mapped_column(sqlalchemy.Boolean, default=False)
Expand Down Expand Up @@ -95,7 +121,7 @@ def __init__(self, item: dict | None) -> None:

self.scraped_at = None
self.scraped_times = 0
self.active_stream = item.get("active_stream", {})
self.active_stream = DownloadedTorrent()
self.streams: List[Stream] = []
self.blacklisted_streams: List[Stream] = []

Expand Down Expand Up @@ -159,9 +185,10 @@ def is_stream_blacklisted(self, stream: Stream):
return stream in self.blacklisted_streams

def blacklist_active_stream(self):
stream = next((stream for stream in self.streams if stream.infohash == self.active_stream.get("infohash", None)), None)
if stream:
self.blacklist_stream(stream)
if self.active_stream.infohash:
stream = next((stream for stream in self.streams if stream.infohash == self.active_stream.infohash), None)
if stream:
self.blacklist_stream(stream)
else:
logger.debug(f"No active stream for {self.log_string}, will not blacklist")

Expand Down Expand Up @@ -189,6 +216,8 @@ def _determine_state(self):
return States.Symlinked
elif self.file and self.folder:
return States.Downloaded
elif self.active_stream.infohash:
return States.Downloading
elif self.is_scraped():
return States.Scraped
elif self.title and self.is_released:
Expand Down Expand Up @@ -263,11 +292,9 @@ def to_extended_dict(self, abbreviated_children=False, with_streams=True):
dict["country"] = self.country if hasattr(self, "country") else None
dict["network"] = self.network if hasattr(self, "network") else None
if with_streams:
dict["streams"] = getattr(self, "streams", [])
dict["blacklisted_streams"] = getattr(self, "blacklisted_streams", [])
dict["active_stream"] = (
self.active_stream if hasattr(self, "active_stream") else None
)
dict["streams"] = [stream.infohash for stream in self.streams]
dict["blacklisted_streams"] = [stream.infohash for stream in self.blacklisted_streams]
dict["active_stream"] = self.active_stream.model_dump_json()
dict["number"] = self.number if hasattr(self, "number") else None
dict["symlinked"] = self.symlinked if hasattr(self, "symlinked") else None
dict["symlinked_at"] = (
Expand Down Expand Up @@ -340,23 +367,20 @@ def get_aliases(self) -> dict:
def __hash__(self):
return hash(self.id)

def reset(self):
def reset(self, reset_streams=True):
"""Reset item attributes."""
if self.type == "show":
for season in self.seasons:
for episode in season.episodes:
episode._reset()
season._reset()
episode._reset(reset_streams)
season._reset(reset_streams)
elif self.type == "season":
for episode in self.episodes:
episode._reset()
self._reset()
if self.title:
self.store_state(States.Indexed)
else:
self.store_state(States.Requested)
episode._reset(reset_streams)
self._reset(reset_streams)
self.store_state()

def _reset(self):
def _reset(self, reset_streams=True):
"""Reset item attributes for rescraping."""
if self.symlink_path:
if Path(self.symlink_path).exists():
Expand All @@ -373,10 +397,10 @@ def _reset(self):
self.set("folder", None)
self.set("alternative_folder", None)

reset_streams(self)
self.active_stream = {}
if reset_streams:
db_reset_streams(self)

self.set("active_stream", {})
self.active_stream = DownloadedTorrent()
self.set("symlinked", False)
self.set("symlinked_at", None)
self.set("update_folder", None)
Expand Down Expand Up @@ -459,6 +483,8 @@ def _determine_state(self):
return States.Symlinked
if any(season.state == States.Downloaded for season in self.seasons):
return States.Downloaded
elif self.active_stream.infohash:
return States.Downloading
if self.is_scraped():
return States.Scraped
if any(season.state == States.Indexed for season in self.seasons):
Expand Down Expand Up @@ -567,6 +593,8 @@ def _determine_state(self):
return States.Symlinked
if any(episode.file and episode.folder for episode in self.episodes):
return States.Downloaded
if self.active_stream.infohash:
return States.Downloading
if self.is_scraped():
return States.Scraped
if any(episode.state == States.Indexed for episode in self.episodes):
Expand Down
1 change: 1 addition & 0 deletions src/program/media/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class States(Enum):
Requested = "Requested"
Indexed = "Indexed"
Scraped = "Scraped"
Downloading = "Downloading"
Downloaded = "Downloaded"
Symlinked = "Symlinked"
Completed = "Completed"
Expand Down
2 changes: 1 addition & 1 deletion src/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
PlexWatchlist,
TraktContent,
)
from program.services.downloaders import Downloader
from program.services.downloaders.downloader import Downloader
from program.services.indexers.trakt import TraktIndexer
from program.services.libraries import SymlinkLibrary
from program.services.libraries.symlink import fix_broken_symlinks
Expand Down
Loading