X-Git-Url: http://git.veekun.com/zzz-pokedex.git/blobdiff_plain/9f083a8f296186583d592ceb7263b24f0c282d18..refs/heads/encukou-sqla-0.7:/pokedex/db/multilang.py diff --git a/pokedex/db/multilang.py b/pokedex/db/multilang.py index b031593..d027d13 100644 --- a/pokedex/db/multilang.py +++ b/pokedex/db/multilang.py @@ -3,9 +3,11 @@ from functools import partial from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.orm import aliased, compile_mappers, mapper, relationship, synonym from sqlalchemy.orm.collections import attribute_mapped_collection +from sqlalchemy.orm.scoping import ScopedSession from sqlalchemy.orm.session import Session, object_session +from sqlalchemy.engine.base import Connection from sqlalchemy.schema import Column, ForeignKey, Table -from sqlalchemy.sql.expression import and_, bindparam, select +from sqlalchemy.sql.expression import and_, bindparam, select, Select from sqlalchemy.types import Integer def create_translation_table(_table_name, foreign_class, relation_name, @@ -76,9 +78,11 @@ def create_translation_table(_table_name, foreign_class, relation_name, # Create the table object table = Table(_table_name, foreign_class.__table__.metadata, Column(foreign_key_name, Integer, ForeignKey(foreign_class.id), - primary_key=True, nullable=False), + primary_key=True, nullable=False, + info=dict(description="ID of the %s these texts relate to" % foreign_class.__singlename__)), Column('local_language_id', Integer, ForeignKey(language_class.id), - primary_key=True, nullable=False), + primary_key=True, nullable=False, + info=dict(description="Language these texts are in")), ) Translations.__table__ = table @@ -96,8 +100,8 @@ def create_translation_table(_table_name, foreign_class, relation_name, 'foreign_id': synonym(foreign_key_name), 'local_language': relationship(language_class, primaryjoin=table.c.local_language_id == language_class.id, - lazy='joined', - innerjoin=True), + innerjoin=True, + lazy='joined'), }) # Add full-table relations to the original class @@ -110,21 +114,19 @@ def create_translation_table(_table_name, foreign_class, relation_name, )) # Foo.bars_local # This is a bit clever; it uses bindparam() to make the join clause - # modifiable on the fly. db sessions know the current language identifier - # populates the bindparam. The manual alias and join are (a) to make the - # condition nice (sqla prefers an EXISTS) and to make the columns play nice - # when foreign_class == language_class. + # modifiable on the fly. db sessions know the current language and + # populate the bindparam. + # The 'dummy' value is to trick SQLA; without it, SQLA thinks this + # bindparam is just its own auto-generated clause and everything gets + # fucked up. local_relation_name = relation_name + '_local' - language_class_a = aliased(language_class) setattr(foreign_class, local_relation_name, relationship(Translations, primaryjoin=and_( - foreign_class.id == Translations.foreign_id, - Translations.local_language_id == select( - [language_class_a.id], - language_class_a.identifier == - bindparam('_default_language', required=True), - ), + Translations.foreign_id == foreign_class.id, + Translations.local_language_id == bindparam('_default_language_id', + value='dummy', type_=Integer, required=True), ), + foreign_keys=[Translations.foreign_id, Translations.local_language_id], uselist=False, #innerjoin=True, lazy=relation_lazy, @@ -147,16 +149,49 @@ def create_translation_table(_table_name, foreign_class, relation_name, setattr(foreign_class, name + '_map', association_proxy(relation_name, name, creator=creator)) + # Add to the list of translation classes + foreign_class.translation_classes.append(Translations) + # Done return Translations class MultilangSession(Session): - """A tiny Session subclass that adds support for a default language.""" - default_language = 'en' - - def execute(self, clause, params=None, *args, **kwargs): - if not params: - params = {} - params.setdefault('_default_language', self.default_language) - return super(MultilangSession, self).execute( - clause, params, *args, **kwargs) + """A tiny Session subclass that adds support for a default language. + + Needs to be used with `MultilangScopedSession`, below. + """ + default_language_id = None + + def __init__(self, *args, **kwargs): + if 'default_language_id' in kwargs: + self.default_language_id = kwargs.pop('default_language_id') + + super(MultilangSession, self).__init__(*args, **kwargs) + + def connection(self, *args, **kwargs): + """Monkeypatch the connection. Not pretty at all. + """ + conn = super(MultilangSession, self).connection(*args, **kwargs) + original_execute = conn.execute + if original_execute.__name__ != 'monkeypatched_execute': + def monkeypatched_execute(statement, *multiparams, **params): + if isinstance(statement, Select): + boundparams = dict(multiparams[0]) + boundparams.setdefault('_default_language_id', self.default_language_id) + multiparams = [boundparams] + list(multiparams[1:]) + return original_execute(statement, *multiparams, **params) + conn.execute = monkeypatched_execute + return conn + +class MultilangScopedSession(ScopedSession): + """Dispatches language selection to the attached Session.""" + + @property + def default_language_id(self): + """Passes the new default language id through to the current session. + """ + return self.registry().default_language_id + + @default_language_id.setter + def default_language_id(self, new): + self.registry().default_language_id = new