1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
11 from pokedex
.db
import metadata
12 import pokedex
.db
.tables
as tables
13 from pokedex
.defaults
import get_default_csv_dir
16 def _get_table_names(metadata
, patterns
):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
22 for pattern
in patterns
:
23 if '.' in pattern
or '/' in pattern
:
24 # If it looks like a filename, pull out just the table name
25 _
, filename
= os
.path
.split(pattern
)
26 table_name
, _
= os
.path
.splitext(filename
)
29 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
31 table_names
= metadata
.tables
.keys()
33 return list(table_names
)
35 def _get_verbose_prints(verbose
):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
40 If `verbose` is false, returns no-op functions.
45 def dummy(*args
, **kwargs
):
48 return dummy
, dummy
, dummy
50 ### Okay, verbose == True; print stuff
52 def print_start(thing
):
53 # Truncate to 66 characters, leaving 10 characters for a success
55 truncated_thing
= thing
[0:66]
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces
= 66 - len(truncated_thing
)
60 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
63 if sys
.stdout
.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
69 def print_status(msg
):
70 # Overwrite any status text with spaces before printing
71 sys
.stdout
.write('\b' * backspaces
[0])
72 sys
.stdout
.write(' ' * backspaces
[0])
73 sys
.stdout
.write('\b' * backspaces
[0])
76 backspaces
[0] = len(msg
)
78 def print_done(msg
='ok'):
79 # Overwrite any status text with spaces before printing
80 sys
.stdout
.write('\b' * backspaces
[0])
81 sys
.stdout
.write(' ' * backspaces
[0])
82 sys
.stdout
.write('\b' * backspaces
[0])
83 sys
.stdout
.write(msg
+ "\n")
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg
):
92 def print_done(msg
='ok'):
95 return print_start
, print_status
, print_done
98 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True):
99 """Load data from CSV files into the given database session.
101 Tables are created automatically.
104 SQLAlchemy session to use.
107 List of tables to load. If omitted, all tables are loaded.
110 Directory the CSV files reside in. Defaults to the `pokedex` data
114 If set to True, existing `pokedex`-related tables will be dropped.
117 If set to True, status messages will be printed to stdout.
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
124 # First take care of verbosity
125 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
128 if directory
is None:
129 directory
= get_default_csv_dir()
131 table_names
= _get_table_names(metadata
, tables
)
132 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
133 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
135 # SQLite speed tweaks
136 if not safe
and session
.connection().dialect
.name
== 'sqlite':
137 session
.connection().execute("PRAGMA synchronous=OFF")
138 session
.connection().execute("PRAGMA journal_mode=OFF")
140 # Drop all tables if requested
142 print_start('Dropping tables')
143 for n
, table
in enumerate(reversed(table_objs
)):
144 table
.drop(checkfirst
=True)
145 print_status('%s/%s' %
(n
, len(table_objs
)))
148 print_start('Creating tables')
149 for n
, table
in enumerate(table_objs
):
151 print_status('%s/%s' %
(n
, len(table_objs
)))
153 connection
= session
.connection()
155 # Okay, run through the tables and actually load the data now
156 for table_obj
in table_objs
:
157 table_name
= table_obj
.name
158 insert_stmt
= table_obj
.insert()
160 print_start(table_name
)
163 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
164 csvfile
= open(csvpath
, 'rb')
166 # File doesn't exist; don't load anything!
167 print_done('missing?')
170 csvsize
= os
.stat(csvpath
).st_size
172 reader
= csv
.reader(csvfile
, lineterminator
='\n')
173 column_names
= [unicode(column
) for column
in reader
.next()]
175 if not safe
and session
.connection().dialect
.name
== 'postgresql':
177 Postgres' CSV dialect is nearly the same as ours, except that it
178 treats completely empty values as NULL, and empty quoted
179 strings ("") as an empty strings.
180 Pokedex dump does not quote empty strings. So, both empty strings
181 and NULLs are read in as NULL.
182 For an empty string in a NOT NULL column, the load will fail, and
183 load will fall back to the cross-backend row-by-row loading. And in
184 nullable columns, we already load empty stings as NULL.
187 not_null_cols
= [c
for c
in column_names
if not table_obj
.c
[c
].nullable
]
189 force_not_null
= 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols
)
192 command
= "COPY {table_name} ({columns}) FROM '{csvpath}' CSV HEADER {force_not_null}"
193 session
.connection().execute(
195 table_name
=table_name
,
197 columns
=','.join('"%s"' % c for c in column_names
),
198 force_not_null
=force_not_null
,
205 # Self-referential tables may contain rows with foreign keys of other
206 # rows in the same table that do not yet exist. Pull these out and add
207 # them to the session last
208 # ASSUMPTION: Self-referential tables have a single PK called "id"
209 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
210 seen_ids
= {} # primary key we've seen => 1
212 # Fetch foreign key columns that point at this table, if any
213 self_ref_columns
= []
214 for column
in table_obj
.c
:
215 if any(_
.references(table_obj
) for _
in column
.foreign_keys
):
216 self_ref_columns
.append(column
)
219 def insert_and_commit():
220 session
.connection().execute(insert_stmt
, new_rows
)
224 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
225 print_status(progress
)
230 for column_name
, value
in zip(column_names
, csvs
):
231 column
= table_obj
.c
[column_name
]
232 if column
.nullable
and value
== '':
233 # Empty string in a nullable column really means NULL
235 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
236 # Boolean values are stored as string values 0/1, but both
237 # of those evaluate as true; SQLA wants True/False
243 # Otherwise, unflatten from bytes
244 value
= value
.decode('utf-8')
246 # nb: Dictionaries flattened with ** have to have string keys
247 row_data
[ str(column_name
) ] = value
249 # May need to stash this row and add it later if it refers to a
250 # later row in this table
252 foreign_ids
= [row_data
[_
.name
] for _
in self_ref_columns
]
253 foreign_ids
= [_
for _
in foreign_ids
if _
] # remove NULL ids
256 # NULL key. Remember this row and add as usual.
257 seen_ids
[row_data
['id']] = 1
259 elif all(_
in seen_ids
for _
in foreign_ids
):
260 # Non-NULL key we've already seen. Remember it and commit
261 # so we know the old row exists when we add the new one
263 seen_ids
[row_data
['id']] = 1
266 # Non-NULL future id. Save this and insert it later!
267 deferred_rows
.append((row_data
, foreign_ids
))
271 new_rows
.append(row_data
)
273 # Remembering some zillion rows in the session consumes a lot of
274 # RAM. Let's not do that. Commit every 1000 rows
275 if len(new_rows
) >= 1000:
280 # Attempt to add any spare rows we've collected
281 for row_data
, foreign_ids
in deferred_rows
:
282 if not all(_
in seen_ids
for _
in foreign_ids
):
283 # Could happen if row A refers to B which refers to C.
284 # This is ridiculous and doesn't happen in my data so far
285 raise ValueError("Too many levels of self-reference! "
286 "Row was: " + str(row
))
288 session
.connection().execute(
289 insert_stmt
.values(**row_data
)
291 seen_ids
[row_data
['id']] = 1
297 if session
.connection().dialect
.name
== 'sqlite':
298 session
.connection().execute("PRAGMA integrity_check")
302 def dump(session
, tables
=[], directory
=None, verbose
=False):
303 """Dumps the contents of a database to a set of CSV files. Probably not
304 useful to anyone besides a developer.
307 SQLAlchemy session to use.
310 List of tables to dump. If omitted, all tables are dumped.
313 Directory the CSV files should be put in. Defaults to the `pokedex`
317 If set to True, status messages will be printed to stdout.
320 # First take care of verbosity
321 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
325 directory
= get_default_csv_dir()
327 table_names
= _get_table_names(metadata
, tables
)
331 for table_name
in table_names
:
332 print_start(table_name
)
333 table
= metadata
.tables
[table_name
]
335 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
337 columns
= [col
.name
for col
in table
.columns
]
338 writer
.writerow(columns
)
340 primary_key
= table
.primary_key
341 for row
in session
.query(table
).order_by(*primary_key
).all():
344 # Convert Pythony values to something more universal
345 val
= getattr(row
, col
)
353 val
= unicode(val
).encode('utf-8')
357 writer
.writerow(csvs
)