Match default language by id, not identifier.
authorEevee <git@veekun.com>
Wed, 30 Mar 2011 03:15:41 +0000 (20:15 -0700)
committerEevee <git@veekun.com>
Wed, 30 Mar 2011 03:15:41 +0000 (20:15 -0700)
pokedex/db/__init__.py
pokedex/db/multilang.py

index a6c8f6e..e2790da 100644 (file)
@@ -2,7 +2,7 @@ from sqlalchemy import MetaData, Table, engine_from_config, orm
 
 from ..defaults import get_default_db_uri
 from .tables import metadata
 
 from ..defaults import get_default_db_uri
 from .tables import metadata
-from .multilang import MultilangSession
+from .multilang import MultilangSession, MultilangScopedSession
 
 
 def connect(uri=None, session_args={}, engine_args={}, engine_prefix=''):
 
 
 def connect(uri=None, session_args={}, engine_args={}, engine_prefix=''):
@@ -42,6 +42,6 @@ def connect(uri=None, session_args={}, engine_args={}, engine_prefix=''):
     all_session_args = dict(autoflush=True, autocommit=False, bind=engine)
     all_session_args.update(session_args)
     sm = orm.sessionmaker(class_=MultilangSession, **all_session_args)
     all_session_args = dict(autoflush=True, autocommit=False, bind=engine)
     all_session_args.update(session_args)
     sm = orm.sessionmaker(class_=MultilangSession, **all_session_args)
-    session = orm.scoped_session(sm)
+    session = MultilangScopedSession(sm)
 
     return session
 
     return session
index b031593..f3438dd 100644 (file)
@@ -3,6 +3,7 @@ from functools import partial
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import aliased, compile_mappers, mapper, relationship, synonym
 from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import aliased, compile_mappers, mapper, relationship, synonym
 from sqlalchemy.orm.collections import attribute_mapped_collection
+from sqlalchemy.orm.scoping import ScopedSession
 from sqlalchemy.orm.session import Session, object_session
 from sqlalchemy.schema import Column, ForeignKey, Table
 from sqlalchemy.sql.expression import and_, bindparam, select
 from sqlalchemy.orm.session import Session, object_session
 from sqlalchemy.schema import Column, ForeignKey, Table
 from sqlalchemy.sql.expression import and_, bindparam, select
@@ -96,7 +97,6 @@ def create_translation_table(_table_name, foreign_class, relation_name,
         'foreign_id': synonym(foreign_key_name),
         'local_language': relationship(language_class,
             primaryjoin=table.c.local_language_id == language_class.id,
         'foreign_id': synonym(foreign_key_name),
         'local_language': relationship(language_class,
             primaryjoin=table.c.local_language_id == language_class.id,
-            lazy='joined',
             innerjoin=True),
     })
 
             innerjoin=True),
     })
 
@@ -110,21 +110,19 @@ def create_translation_table(_table_name, foreign_class, relation_name,
     ))
     # Foo.bars_local
     # This is a bit clever; it uses bindparam() to make the join clause
     ))
     # Foo.bars_local
     # This is a bit clever; it uses bindparam() to make the join clause
-    # modifiable on the fly.  db sessions know the current language identifier
-    # populates the bindparam.  The manual alias and join are (a) to make the
-    # condition nice (sqla prefers an EXISTS) and to make the columns play nice
-    # when foreign_class == language_class.
+    # modifiable on the fly.  db sessions know the current language and
+    # populate the bindparam.
+    # The 'dummy' value is to trick SQLA; without it, SQLA thinks this
+    # bindparam is just its own auto-generated clause and everything gets
+    # fucked up.
     local_relation_name = relation_name + '_local'
     local_relation_name = relation_name + '_local'
-    language_class_a = aliased(language_class)
     setattr(foreign_class, local_relation_name, relationship(Translations,
         primaryjoin=and_(
     setattr(foreign_class, local_relation_name, relationship(Translations,
         primaryjoin=and_(
-            foreign_class.id == Translations.foreign_id,
-            Translations.local_language_id == select(
-                [language_class_a.id],
-                language_class_a.identifier ==
-                    bindparam('_default_language', required=True),
-            ),
+            Translations.foreign_id == foreign_class.id,
+            Translations.local_language_id == bindparam('_default_language_id',
+                value='dummy', type_=Integer, required=True),
         ),
         ),
+        foreign_keys=[Translations.foreign_id, Translations.local_language_id],
         uselist=False,
         #innerjoin=True,
         lazy=relation_lazy,
         uselist=False,
         #innerjoin=True,
         lazy=relation_lazy,
@@ -152,11 +150,42 @@ def create_translation_table(_table_name, foreign_class, relation_name,
 
 class MultilangSession(Session):
     """A tiny Session subclass that adds support for a default language."""
 
 class MultilangSession(Session):
     """A tiny Session subclass that adds support for a default language."""
-    default_language = 'en'
+    _default_language_id = 9  # English.  XXX magic constant
+
+    @property
+    def default_language(self):
+        # XXX need to get the right mapped class for this to work
+        raise NotImplementedError
+
+    @default_language.setter
+    def default_language(self, new):
+        self._default_language_id = new#.id
+
+    @default_language.deleter
+    def default_language(self):
+        try:
+            del self._default_language_id
+        except AttributeError:
+            pass
 
     def execute(self, clause, params=None, *args, **kwargs):
         if not params:
             params = {}
 
     def execute(self, clause, params=None, *args, **kwargs):
         if not params:
             params = {}
-        params.setdefault('_default_language', self.default_language)
+        params.setdefault('_default_language_id', self._default_language_id)
         return super(MultilangSession, self).execute(
             clause, params, *args, **kwargs)
         return super(MultilangSession, self).execute(
             clause, params, *args, **kwargs)
+
+class MultilangScopedSession(ScopedSession):
+    """Dispatches language selection to the attached Session."""
+
+    @property
+    def default_language(self):
+        return self.registry().default_language
+
+    @default_language.setter
+    def default_language(self, new):
+        self.registry().default_language = new
+
+    def remove(self):
+        del self.registry().default_language
+        super(MultilangScopedSession, self).remove()