Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow partial reads and align with GTFS specifications for stop_lat, stop_long, and transfer_type #68

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
12 changes: 7 additions & 5 deletions pygtfs/gtfs_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def in_range(self, key, value):
def _validate_float_range(float_min, float_max, *field_names):
@validates(*field_names)
def in_range(self, key, value):
if value is None:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this make all of the in_range validated fields nullable? What am I missing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since some floats are nullable, this rule was previously too strict, but you're right- this solution may be overly broad.

I can look into adding a nullable argument so this check won't fail only for nullable columns.

Copy link
Collaborator Author

@InterferencePattern InterferencePattern Jan 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check this, I think I've made the necessary change to at least allow a check on nullable columns. I still need to make it conditional on the value of another column (but I'm new to sqlalchemy)

return None
float_value = float(value)
if not (float_min <= float_value <= float_max):
raise PygtfsValidationError(
Expand Down Expand Up @@ -145,8 +147,8 @@ class Stop(Base):
stop_code = Column(Unicode, nullable=True, index=True)
stop_name = Column(Unicode)
stop_desc = Column(Unicode, nullable=True)
stop_lat = Column(Float)
stop_lon = Column(Float)
stop_lat = Column(Float, nullable=True)
stop_lon = Column(Float, nullable=True)
zone_id = Column(Unicode, nullable=True)
stop_url = Column(Unicode, nullable=True)
location_type = Column(Integer, nullable=True)
Expand Down Expand Up @@ -328,7 +330,6 @@ class Trip(Base):
primaryjoin=and_(foreign(service_id) == Service.service_id,
feed_id == Service.feed_id))


_validate_direction_id = _validate_int_choice([None, 0, 1], 'direction_id')
_validate_wheelchair = _validate_int_choice([None, 0, 1, 2],
'wheelchair_accessible')
Expand Down Expand Up @@ -507,7 +508,7 @@ class Transfer(Base):
primaryjoin=and_(Trip.trip_id == foreign(to_trip_id),
Trip.feed_id == feed_id))

_validate_transfer_type = _validate_int_choice([None, 0, 1, 2, 3],
_validate_transfer_type = _validate_int_choice([None, 0, 1, 2, 3, 4, 5],
'transfer_type')

def __repr__(self):
Expand Down Expand Up @@ -543,7 +544,8 @@ def __repr__(self):
Column('trans_id', Unicode),
Column('lang', Unicode),
ForeignKeyConstraint(['stop_feed_id', 'stop_id'], [Stop.feed_id, Stop.stop_id]),
ForeignKeyConstraint(['translation_feed_id', 'trans_id', 'lang'], [Translation.feed_id, Translation.trans_id, Translation.lang]),
ForeignKeyConstraint(['translation_feed_id', 'trans_id', 'lang'],
[Translation.feed_id, Translation.trans_id, Translation.lang]),
)


Expand Down
29 changes: 15 additions & 14 deletions pygtfs/loader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import (division, absolute_import, print_function,
unicode_literals)

from datetime import date
import sys
from datetime import date

import six
from sqlalchemy import and_
from sqlalchemy.sql.expression import select, join

from .gtfs_entities import (Feed, Service, ServiceException, gtfs_required,
from . import feed
from .exceptions import PygtfsException
from .gtfs_entities import (Feed, gtfs_required,
Translation, Stop, Trip, ShapePoint, _stop_translations,
_trip_shapes, gtfs_calendar, gtfs_all)
from . import feed


def list_feeds(schedule):
Expand All @@ -21,7 +20,6 @@ def list_feeds(schedule):


def delete_feed(schedule, feed_filename, interactive=False):

feed_name = feed.derive_feed_name(feed_filename)
feeds_with_name = schedule.session.query(Feed).filter(Feed.feed_name == feed_name).all()
delete_all = not interactive
Expand All @@ -46,8 +44,7 @@ def overwrite_feed(schedule, feed_filename, *args, **kwargs):


def append_feed(schedule, feed_filename, strip_fields=True,
chunk_size=5000, agency_id_override=None):

chunk_size=5000, agency_id_override=None, ignore_failures=True):
InterferencePattern marked this conversation as resolved.
Show resolved Hide resolved
fd = feed.Feed(feed_filename, strip_fields)

gtfs_tables = {}
Expand Down Expand Up @@ -77,24 +74,28 @@ def append_feed(schedule, feed_filename, strip_fields=True,
continue
gtfs_table = gtfs_tables[gtfs_class]


skipped_records = 0
read_records = 0
for i, record in enumerate(gtfs_table):
if not record:
# Empty row.
continue

try:
instance = gtfs_class(feed_id=feed_id, **record._asdict())
schedule.session.add(instance)
read_records += 1
except:
print("Failure while writing {0}".format(record))
raise
schedule.session.add(instance)
skipped_records += 1
print("Failure while writing {}".format(record))
if not ignore_failures:
raise
if i % chunk_size == 0 and i > 0:
schedule.session.flush()
sys.stdout.write('.')
sys.stdout.flush()
print('%d record%s read for %s.' % ((i+1), '' if i == 0 else 's',
gtfs_class))
print('{0} records read for {1}'.format(read_records, gtfs_class))
InterferencePattern marked this conversation as resolved.
Show resolved Hide resolved
print('{0} records skipped for {1}'.format(skipped_records, gtfs_class))
schedule.session.flush()
schedule.session.commit()
# load many to many relationships
Expand Down