1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
11 from pokedex
.db
import metadata
12 import pokedex
.db
.tables
as tables
13 from pokedex
.defaults
import get_default_csv_dir
16 def _get_table_names(metadata
, patterns
):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
22 for pattern
in patterns
:
23 if '.' in pattern
or '/' in pattern
:
24 # If it looks like a filename, pull out just the table name
25 _
, filename
= os
.path
.split(pattern
)
26 table_name
, _
= os
.path
.splitext(filename
)
29 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
31 table_names
= metadata
.tables
.keys()
33 return list(table_names
)
35 def _get_verbose_prints(verbose
):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
40 If `verbose` is false, returns no-op functions.
45 def dummy(*args
, **kwargs
):
48 return dummy
, dummy
, dummy
50 ### Okay, verbose == True; print stuff
52 def print_start(thing
):
53 # Truncate to 66 characters, leaving 10 characters for a success
55 truncated_thing
= thing
[0:66]
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces
= 66 - len(truncated_thing
)
60 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
63 if sys
.stdout
.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
69 def print_status(msg
):
70 # Overwrite any status text with spaces before printing
71 sys
.stdout
.write('\b' * backspaces
[0])
72 sys
.stdout
.write(' ' * backspaces
[0])
73 sys
.stdout
.write('\b' * backspaces
[0])
76 backspaces
[0] = len(msg
)
78 def print_done(msg
='ok'):
79 # Overwrite any status text with spaces before printing
80 sys
.stdout
.write('\b' * backspaces
[0])
81 sys
.stdout
.write(' ' * backspaces
[0])
82 sys
.stdout
.write('\b' * backspaces
[0])
83 sys
.stdout
.write(msg
+ "\n")
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg
):
92 def print_done(msg
='ok'):
95 return print_start
, print_status
, print_done
98 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True):
99 """Load data from CSV files into the given database session.
101 Tables are created automatically.
104 SQLAlchemy session to use.
107 List of tables to load. If omitted, all tables are loaded.
110 Directory the CSV files reside in. Defaults to the `pokedex` data
114 If set to True, existing `pokedex`-related tables will be dropped.
117 If set to True, status messages will be printed to stdout.
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
124 # First take care of verbosity
125 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
128 if directory
is None:
129 directory
= get_default_csv_dir()
131 table_names
= _get_table_names(metadata
, tables
)
132 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
133 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
135 # SQLite speed tweaks
136 if not safe
and session
.connection().dialect
.name
== 'sqlite':
137 session
.connection().execute("PRAGMA synchronous=OFF")
138 session
.connection().execute("PRAGMA journal_mode=OFF")
140 # Drop all tables if requested
142 print_start('Dropping tables')
143 for n
, table
in enumerate(reversed(table_objs
)):
144 table
.drop(checkfirst
=True)
145 print_status('%s/%s' %
(n
, len(table_objs
)))
148 print_start('Creating tables')
149 for n
, table
in enumerate(table_objs
):
151 print_status('%s/%s' %
(n
, len(table_objs
)))
153 connection
= session
.connection()
155 # Okay, run through the tables and actually load the data now
156 for table_obj
in table_objs
:
157 table_name
= table_obj
.name
158 insert_stmt
= table_obj
.insert()
160 print_start(table_name
)
163 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
164 csvfile
= open(csvpath
, 'rb')
166 # File doesn't exist; don't load anything!
167 print_done('missing?')
170 csvsize
= os
.stat(csvpath
).st_size
172 reader
= csv
.reader(csvfile
, lineterminator
='\n')
173 column_names
= [unicode(column
) for column
in reader
.next()]
175 if not safe
and session
.connection().dialect
.name
== 'postgresql':
177 Postgres' CSV dialect works with our data, if we mark the not-null
178 columns with FORCE NOT NULL.
179 COPY is only allowed for DB superusers. If you're not one, use safe
180 loading (pokedex load -S).
183 not_null_cols
= [c
for c
in column_names
if not table_obj
.c
[c
].nullable
]
185 force_not_null
= 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols
)
188 command
= "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
189 session
.connection().execute(
191 table_name
=table_name
,
193 columns
=','.join('"%s"' % c for c in column_names
),
194 force_not_null
=force_not_null
,
201 # Self-referential tables may contain rows with foreign keys of other
202 # rows in the same table that do not yet exist. Pull these out and add
203 # them to the session last
204 # ASSUMPTION: Self-referential tables have a single PK called "id"
205 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
206 seen_ids
= {} # primary key we've seen => 1
208 # Fetch foreign key columns that point at this table, if any
209 self_ref_columns
= []
210 for column
in table_obj
.c
:
211 if any(_
.references(table_obj
) for _
in column
.foreign_keys
):
212 self_ref_columns
.append(column
)
215 def insert_and_commit():
218 session
.connection().execute(insert_stmt
, new_rows
)
222 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
223 print_status(progress
)
228 for column_name
, value
in zip(column_names
, csvs
):
229 column
= table_obj
.c
[column_name
]
230 if column
.nullable
and value
== '':
231 # Empty string in a nullable column really means NULL
233 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
234 # Boolean values are stored as string values 0/1, but both
235 # of those evaluate as true; SQLA wants True/False
241 # Otherwise, unflatten from bytes
242 value
= value
.decode('utf-8')
244 # nb: Dictionaries flattened with ** have to have string keys
245 row_data
[ str(column_name
) ] = value
247 # May need to stash this row and add it later if it refers to a
248 # later row in this table
250 foreign_ids
= [row_data
[_
.name
] for _
in self_ref_columns
]
251 foreign_ids
= [_
for _
in foreign_ids
if _
] # remove NULL ids
254 # NULL key. Remember this row and add as usual.
255 seen_ids
[row_data
['id']] = 1
257 elif all(_
in seen_ids
for _
in foreign_ids
):
258 # Non-NULL key we've already seen. Remember it and commit
259 # so we know the old row exists when we add the new one
261 seen_ids
[row_data
['id']] = 1
264 # Non-NULL future id. Save this and insert it later!
265 deferred_rows
.append((row_data
, foreign_ids
))
269 new_rows
.append(row_data
)
271 # Remembering some zillion rows in the session consumes a lot of
272 # RAM. Let's not do that. Commit every 1000 rows
273 if len(new_rows
) >= 1000:
278 # Attempt to add any spare rows we've collected
279 for row_data
, foreign_ids
in deferred_rows
:
280 if not all(_
in seen_ids
for _
in foreign_ids
):
281 # Could happen if row A refers to B which refers to C.
282 # This is ridiculous and doesn't happen in my data so far
283 raise ValueError("Too many levels of self-reference! "
284 "Row was: " + str(row
))
286 session
.connection().execute(
287 insert_stmt
.values(**row_data
)
289 seen_ids
[row_data
['id']] = 1
295 if session
.connection().dialect
.name
== 'sqlite':
296 session
.connection().execute("PRAGMA integrity_check")
300 def dump(session
, tables
=[], directory
=None, verbose
=False):
301 """Dumps the contents of a database to a set of CSV files. Probably not
302 useful to anyone besides a developer.
305 SQLAlchemy session to use.
308 List of tables to dump. If omitted, all tables are dumped.
311 Directory the CSV files should be put in. Defaults to the `pokedex`
315 If set to True, status messages will be printed to stdout.
318 # First take care of verbosity
319 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
323 directory
= get_default_csv_dir()
325 table_names
= _get_table_names(metadata
, tables
)
329 for table_name
in table_names
:
330 print_start(table_name
)
331 table
= metadata
.tables
[table_name
]
333 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
335 columns
= [col
.name
for col
in table
.columns
]
336 writer
.writerow(columns
)
338 primary_key
= table
.primary_key
339 for row
in session
.query(table
).order_by(*primary_key
).all():
342 # Convert Pythony values to something more universal
343 val
= getattr(row
, col
)
351 val
= unicode(val
).encode('utf-8')
355 writer
.writerow(csvs
)