1cff39d5a72ffa2934b16b9d2ce4986a0b3bd325
[zzz-pokedex.git] / pokedex / db / load.py
1 """CSV to database or vice versa."""
2 import csv
3 import pkg_resources
4 import sys
5
6 from sqlalchemy.orm.attributes import instrumentation_registry
7 import sqlalchemy.types
8
9 from pokedex.db import metadata
10 import pokedex.db.tables as tables
11
12
13 def load(session, directory=None, drop_tables=False):
14 """Load data from CSV files into the given database session.
15
16 Tables are created automatically.
17
18 `session`
19 SQLAlchemy session to use.
20
21 `directory`
22 Directory the CSV files reside in. Defaults to the `pokedex` data
23 directory.
24
25 `drop_tables`
26 If set to True, existing `pokedex`-related tables will be dropped.
27 """
28
29 if not directory:
30 directory = pkg_resources.resource_filename('pokedex', 'data/csv')
31
32 # Drop all tables if requested
33 if options.drop_tables:
34 print 'Dropping tables...'
35 metadata.drop_all()
36
37 metadata.create_all()
38
39 # SQLAlchemy is retarded and there is no way for me to get a list of ORM
40 # classes besides to inspect the module they all happen to live in for
41 # things that look right.
42 table_base = tables.TableBase
43 orm_classes = {} # table object => table class
44
45 for name in dir(tables):
46 # dir() returns strings! How /convenient/.
47 thingy = getattr(tables, name)
48
49 if not isinstance(thingy, type):
50 # Not a class; bail
51 continue
52 elif not issubclass(thingy, table_base):
53 # Not a declarative table; bail
54 continue
55 elif thingy == table_base:
56 # Declarative table base, so not a real table; bail
57 continue
58
59 # thingy is definitely a table class! Hallelujah.
60 orm_classes[thingy.__table__] = thingy
61
62 # Okay, run through the tables and actually load the data now
63 for table_obj in metadata.sorted_tables:
64 table_class = orm_classes[table_obj]
65 table_name = table_obj.name
66
67 # Print the table name but leave the cursor in a fixed column
68 print table_name + '...', ' ' * (40 - len(table_name)),
69 sys.stdout.flush()
70
71 try:
72 csvfile = open("%s/%s.csv" % (directory, table_name), 'rb')
73 except IOError:
74 # File doesn't exist; don't load anything!
75 print 'no data!'
76 continue
77
78 reader = csv.reader(csvfile, lineterminator='\n')
79 column_names = [unicode(column) for column in reader.next()]
80
81 # Self-referential tables may contain rows with foreign keys of other
82 # rows in the same table that do not yet exist. Pull these out and add
83 # them to the session last
84 # ASSUMPTION: Self-referential tables have a single PK called "id"
85 deferred_rows = [] # ( row referring to id, [foreign ids we need] )
86 seen_ids = {} # primary key we've seen => 1
87
88 # Fetch foreign key columns that point at this table, if any
89 self_ref_columns = []
90 for column in table_obj.c:
91 if any(_.references(table_obj) for _ in column.foreign_keys):
92 self_ref_columns.append(column)
93
94 for csvs in reader:
95 row = table_class()
96
97 for column_name, value in zip(column_names, csvs):
98 column = table_obj.c[column_name]
99 if column.nullable and value == '':
100 # Empty string in a nullable column really means NULL
101 value = None
102 elif isinstance(column.type, sqlalchemy.types.Boolean):
103 # Boolean values are stored as string values 0/1, but both
104 # of those evaluate as true; SQLA wants True/False
105 if value == '0':
106 value = False
107 else:
108 value = True
109 else:
110 # Otherwise, unflatten from bytes
111 value = value.decode('utf-8')
112
113 setattr(row, column_name, value)
114
115 # May need to stash this row and add it later if it refers to a
116 # later row in this table
117 if self_ref_columns:
118 foreign_ids = [getattr(row, _.name) for _ in self_ref_columns]
119 foreign_ids = [_ for _ in foreign_ids if _] # remove NULL ids
120
121 if not foreign_ids:
122 # NULL key. Remember this row and add as usual.
123 seen_ids[row.id] = 1
124
125 elif all(_ in seen_ids for _ in foreign_ids):
126 # Non-NULL key we've already seen. Remember it and commit
127 # so we know the old row exists when we add the new one
128 session.commit()
129 seen_ids[row.id] = 1
130
131 else:
132 # Non-NULL future id. Save this and insert it later!
133 deferred_rows.append((row, foreign_ids))
134 continue
135
136 session.add(row)
137
138 session.commit()
139
140 # Attempt to add any spare rows we've collected
141 for row, foreign_ids in deferred_rows:
142 if not all(_ in seen_ids for _ in foreign_ids):
143 # Could happen if row A refers to B which refers to C.
144 # This is ridiculous and doesn't happen in my data so far
145 raise ValueError("Too many levels of self-reference! "
146 "Row was: " + str(row.__dict__))
147
148 session.add(row)
149 seen_ids[row.id] = 1
150 session.commit()
151
152 print 'loaded'
153
154
155
156 def dump(session, directory=None):
157 """Dumps the contents of a database to a set of CSV files. Probably not
158 useful to anyone besides a developer.
159
160 `session`
161 SQLAlchemy session to use.
162
163 `directory`
164 Directory the CSV files should be put in. Defaults to the `pokedex`
165 data directory.
166 """
167
168 if not directory:
169 directory = pkg_resources.resource_filename('pokedex', 'data/csv')
170
171 for table_name in sorted(metadata.tables.keys()):
172 print table_name
173 table = metadata.tables[table_name]
174
175 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
176 lineterminator='\n')
177 columns = [col.name for col in table.columns]
178 writer.writerow(columns)
179
180 primary_key = table.primary_key
181 for row in session.query(table).order_by(*primary_key).all():
182 csvs = []
183 for col in columns:
184 # Convert Pythony values to something more universal
185 val = getattr(row, col)
186 if val == None:
187 val = ''
188 elif val == True:
189 val = '1'
190 elif val == False:
191 val = '0'
192 else:
193 val = unicode(val).encode('utf-8')
194
195 csvs.append(val)
196
197 writer.writerow(csvs)