a6e9d4ad3f2ace270ad689b6c88482052d8e4ab0
[zzz-pokedex.git] / pokedex / db / load.py
1 """CSV to database or vice versa."""
2 import csv
3 import fnmatch
4 import os.path
5 import sys
6
7 from sqlalchemy.orm.attributes import instrumentation_registry
8 import sqlalchemy.sql.util
9 import sqlalchemy.types
10
11 from pokedex.db import metadata
12 import pokedex.db.tables as tables
13 from pokedex.defaults import get_default_csv_dir
14
15
16 def _get_table_names(metadata, patterns):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
19 """
20 if patterns:
21 table_names = set()
22 for pattern in patterns:
23 if '.' in pattern or '/' in pattern:
24 # If it looks like a filename, pull out just the table name
25 _, filename = os.path.split(pattern)
26 table_name, _ = os.path.splitext(filename)
27 pattern = table_name
28
29 table_names.update(fnmatch.filter(metadata.tables.keys(), pattern))
30 else:
31 table_names = metadata.tables.keys()
32
33 return list(table_names)
34
35 def _get_verbose_prints(verbose):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
39
40 If `verbose` is false, returns no-op functions.
41 """
42
43 if not verbose:
44 # Return dummies
45 def dummy(*args, **kwargs):
46 pass
47
48 return dummy, dummy, dummy
49
50 ### Okay, verbose == True; print stuff
51
52 def print_start(thing):
53 # Truncate to 66 characters, leaving 10 characters for a success
54 # or failure message
55 truncated_thing = thing[0:66]
56
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces = 66 - len(truncated_thing)
59
60 print "%s...%s" % (truncated_thing, ' ' * num_spaces),
61 sys.stdout.flush()
62
63 if sys.stdout.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
67
68 backspaces = [0]
69 def print_status(msg):
70 # Overwrite any status text with spaces before printing
71 sys.stdout.write('\b' * backspaces[0])
72 sys.stdout.write(' ' * backspaces[0])
73 sys.stdout.write('\b' * backspaces[0])
74 sys.stdout.write(msg)
75 sys.stdout.flush()
76 backspaces[0] = len(msg)
77
78 def print_done(msg='ok'):
79 # Overwrite any status text with spaces before printing
80 sys.stdout.write('\b' * backspaces[0])
81 sys.stdout.write(' ' * backspaces[0])
82 sys.stdout.write('\b' * backspaces[0])
83 sys.stdout.write(msg + "\n")
84 sys.stdout.flush()
85 backspaces[0] = 0
86
87 else:
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg):
90 pass
91
92 def print_done(msg='ok'):
93 print msg
94
95 return print_start, print_status, print_done
96
97
98 def load(session, tables=[], directory=None, drop_tables=False, verbose=False, safe=True):
99 """Load data from CSV files into the given database session.
100
101 Tables are created automatically.
102
103 `session`
104 SQLAlchemy session to use.
105
106 `tables`
107 List of tables to load. If omitted, all tables are loaded.
108
109 `directory`
110 Directory the CSV files reside in. Defaults to the `pokedex` data
111 directory.
112
113 `drop_tables`
114 If set to True, existing `pokedex`-related tables will be dropped.
115
116 `verbose`
117 If set to True, status messages will be printed to stdout.
118
119 `safe`
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
122 """
123
124 # First take care of verbosity
125 print_start, print_status, print_done = _get_verbose_prints(verbose)
126
127
128 if directory is None:
129 directory = get_default_csv_dir()
130
131 table_names = _get_table_names(metadata, tables)
132 table_objs = [metadata.tables[name] for name in table_names]
133 table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
134
135 # SQLite speed tweaks
136 if not safe and session.connection().dialect.name == 'sqlite':
137 session.connection().execute("PRAGMA synchronous=OFF")
138 session.connection().execute("PRAGMA journal_mode=OFF")
139
140 # Drop all tables if requested
141 if drop_tables:
142 print_start('Dropping tables')
143 for n, table in enumerate(reversed(table_objs)):
144 table.drop(checkfirst=True)
145 print_status('%s/%s' % (n, len(table_objs)))
146 print_done()
147
148 print_start('Creating tables')
149 for n, table in enumerate(table_objs):
150 table.create()
151 print_status('%s/%s' % (n, len(table_objs)))
152 print_done()
153 connection = session.connection()
154
155 # Okay, run through the tables and actually load the data now
156 for table_obj in table_objs:
157 table_name = table_obj.name
158 insert_stmt = table_obj.insert()
159
160 print_start(table_name)
161
162 try:
163 csvpath = "%s/%s.csv" % (directory, table_name)
164 csvfile = open(csvpath, 'rb')
165 except IOError:
166 # File doesn't exist; don't load anything!
167 print_done('missing?')
168 continue
169
170 csvsize = os.stat(csvpath).st_size
171
172 reader = csv.reader(csvfile, lineterminator='\n')
173 column_names = [unicode(column) for column in reader.next()]
174
175 if not safe and session.connection().dialect.name == 'postgresql':
176 """
177 Postgres' CSV dialect is nearly the same as ours, except that it
178 treats completely empty values as NULL, and empty quoted
179 strings ("") as an empty strings.
180 Pokedex dump does not quote empty strings. So, both empty strings
181 and NULLs are read in as NULL.
182 For an empty string in a NOT NULL column, the load will fail, and
183 load will fall back to the cross-backend row-by-row loading. And in
184 nullable columns, we already load empty stings as NULL.
185 """
186 session.commit()
187 not_null_cols = [c for c in column_names if not table_obj.c[c].nullable]
188 if not_null_cols:
189 force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols)
190 else:
191 force_not_null = ''
192 command = "COPY {table_name} ({columns}) FROM '{csvpath}' CSV HEADER {force_not_null}"
193 session.connection().execute(
194 command.format(
195 table_name=table_name,
196 csvpath=csvpath,
197 columns=','.join('"%s"' % c for c in column_names),
198 force_not_null=force_not_null,
199 )
200 )
201 session.commit()
202 print_done()
203 continue
204
205 # Self-referential tables may contain rows with foreign keys of other
206 # rows in the same table that do not yet exist. Pull these out and add
207 # them to the session last
208 # ASSUMPTION: Self-referential tables have a single PK called "id"
209 deferred_rows = [] # ( row referring to id, [foreign ids we need] )
210 seen_ids = {} # primary key we've seen => 1
211
212 # Fetch foreign key columns that point at this table, if any
213 self_ref_columns = []
214 for column in table_obj.c:
215 if any(_.references(table_obj) for _ in column.foreign_keys):
216 self_ref_columns.append(column)
217
218 new_rows = []
219 def insert_and_commit():
220 session.connection().execute(insert_stmt, new_rows)
221 session.commit()
222 new_rows[:] = []
223
224 progress = "%d%%" % (100 * csvfile.tell() // csvsize)
225 print_status(progress)
226
227 for csvs in reader:
228 row_data = {}
229
230 for column_name, value in zip(column_names, csvs):
231 column = table_obj.c[column_name]
232 if column.nullable and value == '':
233 # Empty string in a nullable column really means NULL
234 value = None
235 elif isinstance(column.type, sqlalchemy.types.Boolean):
236 # Boolean values are stored as string values 0/1, but both
237 # of those evaluate as true; SQLA wants True/False
238 if value == '0':
239 value = False
240 else:
241 value = True
242 else:
243 # Otherwise, unflatten from bytes
244 value = value.decode('utf-8')
245
246 # nb: Dictionaries flattened with ** have to have string keys
247 row_data[ str(column_name) ] = value
248
249 # May need to stash this row and add it later if it refers to a
250 # later row in this table
251 if self_ref_columns:
252 foreign_ids = [row_data[_.name] for _ in self_ref_columns]
253 foreign_ids = [_ for _ in foreign_ids if _] # remove NULL ids
254
255 if not foreign_ids:
256 # NULL key. Remember this row and add as usual.
257 seen_ids[row_data['id']] = 1
258
259 elif all(_ in seen_ids for _ in foreign_ids):
260 # Non-NULL key we've already seen. Remember it and commit
261 # so we know the old row exists when we add the new one
262 insert_and_commit()
263 seen_ids[row_data['id']] = 1
264
265 else:
266 # Non-NULL future id. Save this and insert it later!
267 deferred_rows.append((row_data, foreign_ids))
268 continue
269
270 # Insert row!
271 new_rows.append(row_data)
272
273 # Remembering some zillion rows in the session consumes a lot of
274 # RAM. Let's not do that. Commit every 1000 rows
275 if len(new_rows) >= 1000:
276 insert_and_commit()
277
278 insert_and_commit()
279
280 # Attempt to add any spare rows we've collected
281 for row_data, foreign_ids in deferred_rows:
282 if not all(_ in seen_ids for _ in foreign_ids):
283 # Could happen if row A refers to B which refers to C.
284 # This is ridiculous and doesn't happen in my data so far
285 raise ValueError("Too many levels of self-reference! "
286 "Row was: " + str(row))
287
288 session.connection().execute(
289 insert_stmt.values(**row_data)
290 )
291 seen_ids[row_data['id']] = 1
292 session.commit()
293
294 print_done()
295
296 # SQLite check
297 if session.connection().dialect.name == 'sqlite':
298 session.connection().execute("PRAGMA integrity_check")
299
300
301
302 def dump(session, tables=[], directory=None, verbose=False):
303 """Dumps the contents of a database to a set of CSV files. Probably not
304 useful to anyone besides a developer.
305
306 `session`
307 SQLAlchemy session to use.
308
309 `tables`
310 List of tables to dump. If omitted, all tables are dumped.
311
312 `directory`
313 Directory the CSV files should be put in. Defaults to the `pokedex`
314 data directory.
315
316 `verbose`
317 If set to True, status messages will be printed to stdout.
318 """
319
320 # First take care of verbosity
321 print_start, print_status, print_done = _get_verbose_prints(verbose)
322
323
324 if not directory:
325 directory = get_default_csv_dir()
326
327 table_names = _get_table_names(metadata, tables)
328 table_names.sort()
329
330
331 for table_name in table_names:
332 print_start(table_name)
333 table = metadata.tables[table_name]
334
335 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
336 lineterminator='\n')
337 columns = [col.name for col in table.columns]
338 writer.writerow(columns)
339
340 primary_key = table.primary_key
341 for row in session.query(table).order_by(*primary_key).all():
342 csvs = []
343 for col in columns:
344 # Convert Pythony values to something more universal
345 val = getattr(row, col)
346 if val == None:
347 val = ''
348 elif val == True:
349 val = '1'
350 elif val == False:
351 val = '0'
352 else:
353 val = unicode(val).encode('utf-8')
354
355 csvs.append(val)
356
357 writer.writerow(csvs)
358
359 print_done()