1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
12 from pokedex
.db
import metadata
, tables
, translations
13 from pokedex
.defaults
import get_default_csv_dir
14 from pokedex
.db
.dependencies
import find_dependent_tables
17 def _get_table_names(metadata
, patterns
):
18 """Returns a list of table names from the given metadata. If `patterns`
19 exists, only tables matching one of the patterns will be returned.
23 for pattern
in patterns
:
24 if '.' in pattern
or '/' in pattern
:
25 # If it looks like a filename, pull out just the table name
26 _
, filename
= os
.path
.split(pattern
)
27 table_name
, _
= os
.path
.splitext(filename
)
30 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
32 table_names
= metadata
.tables
.keys()
34 return list(table_names
)
36 def _get_verbose_prints(verbose
):
37 """If `verbose` is true, returns three functions: one for printing a
38 starting message, one for printing an interim status update, and one for
39 printing a success or failure message when finished.
41 If `verbose` is false, returns no-op functions.
46 def dummy(*args
, **kwargs
):
49 return dummy
, dummy
, dummy
51 ### Okay, verbose == True; print stuff
53 def print_start(thing
):
54 # Truncate to 66 characters, leaving 10 characters for a success
56 truncated_thing
= thing
[:66]
58 # Also, space-pad to keep the cursor in a known column
59 num_spaces
= 66 - len(truncated_thing
)
61 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
64 if sys
.stdout
.isatty():
65 # stdout is a terminal; stupid backspace tricks are OK.
66 # Don't use print, because it always adds magical spaces, which
67 # makes backspace accounting harder
70 def print_status(msg
):
71 # Overwrite any status text with spaces before printing
72 sys
.stdout
.write('\b' * backspaces
[0])
73 sys
.stdout
.write(' ' * backspaces
[0])
74 sys
.stdout
.write('\b' * backspaces
[0])
77 backspaces
[0] = len(msg
)
79 def print_done(msg
='ok'):
80 # Overwrite any status text with spaces before printing
81 sys
.stdout
.write('\b' * backspaces
[0])
82 sys
.stdout
.write(' ' * backspaces
[0])
83 sys
.stdout
.write('\b' * backspaces
[0])
84 sys
.stdout
.write(msg
+ "\n")
89 # stdout is a file (or something); don't bother with status at all
90 def print_status(msg
):
93 def print_done(msg
='ok'):
96 return print_start
, print_status
, print_done
99 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True, recursive
=True, langs
=None):
100 """Load data from CSV files into the given database session.
102 Tables are created automatically.
105 SQLAlchemy session to use.
108 List of tables to load. If omitted, all tables are loaded.
111 Directory the CSV files reside in. Defaults to the `pokedex` data
115 If set to True, existing `pokedex`-related tables will be dropped.
118 If set to True, status messages will be printed to stdout.
121 If set to False, load can be faster, but can corrupt the database if
122 it crashes or is interrupted.
125 If set to True, load all dependent tables too.
128 List of identifiers of extra language to load, or None to load them all
131 # First take care of verbosity
132 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
135 if directory
is None:
136 directory
= get_default_csv_dir()
138 # XXX why isn't this done in command_load
139 table_names
= _get_table_names(metadata
, tables
)
140 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
143 table_objs
.extend(find_dependent_tables(table_objs
))
145 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
147 # SQLite speed tweaks
148 if not safe
and session
.connection().dialect
.name
== 'sqlite':
149 session
.connection().execute("PRAGMA synchronous=OFF")
150 session
.connection().execute("PRAGMA journal_mode=OFF")
152 # Drop all tables if requested
154 print_start('Dropping tables')
155 for n
, table
in enumerate(reversed(table_objs
)):
156 table
.drop(checkfirst
=True)
157 print_status('%s/%s' %
(n
, len(table_objs
)))
160 print_start('Creating tables')
161 for n
, table
in enumerate(table_objs
):
163 print_status('%s/%s' %
(n
, len(table_objs
)))
165 connection
= session
.connection()
167 # Okay, run through the tables and actually load the data now
168 for table_obj
in table_objs
:
169 table_name
= table_obj
.name
170 insert_stmt
= table_obj
.insert()
172 print_start(table_name
)
175 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
176 csvfile
= open(csvpath
, 'rb')
178 # File doesn't exist; don't load anything!
179 print_done('missing?')
182 csvsize
= os
.stat(csvpath
).st_size
184 reader
= csv
.reader(csvfile
, lineterminator
='\n')
185 column_names
= [unicode(column
) for column
in reader
.next()]
187 if not safe
and session
.connection().dialect
.name
== 'postgresql':
189 Postgres' CSV dialect works with our data, if we mark the not-null
190 columns with FORCE NOT NULL.
191 COPY is only allowed for DB superusers. If you're not one, use safe
192 loading (pokedex load -S).
195 not_null_cols
= [c
for c
in column_names
if not table_obj
.c
[c
].nullable
]
197 force_not_null
= 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols
)
200 command
= "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
201 session
.connection().execute(
203 table_name
=table_name
,
205 columns
=','.join('"%s"' % c for c in column_names
),
206 force_not_null
=force_not_null
,
213 # Self-referential tables may contain rows with foreign keys of other
214 # rows in the same table that do not yet exist. Pull these out and add
215 # them to the session last
216 # ASSUMPTION: Self-referential tables have a single PK called "id"
217 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
218 seen_ids
= set() # primary keys we've seen
220 # Fetch foreign key columns that point at this table, if any
221 self_ref_columns
= []
222 for column
in table_obj
.c
:
223 if any(x
.references(table_obj
) for x
in column
.foreign_keys
):
224 self_ref_columns
.append(column
)
227 def insert_and_commit():
230 session
.connection().execute(insert_stmt
, new_rows
)
234 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
235 print_status(progress
)
240 for column_name
, value
in zip(column_names
, csvs
):
241 column
= table_obj
.c
[column_name
]
242 if column
.nullable
and value
== '':
243 # Empty string in a nullable column really means NULL
245 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
246 # Boolean values are stored as string values 0/1, but both
247 # of those evaluate as true; SQLA wants True/False
253 # Otherwise, unflatten from bytes
254 value
= value
.decode('utf-8')
256 # nb: Dictionaries flattened with ** have to have string keys
257 row_data
[ str(column_name
) ] = value
259 # May need to stash this row and add it later if it refers to a
260 # later row in this table
262 foreign_ids
= set(row_data
[x
.name
] for x
in self_ref_columns
)
263 foreign_ids
.discard(None) # remove NULL ids
266 # NULL key. Remember this row and add as usual.
267 seen_ids
.add(row_data
['id'])
269 elif foreign_ids
.issubset(seen_ids
):
270 # Non-NULL key we've already seen. Remember it and commit
271 # so we know the old row exists when we add the new one
273 seen_ids
.add(row_data
['id'])
276 # Non-NULL future id. Save this and insert it later!
277 deferred_rows
.append((row_data
, foreign_ids
))
281 new_rows
.append(row_data
)
283 # Remembering some zillion rows in the session consumes a lot of
284 # RAM. Let's not do that. Commit every 1000 rows
285 if len(new_rows
) >= 1000:
290 # Attempt to add any spare rows we've collected
291 for row_data
, foreign_ids
in deferred_rows
:
292 if not foreign_ids
.issubset(seen_ids
):
293 # Could happen if row A refers to B which refers to C.
294 # This is ridiculous and doesn't happen in my data so far
295 raise ValueError("Too many levels of self-reference! "
296 "Row was: " + str(row
))
298 session
.connection().execute(
299 insert_stmt
.values(**row_data
)
301 seen_ids
.add(row_data
['id'])
307 print_start('Translations')
308 transl
= translations
.Translations(csv_directory
=directory
)
311 for translation_class
, rows
in transl
.get_load_data(langs
):
312 table_obj
= translation_class
.__table__
313 if table_obj
in table_objs
:
314 insert_stmt
= table_obj
.insert()
315 session
.connection().execute(insert_stmt
, rows
)
317 # We don't have a total, but at least show some increasing number
318 new_row_count
+= len(rows
)
319 print_status(str(new_row_count
))
324 if session
.connection().dialect
.name
== 'sqlite':
325 session
.connection().execute("PRAGMA integrity_check")
329 def dump(session
, tables
=[], directory
=None, verbose
=False, langs
=['en']):
330 """Dumps the contents of a database to a set of CSV files. Probably not
331 useful to anyone besides a developer.
334 SQLAlchemy session to use.
337 List of tables to dump. If omitted, all tables are dumped.
340 Directory the CSV files should be put in. Defaults to the `pokedex`
344 If set to True, status messages will be printed to stdout.
347 List of identifiers of languages to dump unofficial texts for
350 # First take care of verbosity
351 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
353 languages
= dict((l
.id, l
) for l
in session
.query(pokedex
.db
.tables
.Language
))
356 directory
= get_default_csv_dir()
358 table_names
= _get_table_names(metadata
, tables
)
362 for table_name
in table_names
:
363 print_start(table_name
)
364 table
= metadata
.tables
[table_name
]
366 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
368 columns
= [col
.name
for col
in table
.columns
]
370 # For name tables, dump rows for official languages, as well as
371 # for those in `langs`.
372 # For other translation tables, only dump rows for languages in `langs`
373 # For non-translation tables, dump all rows.
374 if 'local_language_id' in columns
:
375 if any(col
.info
.get('official') for col
in table
.columns
):
376 def include_row(row
):
377 return (languages
[row
.local_language_id
].official
or
378 languages
[row
.local_language_id
].identifier
in langs
)
380 def include_row(row
):
381 return languages
[row
.local_language_id
].identifier
in langs
383 def include_row(row
):
386 writer
.writerow(columns
)
388 primary_key
= table
.primary_key
389 for row
in session
.query(table
).order_by(*primary_key
).all():
393 # Convert Pythony values to something more universal
394 val
= getattr(row
, col
)
402 val
= unicode(val
).encode('utf-8')
406 writer
.writerow(csvs
)