Skip to content

Commit

Permalink
#112 - converted most session.query.get calls to session.get (step 2 …
Browse files Browse the repository at this point in the history
…& 3 partly)
  • Loading branch information
Philipp Kraft committed Dec 20, 2023
1 parent ffa70ca commit 503e9c5
Show file tree
Hide file tree
Showing 22 changed files with 108 additions and 109 deletions.
2 changes: 1 addition & 1 deletion odmf/dataimport/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def to_config(self) -> RawConfigParser:
from .. import db
config = RawConfigParser(allow_no_value=True)
with db.session_scope() as session:
inst = session.query(db.Datasource).get(self.instrument)
inst = session.get(db.Datasource, self.instrument)
if not inst:
raise ValueError(
'Error in import description: %s is not a valid instrument id')
Expand Down
8 changes: 4 additions & 4 deletions odmf/dataimport/importlog.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, filename, user, sheetname=0):
df[c] = df[c].astype('Int64')

with db.session_scope() as session:
_user: db.Person = session.query(db.Person).get(user)
_user: db.Person = session.get(db.Person, user)
if not _user:
raise RuntimeError('%s is not a valid user' % user)
else:
Expand Down Expand Up @@ -150,7 +150,7 @@ def logexists(self, session, site, time, timetolerance=30):

