4512bf6c95cc464c03a541f3e6969498b33a3a31
[zzz-pokedex.git] / pokedex / db / load.py
1 """CSV to database or vice versa."""
2 import csv
3 import fnmatch
4 import os.path
5 import sys
6
7 from sqlalchemy.orm.attributes import instrumentation_registry
8 import sqlalchemy.sql.util
9 import sqlalchemy.types
10
11 from pokedex.db import metadata
12 import pokedex.db.tables as tables
13 from pokedex.defaults import get_default_csv_dir
14
15
16 def _get_table_names(metadata, patterns):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
19 """
20 if patterns:
21 table_names = set()
22 for pattern in patterns:
23 if '.' in pattern or '/' in pattern:
24 # If it looks like a filename, pull out just the table name
25 _, filename = os.path.split(pattern)
26 table_name, _ = os.path.splitext(filename)
27 pattern = table_name
28
29 table_names.update(fnmatch.filter(metadata.tables.keys(), pattern))
30 else:
31 table_names = metadata.tables.keys()
32
33 return list(table_names)
34
35 def _get_verbose_prints(verbose):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
39
40 If `verbose` is false, returns no-op functions.
41 """
42
43 if not verbose:
44 # Return dummies
45 def dummy(*args, **kwargs):
46 pass
47
48 return dummy, dummy, dummy
49
50 ### Okay, verbose == True; print stuff
51
52 def print_start(thing):
53 # Truncate to 66 characters, leaving 10 characters for a success
54 # or failure message
55 truncated_thing = thing[0:66]
56
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces = 66 - len(truncated_thing)
59
60 print "%s...%s" % (truncated_thing, ' ' * num_spaces),
61 sys.stdout.flush()
62
63 if sys.stdout.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
67
68 backspaces = [0]
69 def print_status(msg):
70 # Overwrite any status text with spaces before printing
71 sys.stdout.write('\b' * backspaces[0])
72 sys.stdout.write(' ' * backspaces[0])
73 sys.stdout.write('\b' * backspaces[0])
74 sys.stdout.write(msg)
75 sys.stdout.flush()
76 backspaces[0] = len(msg)
77
78 def print_done(msg='ok'):
79 # Overwrite any status text with spaces before printing
80 sys.stdout.write('\b' * backspaces[0])
81 sys.stdout.write(' ' * backspaces[0])
82 sys.stdout.write('\b' * backspaces[0])
83 sys.stdout.write(msg + "\n")
84 sys.stdout.flush()
85 backspaces[0] = 0
86
87 else:
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg):
90 pass
91
92 def print_done(msg='ok'):
93 print msg
94
95 return print_start, print_status, print_done
96
97
98 def load(session, tables=[], directory=None, drop_tables=False, verbose=False, safe=True):
99 """Load data from CSV files into the given database session.
100
101 Tables are created automatically.
102
103 `session`
104 SQLAlchemy session to use.
105
106 `tables`
107 List of tables to load. If omitted, all tables are loaded.
108
109 `directory`
110 Directory the CSV files reside in. Defaults to the `pokedex` data
111 directory.
112
113 `drop_tables`
114 If set to True, existing `pokedex`-related tables will be dropped.
115
116 `verbose`
117 If set to True, status messages will be printed to stdout.
118
119 `safe`
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
122 """
123
124 # First take care of verbosity
125 print_start, print_status, print_done = _get_verbose_prints(verbose)
126
127
128 if directory is None:
129 directory = get_default_csv_dir()
130
131 table_names = _get_table_names(metadata, tables)
132 table_objs = [metadata.tables[name] for name in table_names]
133 table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
134
135 # SQLite speed tweaks
136 if not safe and session.connection().dialect.name == 'sqlite':
137 session.connection().execute("PRAGMA synchronous=OFF")
138 session.connection().execute("PRAGMA journal_mode=OFF")
139
140 # Drop all tables if requested
141 if drop_tables:
142 print_start('Dropping tables')
143 for table in reversed(table_objs):
144 table.drop(checkfirst=True)
145 print_done()
146
147 for table in table_objs:
148 table.create()
149 connection = session.connection()
150
151 # Okay, run through the tables and actually load the data now
152 for table_obj in table_objs:
153 table_name = table_obj.name
154 insert_stmt = table_obj.insert()
155
156 print_start(table_name)
157
158 try:
159 csvpath = "%s/%s.csv" % (directory, table_name)
160 csvfile = open(csvpath, 'rb')
161 except IOError:
162 # File doesn't exist; don't load anything!
163 print_done('missing?')
164 continue
165
166 csvsize = os.stat(csvpath).st_size
167
168 reader = csv.reader(csvfile, lineterminator='\n')
169 column_names = [unicode(column) for column in reader.next()]
170
171 if not safe and session.connection().dialect.name == 'postgresql':
172 """
173 Postgres' CSV dialect is nearly the same as ours, except that it
174 treats completely empty values as NULL, and empty quoted
175 strings ("") as an empty strings.
176 Pokedex dump does not quote empty strings. So, both empty strings
177 and NULLs are read in as NULL.
178 For an empty string in a NOT NULL column, the load will fail, and
179 load will fall back to the cross-backend row-by-row loading. And in
180 nullable columns, we already load empty stings as NULL.
181 """
182 session.commit()
183 not_null_cols = [c for c in column_names if not table_obj.c[c].nullable]
184 if not_null_cols:
185 force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols)
186 else:
187 force_not_null = ''
188 command = "COPY {table_name} ({columns}) FROM '{csvpath}' CSV HEADER {force_not_null}"
189 session.connection().execute(
190 command.format(
191 table_name=table_name,
192 csvpath=csvpath,
193 columns=','.join('"%s"' % c for c in column_names),
194 force_not_null=force_not_null,
195 )
196 )
197 session.commit()
198 print_done()
199 continue
200
201 # Self-referential tables may contain rows with foreign keys of other
202 # rows in the same table that do not yet exist. Pull these out and add
203 # them to the session last
204 # ASSUMPTION: Self-referential tables have a single PK called "id"
205 deferred_rows = [] # ( row referring to id, [foreign ids we need] )
206 seen_ids = {} # primary key we've seen => 1
207
208 # Fetch foreign key columns that point at this table, if any
209 self_ref_columns = []
210 for column in table_obj.c:
211 if any(_.references(table_obj) for _ in column.foreign_keys):
212 self_ref_columns.append(column)
213
214 new_rows = []
215 def insert_and_commit():
216 session.connection().execute(insert_stmt, new_rows)
217 session.commit()
218 new_rows[:] = []
219
220 progress = "%d%%" % (100 * csvfile.tell() // csvsize)
221 print_status(progress)
222
223 for csvs in reader:
224 row_data = {}
225
226 for column_name, value in zip(column_names, csvs):
227 column = table_obj.c[column_name]
228 if column.nullable and value == '':
229 # Empty string in a nullable column really means NULL
230 value = None
231 elif isinstance(column.type, sqlalchemy.types.Boolean):
232 # Boolean values are stored as string values 0/1, but both
233 # of those evaluate as true; SQLA wants True/False
234 if value == '0':
235 value = False
236 else:
237 value = True
238 else:
239 # Otherwise, unflatten from bytes
240 value = value.decode('utf-8')
241
242 # nb: Dictionaries flattened with ** have to have string keys
243 row_data[ str(column_name) ] = value
244
245 # May need to stash this row and add it later if it refers to a
246 # later row in this table
247 if self_ref_columns:
248 foreign_ids = [row_data[_.name] for _ in self_ref_columns]
249 foreign_ids = [_ for _ in foreign_ids if _] # remove NULL ids
250
251 if not foreign_ids:
252 # NULL key. Remember this row and add as usual.
253 seen_ids[row_data['id']] = 1
254
255 elif all(_ in seen_ids for _ in foreign_ids):
256 # Non-NULL key we've already seen. Remember it and commit
257 # so we know the old row exists when we add the new one
258 insert_and_commit()
259 seen_ids[row_data['id']] = 1
260
261 else:
262 # Non-NULL future id. Save this and insert it later!
263 deferred_rows.append((row_data, foreign_ids))
264 continue
265
266 # Insert row!
267 new_rows.append(row_data)
268
269 # Remembering some zillion rows in the session consumes a lot of
270 # RAM. Let's not do that. Commit every 1000 rows
271 if len(new_rows) >= 1000:
272 insert_and_commit()
273
274 insert_and_commit()
275
276 # Attempt to add any spare rows we've collected
277 for row_data, foreign_ids in deferred_rows:
278 if not all(_ in seen_ids for _ in foreign_ids):
279 # Could happen if row A refers to B which refers to C.
280 # This is ridiculous and doesn't happen in my data so far
281 raise ValueError("Too many levels of self-reference! "
282 "Row was: " + str(row))
283
284 session.connection().execute(
285 insert_stmt.values(**row_data)
286 )
287 seen_ids[row_data['id']] = 1
288 session.commit()
289
290 print_done()
291
292 # SQLite check
293 if session.connection().dialect.name == 'sqlite':
294 session.connection().execute("PRAGMA integrity_check")
295
296
297
298 def dump(session, tables=[], directory=None, verbose=False):
299 """Dumps the contents of a database to a set of CSV files. Probably not
300 useful to anyone besides a developer.
301
302 `session`
303 SQLAlchemy session to use.
304
305 `tables`
306 List of tables to dump. If omitted, all tables are dumped.
307
308 `directory`
309 Directory the CSV files should be put in. Defaults to the `pokedex`
310 data directory.
311
312 `verbose`
313 If set to True, status messages will be printed to stdout.
314 """
315
316 # First take care of verbosity
317 print_start, print_status, print_done = _get_verbose_prints(verbose)
318
319
320 if not directory:
321 directory = get_default_csv_dir()
322
323 table_names = _get_table_names(metadata, tables)
324 table_names.sort()
325
326
327 for table_name in table_names:
328 print_start(table_name)
329 table = metadata.tables[table_name]
330
331 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
332 lineterminator='\n')
333 columns = [col.name for col in table.columns]
334 writer.writerow(columns)
335
336 primary_key = table.primary_key
337 for row in session.query(table).order_by(*primary_key).all():
338 csvs = []
339 for col in columns:
340 # Convert Pythony values to something more universal
341 val = getattr(row, col)
342 if val == None:
343 val = ''
344 elif val == True:
345 val = '1'
346 elif val == False:
347 val = '0'
348 else:
349 val = unicode(val).encode('utf-8')
350
351 csvs.append(val)
352
353 writer.writerow(csvs)
354
355 print_done()