From cca6737b7e79efb4d0fd28798b883e63f793ff03 Mon Sep 17 00:00:00 2001 From: Eevee Date: Thu, 28 May 2009 21:16:18 -0700 Subject: [PATCH] Fixed csvimport to load in table dependency order. --- pokedex/__init__.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/pokedex/__init__.py b/pokedex/__init__.py index c24a5db..6ac7344 100644 --- a/pokedex/__init__.py +++ b/pokedex/__init__.py @@ -29,17 +29,11 @@ def csvimport(engine_uri, directory='.'): metadata.create_all() - # Oh, mysql-chan. - # TODO try to insert data in preorder so we don't need this hack and won't - # break similarly on other engines - if 'mysql' in engine_uri: - session.execute('SET FOREIGN_KEY_CHECKS = 0') - # SQLAlchemy is retarded and there is no way for me to get a list of ORM # classes besides to inspect the module they all happen to live in for # things that look right. table_base = tables_module.TableBase - orm_classes = {} + orm_classes = {} # table object => table class for name in dir(tables_module): # dir() returns strings! How /convenient/. @@ -56,10 +50,13 @@ def csvimport(engine_uri, directory='.'): continue # thingy is definitely a table class! Hallelujah. - orm_classes[thingy.__table__.name] = thingy + orm_classes[thingy.__table__] = thingy # Okay, run through the tables and actually load the data now - for table_name, table in sorted(orm_classes.items()): + for table_obj in metadata.sorted_tables: + table_class = orm_classes[table_obj] + table_name = table_obj.name + # Print the table name but leave the cursor in a fixed column print table_name + '...', ' ' * (40 - len(table_name)), @@ -74,10 +71,10 @@ def csvimport(engine_uri, directory='.'): column_names = [unicode(column) for column in reader.next()] for csvs in reader: - row = table() + row = table_class() for column_name, value in zip(column_names, csvs): - column = table.__table__.c[column_name] + column = table_obj.c[column_name] if column.nullable and value == '': # Empty string in a nullable column really means NULL value = None @@ -99,11 +96,6 @@ def csvimport(engine_uri, directory='.'): session.commit() print 'loaded' - # Shouldn't matter since this is usually the end of the program and thus - # the connection too, but let's change this back just in case - if 'mysql' in engine_uri: - session.execute('SET FOREIGN_KEY_CHECKS = 1') - def csvexport(engine_uri, directory='.'): import csv -- 2.7.4