e6b328780d964ec24ca31fc043a970c68b7b2f0b
1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
11 from pokedex
.db
import metadata
12 import pokedex
.db
.tables
as tables
13 from pokedex
.defaults
import get_default_csv_dir
16 def _get_table_names(metadata
, patterns
):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
22 for pattern
in patterns
:
23 if '.' in pattern
or '/' in pattern
:
24 # If it looks like a filename, pull out just the table name
25 _
, filename
= os
.path
.split(pattern
)
26 table_name
, _
= os
.path
.splitext(filename
)
29 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
31 table_names
= metadata
.tables
.keys()
33 return list(table_names
)
35 def _get_verbose_prints(verbose
):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
40 If `verbose` is false, returns no-op functions.
45 def dummy(*args
, **kwargs
):
48 return dummy
, dummy
, dummy
50 ### Okay, verbose == True; print stuff
52 def print_start(thing
):
53 # Truncate to 66 characters, leaving 10 characters for a success
55 truncated_thing
= thing
[0:66]
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces
= 66 - len(truncated_thing
)
60 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
63 if sys
.stdout
.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
69 def print_status(msg
):
70 # Overwrite any status text with spaces before printing
71 sys
.stdout
.write('\b' * backspaces
[0])
72 sys
.stdout
.write(' ' * backspaces
[0])
73 sys
.stdout
.write('\b' * backspaces
[0])
76 backspaces
[0] = len(msg
)
78 def print_done(msg
='ok'):
79 # Overwrite any status text with spaces before printing
80 sys
.stdout
.write('\b' * backspaces
[0])
81 sys
.stdout
.write(' ' * backspaces
[0])
82 sys
.stdout
.write('\b' * backspaces
[0])
83 sys
.stdout
.write(msg
+ "\n")
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg
):
92 def print_done(msg
='ok'):
95 return print_start
, print_status
, print_done
98 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False, safe
=True):
99 """Load data from CSV files into the given database session.
101 Tables are created automatically.
104 SQLAlchemy session to use.
107 List of tables to load. If omitted, all tables are loaded.
110 Directory the CSV files reside in. Defaults to the `pokedex` data
114 If set to True, existing `pokedex`-related tables will be dropped.
117 If set to True, status messages will be printed to stdout.
120 If set to False, load can be faster, but can corrupt the database if
121 it crashes or is interrupted.
124 # First take care of verbosity
125 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
128 if directory
is None:
129 directory
= get_default_csv_dir()
131 table_names
= _get_table_names(metadata
, tables
)
132 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
133 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
135 # SQLite speed tweaks
136 if not safe
and session
.connection().dialect
.name
== 'sqlite':
137 session
.connection().execute("PRAGMA synchronous=OFF")
138 session
.connection().execute("PRAGMA journal_mode=OFF")
140 # Drop all tables if requested
142 print_start('Dropping tables')
143 for table
in reversed(table_objs
):
144 table
.drop(checkfirst
=True)
147 for table
in table_objs
:
149 connection
= session
.connection()
151 # Okay, run through the tables and actually load the data now
152 for table_obj
in table_objs
:
153 table_name
= table_obj
.name
154 insert_stmt
= table_obj
.insert()
156 print_start(table_name
)
159 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
160 csvfile
= open(csvpath
, 'rb')
162 # File doesn't exist; don't load anything!
163 print_done('missing?')
166 csvsize
= os
.stat(csvpath
).st_size
168 reader
= csv
.reader(csvfile
, lineterminator
='\n')
169 column_names
= [unicode(column
) for column
in reader
.next()]
171 # Self-referential tables may contain rows with foreign keys of other
172 # rows in the same table that do not yet exist. Pull these out and add
173 # them to the session last
174 # ASSUMPTION: Self-referential tables have a single PK called "id"
175 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
176 seen_ids
= {} # primary key we've seen => 1
178 # Fetch foreign key columns that point at this table, if any
179 self_ref_columns
= []
180 for column
in table_obj
.c
:
181 if any(_
.references(table_obj
) for _
in column
.foreign_keys
):
182 self_ref_columns
.append(column
)
185 def insert_and_commit():
186 session
.connection().execute(insert_stmt
, new_rows
)
190 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
191 print_status(progress
)
196 for column_name
, value
in zip(column_names
, csvs
):
197 column
= table_obj
.c
[column_name
]
198 if column
.nullable
and value
== '':
199 # Empty string in a nullable column really means NULL
201 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
202 # Boolean values are stored as string values 0/1, but both
203 # of those evaluate as true; SQLA wants True/False
209 # Otherwise, unflatten from bytes
210 value
= value
.decode('utf-8')
212 # nb: Dictionaries flattened with ** have to have string keys
213 row_data
[ str(column_name
) ] = value
215 # May need to stash this row and add it later if it refers to a
216 # later row in this table
218 foreign_ids
= [row_data
[_
.name
] for _
in self_ref_columns
]
219 foreign_ids
= [_
for _
in foreign_ids
if _
] # remove NULL ids
222 # NULL key. Remember this row and add as usual.
223 seen_ids
[row_data
['id']] = 1
225 elif all(_
in seen_ids
for _
in foreign_ids
):
226 # Non-NULL key we've already seen. Remember it and commit
227 # so we know the old row exists when we add the new one
229 seen_ids
[row_data
['id']] = 1
232 # Non-NULL future id. Save this and insert it later!
233 deferred_rows
.append((row_data
, foreign_ids
))
237 new_rows
.append(row_data
)
239 # Remembering some zillion rows in the session consumes a lot of
240 # RAM. Let's not do that. Commit every 1000 rows
241 if len(new_rows
) >= 1000:
246 # Attempt to add any spare rows we've collected
247 for row_data
, foreign_ids
in deferred_rows
:
248 if not all(_
in seen_ids
for _
in foreign_ids
):
249 # Could happen if row A refers to B which refers to C.
250 # This is ridiculous and doesn't happen in my data so far
251 raise ValueError("Too many levels of self-reference! "
252 "Row was: " + str(row
))
254 session
.connection().execute(
255 insert_stmt
.values(**row_data
)
257 seen_ids
[row_data
['id']] = 1
263 if session
.connection().dialect
.name
== 'sqlite':
264 session
.connection().execute("PRAGMA integrity_check")
268 def dump(session
, tables
=[], directory
=None, verbose
=False):
269 """Dumps the contents of a database to a set of CSV files. Probably not
270 useful to anyone besides a developer.
273 SQLAlchemy session to use.
276 List of tables to dump. If omitted, all tables are dumped.
279 Directory the CSV files should be put in. Defaults to the `pokedex`
283 If set to True, status messages will be printed to stdout.
286 # First take care of verbosity
287 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
291 directory
= get_default_csv_dir()
293 table_names
= _get_table_names(metadata
, tables
)
297 for table_name
in table_names
:
298 print_start(table_name
)
299 table
= metadata
.tables
[table_name
]
301 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
303 columns
= [col
.name
for col
in table
.columns
]
304 writer
.writerow(columns
)
306 primary_key
= table
.primary_key
307 for row
in session
.query(table
).order_by(*primary_key
).all():
310 # Convert Pythony values to something more universal
311 val
= getattr(row
, col
)
319 val
= unicode(val
).encode('utf-8')
323 writer
.writerow(csvs
)