Support filtering by strings (Pokemon.name, Pokemon.names['fr'], etc.)
authorPetr Viktorin <encukou@gmail.com>
Sat, 12 Mar 2011 12:36:08 +0000 (14:36 +0200)
committerEevee <git@veekun.com>
Sun, 13 Mar 2011 22:43:01 +0000 (15:43 -0700)
pokedex/db/tables.py
pokedex/tests/test_strings.py [new file with mode: 0644]

index 89a0efa..f847aab 100644 (file)
@@ -15,6 +15,14 @@ Columns have a info dictionary with these keys:
   - identifier: A fan-made identifier in the [-_a-z0-9]* format. Not intended
     for translation.
   - latex: A formula in LaTeX syntax.
   - identifier: A fan-made identifier in the [-_a-z0-9]* format. Not intended
     for translation.
   - latex: A formula in LaTeX syntax.
+
+A localizable text column is visible as two properties:
+The plural-name property (e.g. Pokemon.names) is a language-to-name dictionary:
+  bulbasaur.names['en'] == "Bulbasaur" and bulbasaur.names['de'] == "Bisasam".
+  You can use Pokemon.names['en'] to filter a query.
+The singular-name property returns the name in the default language, English.
+  For example bulbasaur.name == "Bulbasaur"
+  Setting pokedex.db.tables.default_lang changes the default language.
 """
 # XXX: Check if "gametext" is set correctly everywhere
 
 """
 # XXX: Check if "gametext" is set correctly everywhere
 
@@ -30,7 +38,9 @@ from sqlalchemy.orm import (
     )
 from sqlalchemy.orm.session import Session
 from sqlalchemy.orm.collections import attribute_mapped_collection
     )
 from sqlalchemy.orm.session import Session
 from sqlalchemy.orm.collections import attribute_mapped_collection
+from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.sql import and_
 from sqlalchemy.sql import and_
+from sqlalchemy.sql.expression import ColumnOperators
 from sqlalchemy.types import *
 from inspect import isclass
 
 from sqlalchemy.types import *
 from inspect import isclass
 
@@ -1789,6 +1799,8 @@ for table in list(table_classes):
 
 ### Add text/prose tables
 
 
 ### Add text/prose tables
 
+default_lang = u'en'
+
 def makeTextTable(object_table, name_plural, name_singular, columns, lazy):
     # With "Language", we'd have two language_id. So, rename one to 'lang'
     safe_name = object_table.__singlename__
 def makeTextTable(object_table, name_plural, name_singular, columns, lazy):
     # With "Language", we'd have two language_id. So, rename one to 'lang'
     safe_name = object_table.__singlename__
@@ -1801,6 +1813,8 @@ def makeTextTable(object_table, name_plural, name_singular, columns, lazy):
     class Strings(object):
         __tablename__ = tablename
         __singlename__ = singlename
     class Strings(object):
         __tablename__ = tablename
         __singlename__ = singlename
+        _attrname = name_plural
+        _language_identifier = association_proxy('language', 'identifier')
 
     for name, plural, column in columns:
         column.name = name
 
     for name, plural, column in columns:
         column.name = name
@@ -1833,35 +1847,71 @@ def makeTextTable(object_table, name_plural, name_singular, columns, lazy):
         ))
     Strings.object = getattr(Strings, safe_name)
 
         ))
     Strings.object = getattr(Strings, safe_name)
 
-    # Link the tables themselves, so we can get to them
+    # Link the tables themselves, so we can get them if needed
     Strings.object_table = object_table
     setattr(object_table, name_singular + '_table', Strings)
 
     for colname, pluralname, column in columns:
     Strings.object_table = object_table
     setattr(object_table, name_singular + '_table', Strings)
 
     for colname, pluralname, column in columns:
-        # Provide a relation with all the names, and an English accessor
+        # Provide a property with all the names, and an English accessor
         # for backwards compatibility
         # for backwards compatibility
