1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
11 from pokedex
.db
import metadata
12 import pokedex
.db
.tables
as tables
13 from pokedex
.defaults
import get_default_csv_dir
14 from pokedex
.db
.dependencies
import find_dependent_tables
17 def _get_table_names(metadata
, patterns
):
18 """Returns a list of table names from the given metadata. If `patterns`
19 exists, only tables matching one of the patterns will be returned.
23 for pattern
in patterns
:
24 if '.' in pattern
or '/' in pattern
:
25 # If it looks like a filename, pull out just the table name
26 _
, filename
= os
.path
.split(pattern
)
27 table_name
, _
= os
.path
.splitext(filename
)
30 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
32 table_names
= metadata
.tables
.keys()
34 return list(table_names
)
36 def _get_verbose_prints(verbose
):
37 """If `verbose` is true, returns three functions: one for printing a
38 starting message, one for printing an interim status update, and one for
39 printing a success or failure message when finished.
41 If `verbose` is false, returns no-op functions.
46 def dummy(*args
, **kwargs
):
49 return dummy
, dummy
, dummy
51 ### Okay, verbose == True; print stuff
53 def print_start(thing
):
54 # Truncate to 66 characters, leaving 10 characters for a success
56 truncated_thing
= thing
[0:66]
58 # Also, space-pad to keep the cursor in a known column
59 num_spaces
= 66 - len(truncated_thing
)
61 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
64 if sys
.stdout
.isatty():
65 # stdout is a terminal; stupid backspace tricks are OK.
66 # Don't use print, because it always adds magical spaces, which
67 # makes backspace accounting harder
70 def print_status(msg
):
71 # Overwrite any status text with spaces before printing
72 sys
.stdout
.write('\b' * backspaces
[0])
73 sys
.stdout
.write(' ' * backspaces
[0])
74 sys
.stdout
.write('\b' * backspaces
[0])
77 backspaces
[0] = len(msg
)
79 def print_done(msg
='ok'):
80 # Overwrite any status text with spaces before printing
81 sys
.stdout
.write('\b' * backspaces
[0])
82 sys
.stdout
.write(' ' * backspaces
[0])
83 sys
.stdout
.write('\b' * backspaces
[0])
84 sys
.stdout
.write(msg
+ "\n")
89 # stdout is a file (or something); don't bother with status at all
90 def print_status(msg
):
93 def print_done(msg
='ok'):
96 return print_start
, print_status
, print_done
99 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True, recursive
=False):
100 """Load data from CSV files into the given database session.
102 Tables are created automatically.
105 SQLAlchemy session to use.
108 List of tables to load. If omitted, all tables are loaded.
111 Directory the CSV files reside in. Defaults to the `pokedex` data
115 If set to True, existing `pokedex`-related tables will be dropped.
118 If set to True, status messages will be printed to stdout.
121 If set to False, load can be faster, but can corrupt the database if
122 it crashes or is interrupted.
125 If set to True, load all dependent tables too.
128 # First take care of verbosity
129 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
132 if directory
is None:
133 directory
= get_default_csv_dir()
135 # XXX why isn't this done in command_load
136 table_names
= _get_table_names(metadata
, tables
)
137 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
140 table_objs
.extend(find_dependent_tables(table_objs
))
142 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
144 # SQLite speed tweaks
145 if not safe
and session
.connection().dialect
.name
== 'sqlite':
146 session
.connection().execute("PRAGMA synchronous=OFF")
147 session
.connection().execute("PRAGMA journal_mode=OFF")
149 # Drop all tables if requested
151 print_start('Dropping tables')
152 for n
, table
in enumerate(reversed(table_objs
)):
153 table
.drop(checkfirst
=True)
154 print_status('%s/%s' %
(n
, len(table_objs
)))
157 print_start('Creating tables')
158 for n
, table
in enumerate(table_objs
):
160 print_status('%s/%s' %
(n
, len(table_objs
)))
162 connection
= session
.connection()
164 # Okay, run through the tables and actually load the data now
165 for table_obj
in table_objs
:
166 table_name
= table_obj
.name
167 insert_stmt
= table_obj
.insert()
169 print_start(table_name
)
172 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
173 csvfile
= open(csvpath
, 'rb')
175 # File doesn't exist; don't load anything!
176 print_done('missing?')
179 csvsize
= os
.stat(csvpath
).st_size
181 reader
= csv
.reader(csvfile
, lineterminator
='\n')
182 column_names
= [unicode(column
) for column
in reader
.next()]
184 if not safe
and session
.connection().dialect
.name
== 'postgresql':
186 Postgres' CSV dialect works with our data, if we mark the not-null
187 columns with FORCE NOT NULL.
188 COPY is only allowed for DB superusers. If you're not one, use safe
189 loading (pokedex load -S).
192 not_null_cols
= [c
for c
in column_names
if not table_obj
.c
[c
].nullable
]
194 force_not_null
= 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols
)
197 command
= "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
198 session
.connection().execute(
200 table_name
=table_name
,
202 columns
=','.join('"%s"' % c for c in column_names
),
203 force_not_null
=force_not_null
,
210 # Self-referential tables may contain rows with foreign keys of other
211 # rows in the same table that do not yet exist. Pull these out and add
212 # them to the session last
213 # ASSUMPTION: Self-referential tables have a single PK called "id"
214 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
215 seen_ids
= {} # primary key we've seen => 1
217 # Fetch foreign key columns that point at this table, if any
218 self_ref_columns
= []
219 for column
in table_obj
.c
:
220 if any(_
.references(table_obj
) for _
in column
.foreign_keys
):
221 self_ref_columns
.append(column
)
224 def insert_and_commit():
227 session
.connection().execute(insert_stmt
, new_rows
)
231 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
232 print_status(progress
)
237 for column_name
, value
in zip(column_names
, csvs
):
238 column
= table_obj
.c
[column_name
]
239 if column
.nullable
and value
== '':
240 # Empty string in a nullable column really means NULL
242 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
243 # Boolean values are stored as string values 0/1, but both
244 # of those evaluate as true; SQLA wants True/False
250 # Otherwise, unflatten from bytes
251 value
= value
.decode('utf-8')
253 # nb: Dictionaries flattened with ** have to have string keys
254 row_data
[ str(column_name
) ] = value
256 # May need to stash this row and add it later if it refers to a
257 # later row in this table
259 foreign_ids
= [row_data
[_
.name
] for _
in self_ref_columns
]
260 foreign_ids
= [_
for _
in foreign_ids
if _
] # remove NULL ids
263 # NULL key. Remember this row and add as usual.
264 seen_ids
[row_data
['id']] = 1
266 elif all(_
in seen_ids
for _
in foreign_ids
):
267 # Non-NULL key we've already seen. Remember it and commit
268 # so we know the old row exists when we add the new one
270 seen_ids
[row_data
['id']] = 1
273 # Non-NULL future id. Save this and insert it later!
274 deferred_rows
.append((row_data
, foreign_ids
))
278 new_rows
.append(row_data
)
280 # Remembering some zillion rows in the session consumes a lot of
281 # RAM. Let's not do that. Commit every 1000 rows
282 if len(new_rows
) >= 1000:
287 # Attempt to add any spare rows we've collected
288 for row_data
, foreign_ids
in deferred_rows
:
289 if not all(_
in seen_ids
for _
in foreign_ids
):
290 # Could happen if row A refers to B which refers to C.
291 # This is ridiculous and doesn't happen in my data so far
292 raise ValueError("Too many levels of self-reference! "
293 "Row was: " + str(row
))
295 session
.connection().execute(
296 insert_stmt
.values(**row_data
)
298 seen_ids
[row_data
['id']] = 1
304 if session
.connection().dialect
.name
== 'sqlite':
305 session
.connection().execute("PRAGMA integrity_check")
309 def dump(session
, tables
=[], directory
=None, verbose
=False):
310 """Dumps the contents of a database to a set of CSV files. Probably not
311 useful to anyone besides a developer.
314 SQLAlchemy session to use.
317 List of tables to dump. If omitted, all tables are dumped.
320 Directory the CSV files should be put in. Defaults to the `pokedex`
324 If set to True, status messages will be printed to stdout.
327 # First take care of verbosity
328 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
332 directory
= get_default_csv_dir()
334 table_names
= _get_table_names(metadata
, tables
)
338 for table_name
in table_names
:
339 print_start(table_name
)
340 table
= metadata
.tables
[table_name
]
342 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
344 columns
= [col
.name
for col
in table
.columns
]
345 writer
.writerow(columns
)
347 primary_key
= table
.primary_key
348 for row
in session
.query(table
).order_by(*primary_key
).all():
351 # Convert Pythony values to something more universal
352 val
= getattr(row
, col
)
360 val
= unicode(val
).encode('utf-8')
364 writer
.writerow(csvs
)