Update test_schema to the new API. Add some missing column metadata.
[zzz-pokedex.git] / pokedex / db / multilang.py
index b031593..4adb68d 100644 (file)
@@ -3,6 +3,7 @@ from functools import partial
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import aliased, compile_mappers, mapper, relationship, synonym
 from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import aliased, compile_mappers, mapper, relationship, synonym
 from sqlalchemy.orm.collections import attribute_mapped_collection
+from sqlalchemy.orm.scoping import ScopedSession
 from sqlalchemy.orm.session import Session, object_session
 from sqlalchemy.schema import Column, ForeignKey, Table
 from sqlalchemy.sql.expression import and_, bindparam, select
 from sqlalchemy.orm.session import Session, object_session
 from sqlalchemy.schema import Column, ForeignKey, Table
 from sqlalchemy.sql.expression import and_, bindparam, select
@@ -76,9 +77,11 @@ def create_translation_table(_table_name, foreign_class, relation_name,
     # Create the table object
     table = Table(_table_name, foreign_class.__table__.metadata,
         Column(foreign_key_name, Integer, ForeignKey(foreign_class.id),
     # Create the table object
     table = Table(_table_name, foreign_class.__table__.metadata,
         Column(foreign_key_name, Integer, ForeignKey(foreign_class.id),
-            primary_key=True, nullable=False),
+            primary_key=True, nullable=False,
+            info=dict(description="ID of the %s these texts relate to" % foreign_class.__singlename__)),
         Column('local_language_id', Integer, ForeignKey(language_class.id),
         Column('local_language_id', Integer, ForeignKey(language_class.id),
-            primary_key=True, nullable=False),
+            primary_key=True, nullable=False,
+            info=dict(description="Language these texts are in")),
     )
     Translations.__table__ = table
 
     )
     Translations.__table__ = table
 
@@ -96,7 +99,6 @@ def create_translation_table(_table_name, foreign_class, relation_name,
         'foreign_id': synonym(foreign_key_name),
         'local_language': relationship(language_class,
             primaryjoin=table.c.local_language_id == language_class.id,
         'foreign_id': synonym(foreign_key_name),
         'local_language': relationship(language_class,
             primaryjoin=table.c.local_language_id == language_class.id,
-            lazy='joined',
             innerjoin=True),
     })
 
             innerjoin=True),
     })
 
@@ -110,21 +112,19 @@ def create_translation_table(_table_name, foreign_class, relation_name,
     ))
     # Foo.bars_local
     # This is a bit clever; it uses bindparam() to make the join clause
     ))
     # Foo.bars_local
     # This is a bit clever; it uses bindparam() to make the join clause
-    # modifiable on the fly.  db sessions know the current language identifier
-    # populates the bindparam.  The manual alias and join are (a) to make the
-    # condition nice (sqla prefers an EXISTS) and to make the columns play nice
-    # when foreign_class == language_class.
+    # modifiable on the fly.  db sessions know the current language and
+    # populate the bindparam.
+    # The 'dummy' value is to trick SQLA; without it, SQLA thinks this
+    # bindparam is just its own auto-generated clause and everything gets
+    # fucked up.
     local_relation_name = relation_name + '_local'
     local_relation_name = relation_name + '_local'
-    language_class_a = aliased(language_class)
     setattr(foreign_class, local_relation_name, relationship(Translations,
         primaryjoin=and_(
     setattr(foreign_class, local_relation_name, relationship(Translations,
         primaryjoin=and_(
-            foreign_class.id == Translations.foreign_id,
-            Translations.local_language_id == select(
-                [language_class_a.id],
-                language_class_a.identifier ==
-                    bindparam('_default_language', required=True),
-            ),
+            Translations.foreign_id == foreign_class.id,
+            Translations.local_language_id == bindparam('_default_language_id',
+                value='dummy', type_=Integer, required=True),
         ),
         ),
+        foreign_keys=[Translations.foreign_id, Translations.local_language_id],
         uselist=False,
         #innerjoin=True,
         lazy=relation_lazy,
         uselist=False,
         #innerjoin=True,
         lazy=relation_lazy,
@@ -147,16 +147,50 @@ def create_translation_table(_table_name, foreign_class, relation_name,
         setattr(foreign_class, name + '_map',
             association_proxy(relation_name, name, creator=creator))
 
         setattr(foreign_class, name + '_map',
             association_proxy(relation_name, name, creator=creator))
 
+    # Add to the list of translation classes
+    foreign_class.translation_classes.append(Translations)
+
     # Done
     return Translations
 
 class MultilangSession(Session):
     """A tiny Session subclass that adds support for a default language."""
     # Done
     return Translations
 
 class MultilangSession(Session):
     """A tiny Session subclass that adds support for a default language."""
-    default_language = 'en'
+    _default_language_id = 9  # English.  XXX magic constant
+
+    @property
+    def default_language(self):
+        # XXX need to get the right mapped class for this to work
+        raise NotImplementedError
+
+    @default_language.setter
+    def default_language(self, new):
+        self._default_language_id = new#.id
+
+    @default_language.deleter
+    def default_language(self):
+        try:
+            del self._default_language_id
+        except AttributeError:
+            pass
 
     def execute(self, clause, params=None, *args, **kwargs):
         if not params:
             params = {}
 
     def execute(self, clause, params=None, *args, **kwargs):
         if not params:
             params = {}
-        params.setdefault('_default_language', self.default_language)
+        params.setdefault('_default_language_id', self._default_language_id)
         return super(MultilangSession, self).execute(
             clause, params, *args, **kwargs)
         return super(MultilangSession, self).execute(
             clause, params, *args, **kwargs)
+
+class MultilangScopedSession(ScopedSession):
+    """Dispatches language selection to the attached Session."""
+
+    @property
+    def default_language(self):
+        return self.registry().default_language
+
+    @default_language.setter
+    def default_language(self, new):
+        self.registry().default_language = new
+
+    def remove(self):
+        del self.registry().default_language
+        super(MultilangScopedSession, self).remove()