Balls. Last commit was only data. Here's move flag code.
[zzz-pokedex.git] / pokedex / db / load.py
1 """CSV to database or vice versa."""
2 import csv
3 import pkg_resources
4 import sys
5
6 from sqlalchemy.orm.attributes import instrumentation_registry
7 import sqlalchemy.types
8
9 from pokedex.db import metadata
10 import pokedex.db.tables as tables
11
12
13 def _get_verbose_prints(verbose):
14 """If `verbose` is true, returns two functions: one for printing a starting
15 message, and the other for printing a success or failure message when
16 finished.
17
18 If `verbose` is false, returns two no-op functions.
19 """
20
21 if verbose:
22 import sys
23 def print_start(thing):
24 # Truncate to 66 characters, leaving 10 characters for a success
25 # or failure message
26 truncated_thing = thing[0:66]
27
28 # Also, space-pad to keep the cursor in a known column
29 num_spaces = 66 - len(truncated_thing)
30
31 print "%s...%s" % (truncated_thing, ' ' * num_spaces),
32 sys.stdout.flush()
33
34 def print_done(msg='ok'):
35 print msg
36 sys.stdout.flush()
37
38 return print_start, print_done
39
40 # Not verbose; return dummies
41 def dummy(*args, **kwargs):
42 pass
43
44 return dummy, dummy
45
46
47 def load(session, directory=None, drop_tables=False, verbose=False):
48 """Load data from CSV files into the given database session.
49
50 Tables are created automatically.
51
52 `session`
53 SQLAlchemy session to use.
54
55 `directory`
56 Directory the CSV files reside in. Defaults to the `pokedex` data
57 directory.
58
59 `drop_tables`
60 If set to True, existing `pokedex`-related tables will be dropped.
61
62 `verbose`
63 If set to True, status messages will be printed to stdout.
64 """
65
66 # First take care of verbosity
67 print_start, print_done = _get_verbose_prints(verbose)
68
69
70 if not directory:
71 directory = pkg_resources.resource_filename('pokedex', 'data/csv')
72
73 # Drop all tables if requested
74 if drop_tables:
75 print_start('Dropping tables')
76 metadata.drop_all()
77 print_done()
78
79 metadata.create_all()
80
81 # SQLAlchemy is retarded and there is no way for me to get a list of ORM
82 # classes besides to inspect the module they all happen to live in for
83 # things that look right.
84 table_base = tables.TableBase
85 orm_classes = {} # table object => table class
86
87 for name in dir(tables):
88 # dir() returns strings! How /convenient/.
89 thingy = getattr(tables, name)
90
91 if not isinstance(thingy, type):
92 # Not a class; bail
93 continue
94 elif not issubclass(thingy, table_base):
95 # Not a declarative table; bail
96 continue
97 elif thingy == table_base:
98 # Declarative table base, so not a real table; bail
99 continue
100
101 # thingy is definitely a table class! Hallelujah.
102 orm_classes[thingy.__table__] = thingy
103
104 # Okay, run through the tables and actually load the data now
105 for table_obj in metadata.sorted_tables:
106 table_class = orm_classes[table_obj]
107 table_name = table_obj.name
108
109 print_start(table_name)
110
111 try:
112 csvfile = open("%s/%s.csv" % (directory, table_name), 'rb')
113 except IOError:
114 # File doesn't exist; don't load anything!
115 print_done('missing?')
116 continue
117
118 reader = csv.reader(csvfile, lineterminator='\n')
119 column_names = [unicode(column) for column in reader.next()]
120
121 # Self-referential tables may contain rows with foreign keys of other
122 # rows in the same table that do not yet exist. Pull these out and add
123 # them to the session last
124 # ASSUMPTION: Self-referential tables have a single PK called "id"
125 deferred_rows = [] # ( row referring to id, [foreign ids we need] )
126 seen_ids = {} # primary key we've seen => 1
127
128 # Fetch foreign key columns that point at this table, if any
129 self_ref_columns = []
130 for column in table_obj.c:
131 if any(_.references(table_obj) for _ in column.foreign_keys):
132 self_ref_columns.append(column)
133
134 for csvs in reader:
135 row = table_class()
136
137 for column_name, value in zip(column_names, csvs):
138 column = table_obj.c[column_name]
139 if column.nullable and value == '':
140 # Empty string in a nullable column really means NULL
141 value = None
142 elif isinstance(column.type, sqlalchemy.types.Boolean):
143 # Boolean values are stored as string values 0/1, but both
144 # of those evaluate as true; SQLA wants True/False
145 if value == '0':
146 value = False
147 else:
148 value = True
149 else:
150 # Otherwise, unflatten from bytes
151 value = value.decode('utf-8')
152
153 setattr(row, column_name, value)
154
155 # May need to stash this row and add it later if it refers to a
156 # later row in this table
157 if self_ref_columns:
158 foreign_ids = [getattr(row, _.name) for _ in self_ref_columns]
159 foreign_ids = [_ for _ in foreign_ids if _] # remove NULL ids
160
161 if not foreign_ids:
162 # NULL key. Remember this row and add as usual.
163 seen_ids[row.id] = 1
164
165 elif all(_ in seen_ids for _ in foreign_ids):
166 # Non-NULL key we've already seen. Remember it and commit
167 # so we know the old row exists when we add the new one
168 session.commit()
169 seen_ids[row.id] = 1
170
171 else:
172 # Non-NULL future id. Save this and insert it later!
173 deferred_rows.append((row, foreign_ids))
174 continue
175
176 session.add(row)
177
178 # Remembering some zillion rows in the session consumes a lot of
179 # RAM. Let's not do that. Commit every 1000 rows
180 if len(session.new) > 1000:
181 session.commit()
182
183 session.commit()
184
185 # Attempt to add any spare rows we've collected
186 for row, foreign_ids in deferred_rows:
187 if not all(_ in seen_ids for _ in foreign_ids):
188 # Could happen if row A refers to B which refers to C.
189 # This is ridiculous and doesn't happen in my data so far
190 raise ValueError("Too many levels of self-reference! "
191 "Row was: " + str(row.__dict__))
192
193 session.add(row)
194 seen_ids[row.id] = 1
195 session.commit()
196
197 print_done()
198
199
200
201 def dump(session, directory=None, verbose=False):
202 """Dumps the contents of a database to a set of CSV files. Probably not
203 useful to anyone besides a developer.
204
205 `session`
206 SQLAlchemy session to use.
207
208 `directory`
209 Directory the CSV files should be put in. Defaults to the `pokedex`
210 data directory.
211
212 `verbose`
213 If set to True, status messages will be printed to stdout.
214 """
215
216 # First take care of verbosity
217 print_start, print_done = _get_verbose_prints(verbose)
218
219
220 if not directory:
221 directory = pkg_resources.resource_filename('pokedex', 'data/csv')
222
223 for table_name in sorted(metadata.tables.keys()):
224 print_start(table_name)
225 table = metadata.tables[table_name]
226
227 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
228 lineterminator='\n')
229 columns = [col.name for col in table.columns]
230 writer.writerow(columns)
231
232 primary_key = table.primary_key
233 for row in session.query(table).order_by(*primary_key).all():
234 csvs = []
235 for col in columns:
236 # Convert Pythony values to something more universal
237 val = getattr(row, col)
238 if val == None:
239 val = ''
240 elif val == True:
241 val = '1'
242 elif val == False:
243 val = '0'
244 else:
245 val = unicode(val).encode('utf-8')
246
247 csvs.append(val)
248
249 writer.writerow(csvs)
250
251 print_done()