1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
11 from pokedex
.db
import metadata
12 import pokedex
.db
.tables
as tables
13 from pokedex
.defaults
import get_default_csv_dir
16 def _get_table_names(metadata
, patterns
):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
22 for pattern
in patterns
:
23 if '.' in pattern
or '/' in pattern
:
24 # If it looks like a filename, pull out just the table name
25 _
, filename
= os
.path
.split(pattern
)
26 table_name
, _
= os
.path
.splitext(filename
)
29 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
31 table_names
= metadata
.tables
.keys()
33 return list(table_names
)
35 def _get_verbose_prints(verbose
):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
40 If `verbose` is false, returns no-op functions.
45 def dummy(*args
, **kwargs
):
48 return dummy
, dummy
, dummy
50 ### Okay, verbose == True; print stuff
52 def print_start(thing
):
53 # Truncate to 66 characters, leaving 10 characters for a success
55 truncated_thing
= thing
[0:66]
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces
= 66 - len(truncated_thing
)
60 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
63 if sys
.stdout
.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
69 def print_status(msg
):
70 # Overwrite any status text with spaces before printing
71 sys
.stdout
.write('\b' * backspaces
[0])
72 sys
.stdout
.write(' ' * backspaces
[0])
73 sys
.stdout
.write('\b' * backspaces
[0])
76 backspaces
[0] = len(msg
)
78 def print_done(msg
='ok'):
79 # Overwrite any status text with spaces before printing
80 sys
.stdout
.write('\b' * backspaces
[0])
81 sys
.stdout
.write(' ' * backspaces
[0])
82 sys
.stdout
.write('\b' * backspaces
[0])
83 sys
.stdout
.write(msg
+ "\n")
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg
):
92 def print_done(msg
='ok'):
95 return print_start
, print_status
, print_done
98 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True):
99 """Load data from CSV files into the given database session.
101 Tables are created automatically.
104 SQLAlchemy session to use.
107 List of tables to load. If omitted, all tables are loaded.
110 Directory the CSV files reside in. Defaults to the `pokedex` data
114 If set to True, existing `pokedex`-related tables will be dropped.
117 If set to True, status messages will be printed to stdout.
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
124 # First take care of verbosity
125 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
128 if directory
is None:
129 directory
= get_default_csv_dir()
131 table_names
= _get_table_names(metadata
, tables
)
132 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
133 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
135 # SQLite speed tweaks
136 if not safe
and session
.connection().dialect
.name
== 'sqlite':
137 session
.connection().execute("PRAGMA synchronous=OFF")
138 session
.connection().execute("PRAGMA journal_mode=OFF")
140 # Drop all tables if requested
142 print_start('Dropping tables')
143 for n
, table
in enumerate(reversed(table_objs
)):
144 table
.drop(checkfirst
=True)
145 print_status('%s/%s' %
(n
, len(table_objs
)))
148 print_start('Creating tables')
149 for n
, table
in enumerate(table_objs
):
151 print_status('%s/%s' %
(n
, len(table_objs
)))
153 connection
= session
.connection()
155 # Okay, run through the tables and actually load the data now
156 for table_obj
in table_objs
:
157 table_name
= table_obj
.name
158 insert_stmt
= table_obj
.insert()
160 print_start(table_name
)
163 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
164 csvfile
= open(csvpath
, 'rb')
166 # File doesn't exist; don't load anything!
167 print_done('missing?')
170 csvsize
= os
.stat(csvpath
).st_size
172 reader
= csv
.reader(csvfile
, lineterminator
='\n')
173 column_names
= [unicode(column
) for column
in reader
.next()]
175 if not safe
and session
.connection().dialect
.name
== 'postgresql':
177 Postgres' CSV dialect works with our data, if we mark the not-null
178 columns with FORCE NOT NULL.
179 COPY is only allowed for DB superusers. If you're not one, use safe
180 loading (pokedex load -S).
183 not_null_cols
= [c
for c
in column_names
if not table_obj
.c
[c
].nullable
]
185 force_not_null
= 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols
)
188 command
= "COPY {table_name} ({columns}) FROM '{csvpath}' CSV HEADER {force_not_null}"
189 session
.connection().execute(
191 table_name
=table_name
,
193 columns
=','.join('"%s"' % c for c in column_names
),
194 force_not_null
=force_not_null
,
201 # Self-referential tables may contain rows with foreign keys of other
202 # rows in the same table that do not yet exist. Pull these out and add
203 # them to the session last
204 # ASSUMPTION: Self-referential tables have a single PK called "id"
205 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
206 seen_ids
= {} # primary key we've seen => 1
208 # Fetch foreign key columns that point at this table, if any
209 self_ref_columns
= []
210 for column
in table_obj
.c
:
211 if any(_
.references(table_obj
) for _
in column
.foreign_keys
):
212 self_ref_columns
.append(column
)
215 def insert_and_commit():
216 session
.connection().execute(insert_stmt
, new_rows
)
220 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
221 print_status(progress
)
226 for column_name
, value
in zip(column_names
, csvs
):
227 column
= table_obj
.c
[column_name
]
228 if column
.nullable
and value
== '':
229 # Empty string in a nullable column really means NULL
231 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
232 # Boolean values are stored as string values 0/1, but both
233 # of those evaluate as true; SQLA wants True/False
239 # Otherwise, unflatten from bytes
240 value
= value
.decode('utf-8')
242 # nb: Dictionaries flattened with ** have to have string keys
243 row_data
[ str(column_name
) ] = value
245 # May need to stash this row and add it later if it refers to a
246 # later row in this table
248 foreign_ids
= [row_data
[_
.name
] for _
in self_ref_columns
]
249 foreign_ids
= [_
for _
in foreign_ids
if _
] # remove NULL ids
252 # NULL key. Remember this row and add as usual.
253 seen_ids
[row_data
['id']] = 1
255 elif all(_
in seen_ids
for _
in foreign_ids
):
256 # Non-NULL key we've already seen. Remember it and commit
257 # so we know the old row exists when we add the new one
259 seen_ids
[row_data
['id']] = 1
262 # Non-NULL future id. Save this and insert it later!
263 deferred_rows
.append((row_data
, foreign_ids
))
267 new_rows
.append(row_data
)
269 # Remembering some zillion rows in the session consumes a lot of
270 # RAM. Let's not do that. Commit every 1000 rows
271 if len(new_rows
) >= 1000:
276 # Attempt to add any spare rows we've collected
277 for row_data
, foreign_ids
in deferred_rows
:
278 if not all(_
in seen_ids
for _
in foreign_ids
):
279 # Could happen if row A refers to B which refers to C.
280 # This is ridiculous and doesn't happen in my data so far
281 raise ValueError("Too many levels of self-reference! "
282 "Row was: " + str(row
))
284 session
.connection().execute(
285 insert_stmt
.values(**row_data
)
287 seen_ids
[row_data
['id']] = 1
293 if session
.connection().dialect
.name
== 'sqlite':
294 session
.connection().execute("PRAGMA integrity_check")
298 def dump(session
, tables
=[], directory
=None, verbose
=False):
299 """Dumps the contents of a database to a set of CSV files. Probably not
300 useful to anyone besides a developer.
303 SQLAlchemy session to use.
306 List of tables to dump. If omitted, all tables are dumped.
309 Directory the CSV files should be put in. Defaults to the `pokedex`
313 If set to True, status messages will be printed to stdout.
316 # First take care of verbosity
317 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
321 directory
= get_default_csv_dir()
323 table_names
= _get_table_names(metadata
, tables
)
327 for table_name
in table_names
:
328 print_start(table_name
)
329 table
= metadata
.tables
[table_name
]
331 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
333 columns
= [col
.name
for col
in table
.columns
]
334 writer
.writerow(columns
)
336 primary_key
= table
.primary_key
337 for row
in session
.query(table
).order_by(*primary_key
).all():
340 # Convert Pythony values to something more universal
341 val
= getattr(row
, col
)
349 val
= unicode(val
).encode('utf-8')
353 writer
.writerow(csvs
)