1 """CSV to database or vice versa."""
7 import sqlalchemy
.sql
.util
8 import sqlalchemy
.types
10 from pokedex
.db
import metadata
11 import pokedex
.db
.tables
as tables
12 from pokedex
.defaults
import get_default_csv_dir
13 from pokedex
.db
.dependencies
import find_dependent_tables
16 def _get_table_names(metadata
, patterns
):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
22 for pattern
in patterns
:
23 if '.' in pattern
or '/' in pattern
:
24 # If it looks like a filename, pull out just the table name
25 _
, filename
= os
.path
.split(pattern
)
26 table_name
, _
= os
.path
.splitext(filename
)
29 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
31 table_names
= metadata
.tables
.keys()
33 return list(table_names
)
35 def _get_verbose_prints(verbose
):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
40 If `verbose` is false, returns no-op functions.
45 def dummy(*args
, **kwargs
):
48 return dummy
, dummy
, dummy
50 ### Okay, verbose == True; print stuff
52 def print_start(thing
):
53 # Truncate to 66 characters, leaving 10 characters for a success
55 truncated_thing
= thing
[:66]
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces
= 66 - len(truncated_thing
)
60 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
63 if sys
.stdout
.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
69 def print_status(msg
):
70 # Overwrite any status text with spaces before printing
71 sys
.stdout
.write('\b' * backspaces
[0])
72 sys
.stdout
.write(' ' * backspaces
[0])
73 sys
.stdout
.write('\b' * backspaces
[0])
76 backspaces
[0] = len(msg
)
78 def print_done(msg
='ok'):
79 # Overwrite any status text with spaces before printing
80 sys
.stdout
.write('\b' * backspaces
[0])
81 sys
.stdout
.write(' ' * backspaces
[0])
82 sys
.stdout
.write('\b' * backspaces
[0])
83 sys
.stdout
.write(msg
+ "\n")
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg
):
92 def print_done(msg
='ok'):
95 return print_start
, print_status
, print_done
98 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True, recursive
=False):
99 """Load data from CSV files into the given database session.
101 Tables are created automatically.
104 SQLAlchemy session to use.
107 List of tables to load. If omitted, all tables are loaded.
110 Directory the CSV files reside in. Defaults to the `pokedex` data
114 If set to True, existing `pokedex`-related tables will be dropped.
117 If set to True, status messages will be printed to stdout.
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
124 If set to True, load all dependent tables too.
127 # First take care of verbosity
128 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
131 if directory
is None:
132 directory
= get_default_csv_dir()
134 # XXX why isn't this done in command_load
135 table_names
= _get_table_names(metadata
, tables
)
136 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
139 table_objs
.extend(find_dependent_tables(table_objs
))
141 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
143 # SQLite speed tweaks
144 if not safe
and session
.connection().dialect
.name
== 'sqlite':
145 session
.connection().execute("PRAGMA synchronous=OFF")
146 session
.connection().execute("PRAGMA journal_mode=OFF")
148 # Drop all tables if requested
150 print_start('Dropping tables')
151 for n
, table
in enumerate(reversed(table_objs
)):
152 table
.drop(checkfirst
=True)
153 print_status('%s/%s' %
(n
, len(table_objs
)))
156 print_start('Creating tables')
157 for n
, table
in enumerate(table_objs
):
159 print_status('%s/%s' %
(n
, len(table_objs
)))
161 connection
= session
.connection()
163 # Okay, run through the tables and actually load the data now
164 for table_obj
in table_objs
:
165 table_name
= table_obj
.name
166 insert_stmt
= table_obj
.insert()
168 print_start(table_name
)
171 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
172 csvfile
= open(csvpath
, 'rb')
174 # File doesn't exist; don't load anything!
175 print_done('missing?')
178 csvsize
= os
.stat(csvpath
).st_size
180 reader
= csv
.reader(csvfile
, lineterminator
='\n')
181 column_names
= [unicode(column
) for column
in reader
.next()]
183 if not safe
and session
.connection().dialect
.name
== 'postgresql':
185 Postgres' CSV dialect works with our data, if we mark the not-null
186 columns with FORCE NOT NULL.
187 COPY is only allowed for DB superusers. If you're not one, use safe
188 loading (pokedex load -S).
191 not_null_cols
= [c
for c
in column_names
if not table_obj
.c
[c
].nullable
]
193 force_not_null
= 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols
)
196 command
= "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
197 session
.connection().execute(
199 table_name
=table_name
,
201 columns
=','.join('"%s"' % c for c in column_names
),
202 force_not_null
=force_not_null
,
209 # Self-referential tables may contain rows with foreign keys of other
210 # rows in the same table that do not yet exist. Pull these out and add
211 # them to the session last
212 # ASSUMPTION: Self-referential tables have a single PK called "id"
213 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
214 seen_ids
= set() # primary keys we've seen
216 # Fetch foreign key columns that point at this table, if any
217 self_ref_columns
= []
218 for column
in table_obj
.c
:
219 if any(x
.references(table_obj
) for x
in column
.foreign_keys
):
220 self_ref_columns
.append(column
)
223 def insert_and_commit():
226 session
.connection().execute(insert_stmt
, new_rows
)
230 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
231 print_status(progress
)
236 for column_name
, value
in zip(column_names
, csvs
):
237 column
= table_obj
.c
[column_name
]
238 if column
.nullable
and value
== '':
239 # Empty string in a nullable column really means NULL
241 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
242 # Boolean values are stored as string values 0/1, but both
243 # of those evaluate as true; SQLA wants True/False
249 # Otherwise, unflatten from bytes
250 value
= value
.decode('utf-8')
252 # nb: Dictionaries flattened with ** have to have string keys
253 row_data
[ str(column_name
) ] = value
255 # May need to stash this row and add it later if it refers to a
256 # later row in this table
258 foreign_ids
= set(row_data
[x
.name
] for x
in self_ref_columns
)
259 foreign_ids
.discard(None) # remove NULL ids
262 # NULL key. Remember this row and add as usual.
263 seen_ids
.add(row_data
['id'])
265 elif foreign_ids
.issubset(seen_ids
):
266 # Non-NULL key we've already seen. Remember it and commit
267 # so we know the old row exists when we add the new one
269 seen_ids
.add(row_data
['id'])
272 # Non-NULL future id. Save this and insert it later!
273 deferred_rows
.append((row_data
, foreign_ids
))
277 new_rows
.append(row_data
)
279 # Remembering some zillion rows in the session consumes a lot of
280 # RAM. Let's not do that. Commit every 1000 rows
281 if len(new_rows
) >= 1000:
286 # Attempt to add any spare rows we've collected
287 for row_data
, foreign_ids
in deferred_rows
:
288 if not foreign_ids
.issubset(seen_ids
):
289 # Could happen if row A refers to B which refers to C.
290 # This is ridiculous and doesn't happen in my data so far
291 raise ValueError("Too many levels of self-reference! "
292 "Row was: " + str(row
))
294 session
.connection().execute(
295 insert_stmt
.values(**row_data
)
297 seen_ids
.add(row_data
['id'])
303 if session
.connection().dialect
.name
== 'sqlite':
304 session
.connection().execute("PRAGMA integrity_check")
308 def dump(session
, tables
=[], directory
=None, verbose
=False):
309 """Dumps the contents of a database to a set of CSV files. Probably not
310 useful to anyone besides a developer.
313 SQLAlchemy session to use.
316 List of tables to dump. If omitted, all tables are dumped.
319 Directory the CSV files should be put in. Defaults to the `pokedex`
323 If set to True, status messages will be printed to stdout.
326 # First take care of verbosity
327 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
331 directory
= get_default_csv_dir()
333 table_names
= _get_table_names(metadata
, tables
)
337 for table_name
in table_names
:
338 print_start(table_name
)
339 table
= metadata
.tables
[table_name
]
341 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
343 columns
= [col
.name
for col
in table
.columns
]
344 writer
.writerow(columns
)
346 primary_key
= table
.primary_key
347 for row
in session
.query(table
).order_by(*primary_key
).all():
350 # Convert Pythony values to something more universal
351 val
= getattr(row
, col
)
359 val
= unicode(val
).encode('utf-8')
363 writer
.writerow(csvs
)