Tiny fix for CLI help.
[zzz-pokedex.git] / pokedex / db / load.py
1 """CSV to database or vice versa."""
2 import csv
3 import os.path
4 import pkg_resources
5 import re
6 import sys
7
8 from sqlalchemy.orm.attributes import instrumentation_registry
9 import sqlalchemy.sql.util
10 import sqlalchemy.types
11
12 from pokedex.db import metadata
13 import pokedex.db.tables as tables
14
15
16 def _wildcard_char_to_regex(char):
17 """Converts a single wildcard character to the regex equivalent."""
18
19 if char == '?':
20 return '.?'
21 elif char == '*':
22 return '.*'
23 else:
24 return re.escape(char)
25
26 def _wildcard_glob_to_regex(glob):
27 """Converts a single wildcard glob to a regex STRING."""
28
29 # If it looks like a filename, make it not one
30 if '.' in glob or '/' in glob:
31 _, filename = os.path.split(glob)
32 table_name, _ = os.path.splitext(filename)
33 glob = table_name
34
35 return u''.join(map(_wildcard_char_to_regex, glob))
36
37 def _wildcards_to_regex(strings):
38 """Converts a list of wildcard globs to a single regex object."""
39
40 regex_parts = map(_wildcard_glob_to_regex, strings)
41
42 regex = '^(?:' + '|'.join(regex_parts) + ')$'
43
44 return re.compile(regex)
45
46
47 def _get_verbose_prints(verbose):
48 """If `verbose` is true, returns three functions: one for printing a
49 starting message, one for printing an interim status update, and one for
50 printing a success or failure message when finished.
51
52 If `verbose` is false, returns no-op functions.
53 """
54
55 if not verbose:
56 # Return dummies
57 def dummy(*args, **kwargs):
58 pass
59
60 return dummy, dummy, dummy
61
62 ### Okay, verbose == True; print stuff
63
64 def print_start(thing):
65 # Truncate to 66 characters, leaving 10 characters for a success
66 # or failure message
67 truncated_thing = thing[0:66]
68
69 # Also, space-pad to keep the cursor in a known column
70 num_spaces = 66 - len(truncated_thing)
71
72 print "%s...%s" % (truncated_thing, ' ' * num_spaces),
73 sys.stdout.flush()
74
75 if sys.stdout.isatty():
76 # stdout is a terminal; stupid backspace tricks are OK.
77 # Don't use print, because it always adds magical spaces, which
78 # makes backspace accounting harder
79
80 backspaces = [0]
81 def print_status(msg):
82 # Overwrite any status text with spaces before printing
83 sys.stdout.write('\b' * backspaces[0])
84 sys.stdout.write(' ' * backspaces[0])
85 sys.stdout.write('\b' * backspaces[0])
86 sys.stdout.write(msg)
87 sys.stdout.flush()
88 backspaces[0] = len(msg)
89
90 def print_done(msg='ok'):
91 # Overwrite any status text with spaces before printing
92 sys.stdout.write('\b' * backspaces[0])
93 sys.stdout.write(' ' * backspaces[0])
94 sys.stdout.write('\b' * backspaces[0])
95 sys.stdout.write(msg + "\n")
96 sys.stdout.flush()
97 backspaces[0] = 0
98
99 else:
100 # stdout is a file (or something); don't bother with status at all
101 def print_status(msg):
102 pass
103
104 def print_done(msg='ok'):
105 print msg
106
107 return print_start, print_status, print_done
108
109
110 def load(session, tables=[], directory=None, drop_tables=False, verbose=False):
111 """Load data from CSV files into the given database session.
112
113 Tables are created automatically.
114
115 `session`
116 SQLAlchemy session to use.
117
118 `tables`
119 List of tables to load. If omitted, all tables are loaded.
120
121 `directory`
122 Directory the CSV files reside in. Defaults to the `pokedex` data
123 directory.
124
125 `drop_tables`
126 If set to True, existing `pokedex`-related tables will be dropped.
127
128 `verbose`
129 If set to True, status messages will be printed to stdout.
130 """
131
132 # First take care of verbosity
133 print_start, print_status, print_done = _get_verbose_prints(verbose)
134
135
136 if not directory:
137 directory = pkg_resources.resource_filename('pokedex', 'data/csv')
138
139 if tables:
140 regex = _wildcards_to_regex(tables)
141 table_names = filter(regex.match, metadata.tables.keys())
142 else:
143 table_names = metadata.tables.keys()
144
145 table_objs = [metadata.tables[name] for name in table_names]
146 table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
147
148
149 # Drop all tables if requested
150 if drop_tables:
151 print_start('Dropping tables')
152 for table in reversed(table_objs):
153 table.drop(checkfirst=True)
154 print_done()
155
156 for table in table_objs:
157 table.create()
158 connection = session.connection()
159
160 # Okay, run through the tables and actually load the data now
161 for table_obj in table_objs:
162 table_name = table_obj.name
163 insert_stmt = table_obj.insert()
164
165 print_start(table_name)
166
167 try:
168 csvpath = "%s/%s.csv" % (directory, table_name)
169 csvfile = open(csvpath, 'rb')
170 except IOError:
171 # File doesn't exist; don't load anything!
172 print_done('missing?')
173 continue
174
175 csvsize = os.stat(csvpath).st_size
176
177 reader = csv.reader(csvfile, lineterminator='\n')
178 column_names = [unicode(column) for column in reader.next()]
179
180 # Self-referential tables may contain rows with foreign keys of other
181 # rows in the same table that do not yet exist. Pull these out and add
182 # them to the session last
183 # ASSUMPTION: Self-referential tables have a single PK called "id"
184 deferred_rows = [] # ( row referring to id, [foreign ids we need] )
185 seen_ids = {} # primary key we've seen => 1
186
187 # Fetch foreign key columns that point at this table, if any
188 self_ref_columns = []
189 for column in table_obj.c:
190 if any(_.references(table_obj) for _ in column.foreign_keys):
191 self_ref_columns.append(column)
192
193 new_rows = []
194 def insert_and_commit():
195 session.connection().execute(insert_stmt, new_rows)
196 session.commit()
197 new_rows[:] = []
198
199 progress = "{0}%".format(100 * csvfile.tell() // csvsize)
200 print_status(progress)
201
202 for csvs in reader:
203 row_data = {}
204
205 for column_name, value in zip(column_names, csvs):
206 column = table_obj.c[column_name]
207 if column.nullable and value == '':
208 # Empty string in a nullable column really means NULL
209 value = None
210 elif isinstance(column.type, sqlalchemy.types.Boolean):
211 # Boolean values are stored as string values 0/1, but both
212 # of those evaluate as true; SQLA wants True/False
213 if value == '0':
214 value = False
215 else:
216 value = True
217 else:
218 # Otherwise, unflatten from bytes
219 value = value.decode('utf-8')
220
221 # nb: Dictionaries flattened with ** have to have string keys
222 row_data[ str(column_name) ] = value
223
224 # May need to stash this row and add it later if it refers to a
225 # later row in this table
226 if self_ref_columns:
227 foreign_ids = [row_data[_.name] for _ in self_ref_columns]
228 foreign_ids = [_ for _ in foreign_ids if _] # remove NULL ids
229
230 if not foreign_ids:
231 # NULL key. Remember this row and add as usual.
232 seen_ids[row_data['id']] = 1
233
234 elif all(_ in seen_ids for _ in foreign_ids):
235 # Non-NULL key we've already seen. Remember it and commit
236 # so we know the old row exists when we add the new one
237 insert_and_commit()
238 seen_ids[row_data['id']] = 1
239
240 else:
241 # Non-NULL future id. Save this and insert it later!
242 deferred_rows.append((row_data, foreign_ids))
243 continue
244
245 # Insert row!
246 new_rows.append(row_data)
247
248 # Remembering some zillion rows in the session consumes a lot of
249 # RAM. Let's not do that. Commit every 1000 rows
250 if len(new_rows) >= 1000:
251 insert_and_commit()
252
253 insert_and_commit()
254
255 # Attempt to add any spare rows we've collected
256 for row_data, foreign_ids in deferred_rows:
257 if not all(_ in seen_ids for _ in foreign_ids):
258 # Could happen if row A refers to B which refers to C.
259 # This is ridiculous and doesn't happen in my data so far
260 raise ValueError("Too many levels of self-reference! "
261 "Row was: " + str(row))
262
263 session.connection().execute(
264 insert_stmt.values(**row_data)
265 )
266 seen_ids[row_data['id']] = 1
267 session.commit()
268
269 print_done()
270
271
272
273 def dump(session, tables=[], directory=None, verbose=False):
274 """Dumps the contents of a database to a set of CSV files. Probably not
275 useful to anyone besides a developer.
276
277 `session`
278 SQLAlchemy session to use.
279
280 `tables`
281 List of tables to dump. If omitted, all tables are dumped.
282
283 `directory`
284 Directory the CSV files should be put in. Defaults to the `pokedex`
285 data directory.
286
287 `verbose`
288 If set to True, status messages will be printed to stdout.
289 """
290
291 # First take care of verbosity
292 print_start, print_status, print_done = _get_verbose_prints(verbose)
293
294
295 if not directory:
296 directory = pkg_resources.resource_filename('pokedex', 'data/csv')
297
298 if tables:
299 regex = _wildcards_to_regex(tables)
300 table_names = filter(regex.match, metadata.tables.keys())
301 else:
302 table_names = metadata.tables.keys()
303
304 table_names.sort()
305
306
307 for table_name in table_names:
308 print_start(table_name)
309 table = metadata.tables[table_name]
310
311 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
312 lineterminator='\n')
313 columns = [col.name for col in table.columns]
314 writer.writerow(columns)
315
316 primary_key = table.primary_key
317 for row in session.query(table).order_by(*primary_key).all():
318 csvs = []
319 for col in columns:
320 # Convert Pythony values to something more universal
321 val = getattr(row, col)
322 if val == None:
323 val = ''
324 elif val == True:
325 val = '1'
326 elif val == False:
327 val = '0'
328 else:
329 val = unicode(val).encode('utf-8')
330
331 csvs.append(val)
332
333 writer.writerow(csvs)
334
335 print_done()