X-Git-Url: http://git.veekun.com/zzz-pokedex.git/blobdiff_plain/00e0bf75c998b9d9c743d383f1596c91f9a03766..95cf16a6a0521a5bd78222f847953e6c7dd6fbe6:/pokedex/db/multilang.py diff --git a/pokedex/db/multilang.py b/pokedex/db/multilang.py index 3274f61..d027d13 100644 --- a/pokedex/db/multilang.py +++ b/pokedex/db/multilang.py @@ -5,8 +5,9 @@ from sqlalchemy.orm import aliased, compile_mappers, mapper, relationship, synon from sqlalchemy.orm.collections import attribute_mapped_collection from sqlalchemy.orm.scoping import ScopedSession from sqlalchemy.orm.session import Session, object_session +from sqlalchemy.engine.base import Connection from sqlalchemy.schema import Column, ForeignKey, Table -from sqlalchemy.sql.expression import and_, bindparam, select +from sqlalchemy.sql.expression import and_, bindparam, select, Select from sqlalchemy.types import Integer def create_translation_table(_table_name, foreign_class, relation_name, @@ -157,50 +158,40 @@ def create_translation_table(_table_name, foreign_class, relation_name, class MultilangSession(Session): """A tiny Session subclass that adds support for a default language. - Caller will need to assign something to `default_language` before this will - actually work. + Needs to be used with `MultilangScopedSession`, below. """ - _default_language_id = 0 # Better fill this in, caller + default_language_id = None def __init__(self, *args, **kwargs): - self.language_class = kwargs.pop('language_class') + if 'default_language_id' in kwargs: + self.default_language_id = kwargs.pop('default_language_id') + super(MultilangSession, self).__init__(*args, **kwargs) - @property - def default_language(self): - return self.query(self.language_class) \ - .filter_by(id=self._default_language_id) \ - .one() - - @default_language.setter - def default_language(self, new): - self._default_language_id = new.id - - @default_language.deleter - def default_language(self): - try: - del self._default_language_id - except AttributeError: - pass - - def execute(self, clause, params=None, *args, **kwargs): - if not params: - params = {} - params.setdefault('_default_language_id', self._default_language_id) - return super(MultilangSession, self).execute( - clause, params, *args, **kwargs) + def connection(self, *args, **kwargs): + """Monkeypatch the connection. Not pretty at all. + """ + conn = super(MultilangSession, self).connection(*args, **kwargs) + original_execute = conn.execute + if original_execute.__name__ != 'monkeypatched_execute': + def monkeypatched_execute(statement, *multiparams, **params): + if isinstance(statement, Select): + boundparams = dict(multiparams[0]) + boundparams.setdefault('_default_language_id', self.default_language_id) + multiparams = [boundparams] + list(multiparams[1:]) + return original_execute(statement, *multiparams, **params) + conn.execute = monkeypatched_execute + return conn class MultilangScopedSession(ScopedSession): """Dispatches language selection to the attached Session.""" @property - def default_language(self): - return self.registry().default_language - - @default_language.setter - def default_language(self, new): - self.registry().default_language = new - - def remove(self): - del self.registry().default_language - super(MultilangScopedSession, self).remove() + def default_language_id(self): + """Passes the new default language id through to the current session. + """ + return self.registry().default_language_id + + @default_language_id.setter + def default_language_id(self, new): + self.registry().default_language_id = new