Skip to content
Snippets Groups Projects
Commit 85d63a1c authored by Anton Sarukhanov's avatar Anton Sarukhanov
Browse files

Make get_or_create less braindead. Instead of looking for a 100% identical...

Make get_or_create less braindead. Instead of looking for a 100% identical match, it looks for one with matching keys/indices. If those match, the old record is returned. Caveat: Returned record not guaranteed to 100% match the entered parameters - only the keys. Could be bad if not accounted for
parent 9bcf5f80
No related branches found
No related tags found
No related merge requests found
......@@ -68,15 +68,3 @@ def ajax():
if __name__ == '__main__':
# Run Flask
app.run(host='0.0.0.0')
# Run Celery
from celery import current_app
from celery.bin import worker
application = current_app.get_current_object()
worker = worker(app=application)
options = {
'broker': app.config['CELERY_BROKER_URL'],
'loglevel': 'INFO',
'traceback': True,
}
worker.run(**options)
from datetime import datetime
from itertools import chain
from flask import current_app
from flask.ext.sqlalchemy import SQLAlchemy
from sqlalchemy.dialects import postgresql
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.engine import reflection
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import backref, column_property
from sqlalchemy.orm.collections import attribute_mapped_collection
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
from sqlalchemy.schema import Table
from sqlalchemy.sql.expression import ClauseElement
......@@ -25,9 +28,21 @@ class BMModel():
@classmethod
def get_or_create(self, session, create_method='', create_method_kwargs=None, **kwargs):
""" Imitate Django's get_or_create() """
""" Try to find an existing object filtering by kwargs. If not found, create. """
keys = [k.name for k in inspect(self).primary_key]
inspector = reflection.Inspector.from_engine(db.engine)
keys = list(chain.from_iterable([i['column_names'] for i in
inspector.get_indexes(inspect(self).mapped_table)]))
keys += [k.name for k in inspect(self).primary_key]
filter_args = {arg: kwargs[arg] for arg in kwargs if arg in keys}
try:
return session.query(self).filter_by(**kwargs).one()
return session.query(self).filter_by(**filter_args).one()
except MultipleResultsFound:
raise Exception("{0} matches in get_or_create. This should never happen."
" More primary keys are probably needed for some models."
"\nkwargs = {1}\nkeys = {2}\nfilter_args = {3}"
.format(session.query(self).filter_by(**filter_args).count(),
kwargs, keys, filter_args))
except NoResultFound:
kwargs.update(create_method_kwargs or {})
new = getattr(self, create_method, self)(**kwargs)
......@@ -180,7 +195,7 @@ class Region(db.Model, BMModel):
id = db.Column(db.Integer, primary_key=True)
# title - Region title
title = db.Column(db.String, unique=True)
title = db.Column(db.String, unique=True, index=True)
# API Request which was used to retrieve this data
api_call_id = db.Column(db.Integer, db.ForeignKey('api_call.id', ondelete="set null"))
......@@ -280,7 +295,7 @@ class RouteStop(db.Model, BMModel):
collection_class=attribute_mapped_collection("route_id"),
cascade="all, delete-orphan"))
stop_tag = db.Column(db.String)
stop_tag = db.Column(db.String, nullable=False)
......@@ -295,9 +310,8 @@ class Stop(db.Model, BMModel):
id = db.Column(db.Integer, primary_key=True)
# routes - Routes which serve this stop.
#routes = db.relationship("Route", secondary=route_stop, back_populates="stops", collection_class=attribute_mapped_collection("tag"), cascade="all", passive_deletes=True)
routes = association_proxy("route_stop", "route",
creator = lambda k,v: RouteStop(stop_tag=k, route=v))
creator = lambda k,v: RouteStop(stop_id=self.id, stop_tag=k, route_id=v.id))
# stop_id - Numeric ID
# Not all routes/stops have this! Cannot be used as an index/lookup.
......
......@@ -199,7 +199,7 @@ class Nextbus():
tag = direction.get('tag'),
title = direction.get('title'),
name = direction.get('name'),
api_call = api_call)
api_call_id = api_call.id)
route_obj.directions.append(d)
def save_stops(route_xml, route_obj, api_call):
stops = route_xml.findall('stop')
......@@ -209,10 +209,11 @@ class Nextbus():
lat = float(stop.get('lat')),
lon = float(stop.get('lon')),
stop_id = stop.get('stopId'),
api_call = api_call)
api_call_id = api_call.id)
db.session.flush()
rs = RouteStop.get_or_create(db.session,
route = route_obj,
stop = s,
route_id = route_obj.id,
stop_id = s.id,
stop_tag = stop.get('tag'))
r = Route.get_or_create(db.session,
tag = route_xml.get('tag'),
......@@ -227,6 +228,7 @@ class Nextbus():
api_call = api_call)
save_directions(route_xml, r, api_call)
save_stops(route_xml, r, api_call)
db.session.flush()
return r
# Get list of routes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment