Replace all_tables by table_classes; get rid of globals()
authorPetr Viktorin <encukou@gmail.com>
Tue, 8 Mar 2011 20:55:47 +0000 (22:55 +0200)
committerEevee <git@veekun.com>
Sun, 13 Mar 2011 22:43:00 +0000 (15:43 -0700)
pokedex/db/tables.py
pokedex/tests/test_schema.py

index 1e3c3a7..88b206a 100644 (file)
@@ -21,7 +21,9 @@ Columns have a info dictionary with these keys:
 import operator
 
 from sqlalchemy import Column, ForeignKey, MetaData, PrimaryKeyConstraint, Table
-from sqlalchemy.ext.declarative import declarative_base, declared_attr
+from sqlalchemy.ext.declarative import (
+        declarative_base, declared_attr, DeclarativeMeta,
+    )
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.orm import backref, eagerload_all, relation, class_mapper
 from sqlalchemy.orm.session import Session
@@ -32,8 +34,17 @@ from inspect import isclass
 
 from pokedex.db import markdown
 
+# A list of all table classes will live in table_classes
+table_classes = []
+
+class TableMetaclass(DeclarativeMeta):
+    def __init__(cls, name, bases, attrs):
+        super(TableMetaclass, cls).__init__(name, bases, attrs)
+        if hasattr(cls, '__tablename__'):
+            table_classes.append(cls)
+
 metadata = MetaData()
-TableBase = declarative_base(metadata=metadata)
+TableBase = declarative_base(metadata=metadata, metaclass=TableMetaclass)
 
 ### Helper classes
 class Named(object):
@@ -1737,16 +1748,9 @@ VersionGroup.version_group_regions = relation(VersionGroupRegion, backref='versi
 VersionGroup.regions = association_proxy('version_group_regions', 'region')
 VersionGroup.pokedex = relation(Pokedex, back_populates='version_groups')
 
-### Convenience function
-def all_tables():
-    u"""Yields all tables in the pokédex"""
-    for table in set(t for t in globals().values() if isclass(t)):
-        if issubclass(table, TableBase) and table is not TableBase:
-            yield table
-
 
 ### Add name tables
-for table in all_tables():
+for table in list(table_classes):
     if issubclass(table, OfficiallyNamed):
         cls = TextColumn
         info=dict(description="The name", format='plaintext', official=True)
@@ -1832,7 +1836,7 @@ def makeTextTable(object_table, name_plural, name_singular, columns, lazy):
 
     return Strings
 
-for table in all_tables():
+for table in list(table_classes):
     # Find all the language-specific columns, keeping them in the order they
     # were defined
     all_columns = []
@@ -1857,12 +1861,10 @@ for table in all_tables():
 
     if text_columns:
         string_table = makeTextTable(table, 'texts', 'text', text_columns, lazy=False)
-        globals()[string_table.__name__] = string_table
     if prose_columns:
         string_table = makeTextTable(table, 'prose', 'prose', prose_columns, lazy=True)
-        globals()[string_table.__name__] = string_table
 
 ### Add language relations
-for table in all_tables():
+for table in list(table_classes):
     if issubclass(table, LanguageSpecific):
         table.language = relation(Language, primaryjoin=table.language_id == Language.id)
index 5a03777..f1583a7 100644 (file)
@@ -19,7 +19,7 @@ def test_variable_names():
         classname = table.__name__
         if classname and varname[0].isupper():
             assert varname == classname, '%s refers to %s' % (varname, classname)
-    for table in tables.all_tables():
+    for table in tables.table_classes:
         assert getattr(tables, table.__name__) is table
 
 def test_texts():
@@ -28,7 +28,7 @@ def test_texts():
     Mostly protects against copy/paste oversights and rebase hiccups.
     If there's a reason to relax the tests, do it
     """
-    for table in sorted(tables.all_tables(), key=lambda t: t.__name__):
+    for table in sorted(tables.table_classes, key=lambda t: t.__name__):
         if issubclass(table, tables.LanguageSpecific):
             good_formats = 'markdown plaintext gametext'.split()
             assert_text = '%s is language-specific'
@@ -61,7 +61,7 @@ def test_identifiers_with_names():
 
     ...have either names or identifiers.
     """
-    for table in sorted(tables.all_tables(), key=lambda t: t.__name__):
+    for table in sorted(tables.table_classes, key=lambda t: t.__name__):
         if issubclass(table, tables.Named):
             assert issubclass(table, tables.OfficiallyNamed) or issubclass(table, tables.UnofficiallyNamed), table
             assert hasattr(table, 'identifier'), table