-        def scope(colname, pluralname, column):
-            def get_strings(self):
-                return dict(
-                        (l, getattr(t, colname))
-                        for l, t in getattr(self, name_plural).items()
-                    )
-
-            def get_english_string(self):
-                try:
-                    return get_strings(self)['en']
-                except KeyError:
-                    raise AttributeError(colname)
-
-            setattr(object_table, pluralname, property(get_strings))
-            setattr(object_table, colname, property(get_english_string))
-        scope(colname, pluralname, column)
+        setattr(object_table, pluralname, StringProperty(
+                object_table, Strings, colname,
+            ))
+        setattr(object_table, colname, DefaultLangProperty(pluralname))
 
         if colname == 'name':
             object_table.name_table = Strings
 
     return Strings
 
 
         if colname == 'name':
             object_table.name_table = Strings
 
     return Strings
 
+class StringProperty(object):
+    def __init__(self, cls, stringclass, colname):
+        self.cls = cls
+        self.colname = colname
+        self.stringclass = stringclass
+
+    def __get__(self, instance, cls):
+        if instance:
+            return dict(
+                    (l, getattr(t, self.colname))
+                    for l, t
+                    in getattr(instance, self.stringclass._attrname).items()
+                )
+        else:
+            return self
+
+    def __getitem__(self, lang):
+        return StringExpression(self, lang)
+
+    def __str__(self):
+        return '<StringDict %s.%s>' % (self.cls, self.colname)
+
+class StringExpression(ColumnOperators):
+    def __init__(self, prop, lang):
+        self.prop = prop
+        self.column = getattr(prop.stringclass, prop.colname)
+        self.lang_column = prop.stringclass._language_identifier
+        if isinstance(lang, basestring):
+            self.lang = lang
+        else:
+            self.lang = lang.identifier
+
+    def operate(self, op, *values, **kwargs):
+        return getattr(self.prop.cls, self.prop.stringclass._attrname).any(and_(
+                self.lang_column == self.lang,
+                op(self.column, *values, **kwargs),
+            ))
+
+class DefaultLangProperty(object):
+    def __init__(self, colname):
+        self.colname = colname
+
+    def __get__(self, instance, cls):
+        if instance:
+            return getattr(instance, self.colname)[default_lang]
+        else:
+            return getattr(cls, self.colname)[default_lang]
+
 for table in list(table_classes):
     # Find all the language-specific columns, keeping them in the order they
     # were defined
 for table in list(table_classes):
     # Find all the language-specific columns, keeping them in the order they
     # were defined
diff --git a/pokedex/tests/test_strings.py b/pokedex/tests/test_strings.py
new file mode 100644 (file)
index 0000000..065b0ed
--- /dev/null
@@ -0,0 +1,43 @@
+# Encoding: UTF-8
+
+from nose.tools import *
+
+from pokedex.db import tables, connect
+
+class TestStrings(object):
+    def setup(self):
+        self.connection = connect()
+
+    def test_filter(self):
+        q = self.connection.query(tables.Pokemon).filter(
+                tables.Pokemon.name == u"Marowak")
+        assert q.one().identifier == 'marowak'
+
+    def test_gt(self):
+        # Assuming that the identifiers are just lowercase names
+        q1 = self.connection.query(tables.Pokemon).filter(
+                tables.Pokemon.name > u"Xatu").order_by(
+                tables.Pokemon.id)
+        q2 = self.connection.query(tables.Pokemon).filter(
+                tables.Pokemon.identifier > u"xatu").order_by(
+                tables.Pokemon.id)
+        assert q1.all() == q2.all()
+
+    def test_languages(self):
+        q = self.connection.query(tables.Pokemon).filter(
+                tables.Pokemon.name == u"Mightyena")
+        pkmn = q.one()
+        for lang, name in (
+                ('en', u'Mightyena'),
+                ('ja', u'グラエナ'),
+                ('roomaji', u'Guraena'),
+                ('fr', u'Grahyèna'),
+            ):
+            assert pkmn.names[lang] == name
+
+    @raises(KeyError)
+    def test_bad_lang(self):
+        q = self.connection.query(tables.Pokemon).filter(
+                tables.Pokemon.name == u"Mightyena")
+        pkmn = q.one()
+        pkmn.names["identifier of a language that doesn't exist"]