def get_dataset(self, session, row, data) -> db.Timeseries:
"""Loads the dataset from a row and checks if it is manually measured and at the correct site"""
ds = session.query(db.Dataset).get(data.dataset)
ds = session.get(db.Dataset, data.dataset)
if not ds:
raise LogImportRowError(row, f'Dataset {data.dataset} does not exist')
# check dataset is manual measurement
Expand Down Expand Up @@ -193,8 +193,8 @@ def row_to_log(self, session, row, data):
Creates a new db.Log object from a row without dataset
"""
time = data.time.to_pydatetime()
site = session.query(db.Site).get(data.site)
user = session.query(db.Person).get(self.user)
site = session.get(db.Site, data.site)
user = session.get(db.Person, self.user)

if not site:
raise LogImportRowError(row, f'Log: Site #{data.site} not found')
Expand Down
12 changes: 6 additions & 6 deletions odmf/dataimport/pandas_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, session, idescr: ImportDescription, col: ImportColumn,

if col.append:
try:
self.dataset: db.Timeseries = session.query(db.Dataset).get(int(col.append))
self.dataset: db.Timeseries = session.get(db.Dataset, int(col.append))
assert self.dataset.type == 'timeseries'
self.dataset.start = min(start, self.dataset.start)
self.dataset.end = max(end, self.dataset.end)
Expand Down Expand Up @@ -101,11 +101,11 @@ def columndatasets_from_description(
"""
with session.no_autoflush:
# Get instrument, user and site object from db
inst = session.query(db.Datasource).get(idescr.instrument)
user = session.query(db.Person).get(user)
site = session.query(db.Site).get(siteid)
inst = session.get(db.Datasource, idescr.instrument)
user = session.get(db.Person, user)
site = session.get(db.Site, siteid)
# Get "raw" as data quality, to use as a default value
raw = session.query(db.Quality).get(0)
raw = session.get(db.Quality, 0)
# Get all the relevant valuetypes (vt) from db as a dict for fast look up
valuetypes = {
vt.id: vt for vt in
Expand Down Expand Up @@ -332,7 +332,7 @@ def get_newid_range(ds: db.Timeseries):
dsid = int(dsid)
# int conversion is necessary to prevent
# (psycopg2.ProgrammingError) can't adapt type 'numpy.int64'
ds = session.query(db.Dataset).get(dsid)
ds = session.get(db.Dataset, dsid)
if ds:
# Filter data for the current ds
ds_data = data[ds_ids == dsid]
Expand Down
7 changes: 3 additions & 4 deletions odmf/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import sqlalchemy as sql
import sqlalchemy.orm as orm
from sqlalchemy.ext.declarative import declarative_base
from contextlib import contextmanager
from functools import total_ordering

Expand Down Expand Up @@ -112,10 +111,10 @@ def query(cls, session):

@classmethod
def get(cls, session, id):
return session.query(cls).get(id)
return session.get(cls, id)


Base = declarative_base(cls=Base)
Base = orm.declarative_base(cls=Base)
metadata = Base.metadata


Expand Down Expand Up @@ -150,7 +149,7 @@ def q(self) -> orm.Query:
return self.session.query(self.cls).filter_by(**self.filter)

def __getitem__(self, item):
if (res:=self.q.get(item)) is not None:
if (res:=self.session.get(self.cls, item)) is not None:
return res
else:
raise KeyError(f'{item} not found in {self.cls}')
Expand Down
2 changes: 1 addition & 1 deletion odmf/db/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def removedataset(*args):
"""Removes a dataset and its records entirely from the database
!!Handle with care, there will be no more checking!!"""
with session_scope() as session:
datasets = [session.query(Dataset).get(int(a)) for a in args]
datasets = [session.get(Dataset, int(a)) for a in args]
for ds in datasets:
dsid = ds.id
if ds.is_timeseries():
Expand Down
8 changes: 4 additions & 4 deletions odmf/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,14 @@ def load(self, start=None, end=None):
return series

def valuetype(self, session):
return session.query(db.ValueType).get(
return session.get(db.ValueType,
int(self.valuetypeid)) if self.valuetypeid else None

def site(self, session):
return session.query(db.Site).get(int(self.siteid)) if self.siteid else None
return session.get(db.Site, int(self.siteid)) if self.siteid else None

def instrument(self, session):
return session.query(db.Datasource).get(
return session.get(db.Datasource,
int(self.instrumentid)) if self.instrumentid else None

def export_csv(self, stream, start=None, end=None):
Expand Down Expand Up @@ -194,7 +194,7 @@ def get_ylabel(self):
elif self.lines:
with db.session_scope() as session:
l = self.lines[0]
valuetype = session.query(db.ValueType).get(l.valuetypeid)
valuetype = session.get(db.ValueType, l.valuetypeid)
return f'{valuetype.name} [{valuetype.unit}]'
else:
return 'unknown'
Expand Down
4 changes: 2 additions & 2 deletions odmf/tools/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def add_admin(password=None):
from odmf import db
from odmf.tools import hashpw
with db.session_scope() as session:
if session.query(db.Person).get('odmf.admin'):
if session.get(db.Person, 'odmf.admin'):
logger.info('odmf.admin exists already')
else:
user = db.Person(username='odmf.admin', firstname='odmf', surname='admin', access_level=4)
Expand All @@ -50,7 +50,7 @@ def add_quality_data(data):
with db.session_scope() as session:

for q in data:
if not session.query(db.Quality).get(q['id']):
if not session.get(db.Quality, q['id']):
session.add(db.Quality(**q))
logger.debug(f'Added quality level {q["id"]}')
active=True
Expand Down
22 changes: 11 additions & 11 deletions odmf/webpage/api/dataset_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def parse_id(dsid: str) -> int:
def get_dataset(dsid: str, check_access=True) -> db.Dataset:
dsid = DatasetAPI.parse_id(dsid)
with db.session_scope() as session:
ds = session.query(db.Dataset).get(dsid)
ds = session.get(db.Dataset, dsid)
if not ds:
raise web.APIError(404, f'ds{dsid} does not exist')
elif check_access and ds.access > ds.get_access_level(users.current):
Expand Down Expand Up @@ -158,11 +158,11 @@ def new(self, **kwargs):
"""
with db.session_scope() as session:
try:
pers = session.query(db.Person).get(kwargs.get('measured_by'))
vt = session.query(db.ValueType).get(kwargs.get('valuetype'))
q = session.query(db.Quality).get(kwargs.get('quality'))
s = session.query(db.Site).get(kwargs.get('site'))
src = session.query(db.Datasource).get(kwargs.get('source'))
pers = session.get(db.Person, kwargs.get('measured_by'))
vt = session.get(db.ValueType, kwargs.get('valuetype'))
q = session.get(db.Quality, kwargs.get('quality'))
s = session.get(db.Site, kwargs.get('site'))
src = session.get(db.Datasource, kwargs.get('source'))

ds = db.Timeseries()
# Get properties from the keyword arguments kwargs
Expand Down Expand Up @@ -214,7 +214,7 @@ def new(self, **kwargs):
@web.method.post_or_delete
def delete(self, dsid: int):
with db.session_scope() as session:
if not (ds := session.query(db.Dataset).get(dsid)):
if not (ds := session.get(db.Dataset, dsid)):
raise web.APIError(404, f'Dataset {dsid} not found')
if isinstance(ds, db.Timeseries) and ds.size():
raise web.APIError(500, f'Dataset ds{dsid} has {ds.size()} records. Call api.dataset.delete_records({dsid}) first, to delete all records')
Expand All @@ -226,7 +226,7 @@ def delete(self, dsid: int):
def count_records(self, dsid: int):
web.mime.plain.set()
with db.session_scope() as session:
if not (ds := session.query(db.Dataset).get(dsid)):
if not (ds := session.get(db.Dataset, dsid)):
raise web.APIError(404, f'Dataset {dsid} not found')
return f'{ds.size()}'.encode('utf-8')

Expand All @@ -239,7 +239,7 @@ def delete_records(self, dsid: int, start=None, end=None):
"""
web.mime.plain.set()
with db.session_scope() as session:
if not (ds := session.query(db.Timeseries).get(dsid)):
if not (ds := session.get(db.Timeseries, dsid)):
raise web.APIError(404, f'Dataset {dsid} not found')
if ds.access > ds.get_access_level(users.current):
raise web.APIError(403, 'Not enough privileges')
Expand Down Expand Up @@ -276,7 +276,7 @@ def addrecord(self, dsid: int, value: float, time: str,
with db.session_scope() as session:
try:
dsid = self.parse_id(dsid)
ds = session.query(db.Timeseries).get(dsid)
ds = session.get(db.Timeseries, dsid)
if ds.access > ds.get_access_level(users.current):
raise web.APIError(403, 'Not enough privileges')

Expand Down Expand Up @@ -332,7 +332,7 @@ def addrecords_json(self):
f'(allowed keywords are dsid, dataset and dataset_id)')
if not dataset or dataset.id != dsid:
# load dataset from db
dataset = session.query(db.Dataset).get(dsid)
dataset = session.get(db.Dataset, dsid)
else:
... # reuse last dataset
if not dataset:
Expand Down
2 changes: 1 addition & 1 deletion odmf/webpage/api/site_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def index(self, siteid: int=None):
url, res = get_help(self, self.url)
return web.json_out(res)
with db.session_scope() as session:
if not (site := session.query(db.Site).get(siteid)):
if not (site := session.get(db.Site, siteid)):
raise web.APIError(404, f'#{site} does not exist')
else:
return web.json_out(site)
Expand Down
Loading

0 comments on commit 503e9c5

Please sign in to comment.