Make the string properties mutable
authorPetr Viktorin <encukou@gmail.com>
Sun, 13 Mar 2011 15:47:31 +0000 (17:47 +0200)
committerEevee <git@veekun.com>
Sun, 13 Mar 2011 22:43:42 +0000 (15:43 -0700)
pokedex/db/tables.py
pokedex/tests/test_strings.py

index 0b2c80f..76bd753 100644 (file)
@@ -26,7 +26,7 @@ The singular-name property returns the name in the default language, English.
 """
 # XXX: Check if "gametext" is set correctly everywhere
 
-import operator
+import collections
 
 from sqlalchemy import Column, ForeignKey, MetaData, PrimaryKeyConstraint, Table
 from sqlalchemy.ext.declarative import (
@@ -36,11 +36,12 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import (
         backref, eagerload_all, relation, class_mapper, synonym, mapper,
     )
-from sqlalchemy.orm.session import Session
+from sqlalchemy.orm.session import Session, object_session
 from sqlalchemy.orm.collections import attribute_mapped_collection
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.sql import and_
 from sqlalchemy.sql.expression import ColumnOperators
+from sqlalchemy.schema import ColumnDefault
 from sqlalchemy.types import *
 from inspect import isclass
 
@@ -1808,6 +1809,10 @@ def makeTextTable(object_table, name_plural, name_singular, columns, lazy):
 
     for name, plural, column in columns:
         column.name = name
+        if not column.nullable:
+            # A Python side default value, so that the strings can be set
+            # one by one without the DB complaining about missing values
+            column.default = ColumnDefault(u'')
 
     table = Table(tablename, metadata,
             Column(safe_name + '_id', Integer, ForeignKey(object_table.id),
@@ -1862,11 +1867,7 @@ class StringProperty(object):
 
     def __get__(self, instance, cls):
         if instance:
-            return dict(
-                    (l, getattr(t, self.colname))
-                    for l, t
-                    in getattr(instance, self.stringclass._attrname).items()
-                )
+            return StringMapping(instance, self)
         else:
             return self
 
@@ -1876,6 +1877,48 @@ class StringProperty(object):
     def __str__(self):
         return '<StringDict %s.%s>' % (self.cls, self.colname)
 
+class StringMapping(collections.MutableMapping):
+    def __init__(self, instance, prop):
+        self.stringclass = prop.stringclass
+        self.instance = instance
+        self.strings = getattr(instance, prop.stringclass._attrname)
+        self.colname = prop.colname
+
+    def __len__(self):
+        return len(self.strings)
+
+    def __iter__(self):
+        return iter(self.strings)
+
+    def __contains__(self, lang):
+        return lang in self.strings
+
+    def __getitem__(self, lang):
+        return getattr(self.strings[lang], self.colname)
+
+    def __setitem__(self, lang, value):
+        try:
+            # Modifying an existing row
+            row = self.strings[lang]
+        except KeyError:
+            # We need do add a whole row for the language
+            row = self.stringclass()
+            row.object_id = self.instance.id
+            session = object_session(self.instance)
+            if isinstance(lang, basestring):
+                lang = session.query(Language).filter_by(
+                        identifier=lang).one()
+            row.language = lang
+            self.strings[lang] = row
+            session.add(row)
+        return setattr(row, self.colname, value)
+
+    def __delitem__(self, lang):
+        raise NotImplementedError('Cannot delete a single string. '
+                'Perhaps you wan to delete all of %s.%s?' %
+                (self.instance, self.stringclass._attrname)
+            )
+
 class StringExpression(ColumnOperators):
     def __init__(self, prop, lang):
         self.prop = prop
@@ -1902,6 +1945,12 @@ class DefaultLangProperty(object):
         else:
             return getattr(cls, self.colname)[default_lang]
 
+    def __set__(self, instance, value):
+        getattr(instance, self.colname)[default_lang] = value
+
+    def __delete__(self, instance):
+        del getattr(instance, self.colname)[default_lang]
+
 for table in list(table_classes):
     # Find all the language-specific columns, keeping them in the order they
     # were defined
index 065b0ed..eba0819 100644 (file)
@@ -8,6 +8,9 @@ class TestStrings(object):
     def setup(self):
         self.connection = connect()
 
+    def teardown(self):
+        self.connection.rollback()
+
     def test_filter(self):
         q = self.connection.query(tables.Pokemon).filter(
                 tables.Pokemon.name == u"Marowak")
@@ -41,3 +44,63 @@ class TestStrings(object):
                 tables.Pokemon.name == u"Mightyena")
         pkmn = q.one()
         pkmn.names["identifier of a language that doesn't exist"]
+
+    def test_mutating(self):
+        item = self.connection.query(tables.Item).filter_by(
+                identifier=u"jade-orb").one()
+        language = self.connection.query(tables.Language).filter_by(
+                identifier=u"de").one()
+        item.names['de'] = u"foo"
+        assert item.names['de'] == "foo"
+        assert item.names[language] == "foo"
+        item.names[language] = u"xyzzy"
+        assert item.names['de'] == "xyzzy"
+        assert item.names[language] == "xyzzy"
+
+    def test_mutating_default(self):
+        item = self.connection.query(tables.Item).filter_by(
+                identifier=u"jade-orb").one()
+        item.name = u"foo"
+        assert item.name == "foo"
+
+    def test_string_mapping(self):
+        item = self.connection.query(tables.Item).filter_by(
+                identifier=u"jade-orb").one()
+        assert len(item.names) == len(item.texts)
+        for lang in item.texts:
+            assert item.names[lang] == item.texts[lang].name
+            assert item.names[lang] == item.names[lang.identifier]
+            assert lang in item.names
+            assert lang.identifier in item.names
+        assert "language that doesn't exist" not in item.names
+        assert tables.Language() not in item.names
+
+    def test_new_language(self):
+        item = self.connection.query(tables.Item).filter_by(
+                identifier=u"jade-orb").one()
+        language = tables.Language()
+        language.id = -1
+        language.identifier = u'test'
+        language.iso639 = language.iso3166 = u'--'
+        language.official = False
+        self.connection.add(language)
+        item.names[u'test'] = u"foo"
+        assert item.names[language] == "foo"
+        assert item.names['test'] == "foo"
+        assert 'de' in item.names
+        assert language in item.names
+        item.names[language] = u"xyzzy"
+        assert item.names[language] == "xyzzy"
+        assert item.names['test'] == "xyzzy"
+
+    @raises(NotImplementedError)
+    def test_delstring(self):
+        item = self.connection.query(tables.Item).filter_by(
+                identifier=u"jade-orb").one()
+        del item.names['en']
+
+    def test_markdown(self):
+        move = self.connection.query(tables.Move).filter_by(
+                identifier=u"thunderbolt").one()
+        assert '10%' in move.effect.as_text
+        assert '10%' in move.effects['en'].as_text