df0d0f171ba322a8d1d9a5f0a919970bab1b100c
1 """CSV to database or vice versa."""
7 from sqlalchemy
.orm
.attributes
import instrumentation_registry
8 import sqlalchemy
.sql
.util
9 import sqlalchemy
.types
11 from pokedex
.db
import metadata
12 import pokedex
.db
.tables
as tables
13 from pokedex
.defaults
import get_default_csv_dir
16 def _get_table_names(metadata
, patterns
):
17 """Returns a list of table names from the given metadata. If `patterns`
18 exists, only tables matching one of the patterns will be returned.
22 for pattern
in patterns
:
23 if '.' in pattern
or '/' in pattern
:
24 # If it looks like a filename, pull out just the table name
25 _
, filename
= os
.path
.split(pattern
)
26 table_name
, _
= os
.path
.splitext(filename
)
29 table_names
.update(fnmatch
.filter(metadata
.tables
.keys(), pattern
))
31 table_names
= metadata
.tables
.keys()
33 return list(table_names
)
35 def _get_verbose_prints(verbose
):
36 """If `verbose` is true, returns three functions: one for printing a
37 starting message, one for printing an interim status update, and one for
38 printing a success or failure message when finished.
40 If `verbose` is false, returns no-op functions.
45 def dummy(*args
, **kwargs
):
48 return dummy
, dummy
, dummy
50 ### Okay, verbose == True; print stuff
52 def print_start(thing
):
53 # Truncate to 66 characters, leaving 10 characters for a success
55 truncated_thing
= thing
[0:66]
57 # Also, space-pad to keep the cursor in a known column
58 num_spaces
= 66 - len(truncated_thing
)
60 print "%s...%s" %
(truncated_thing
, ' ' * num_spaces
),
63 if sys
.stdout
.isatty():
64 # stdout is a terminal; stupid backspace tricks are OK.
65 # Don't use print, because it always adds magical spaces, which
66 # makes backspace accounting harder
69 def print_status(msg
):
70 # Overwrite any status text with spaces before printing
71 sys
.stdout
.write('\b' * backspaces
[0])
72 sys
.stdout
.write(' ' * backspaces
[0])
73 sys
.stdout
.write('\b' * backspaces
[0])
76 backspaces
[0] = len(msg
)
78 def print_done(msg
='ok'):
79 # Overwrite any status text with spaces before printing
80 sys
.stdout
.write('\b' * backspaces
[0])
81 sys
.stdout
.write(' ' * backspaces
[0])
82 sys
.stdout
.write('\b' * backspaces
[0])
83 sys
.stdout
.write(msg
+ "\n")
88 # stdout is a file (or something); don't bother with status at all
89 def print_status(msg
):
92 def print_done(msg
='ok'):
95 return print_start
, print_status
, print_done
98 def load(session
, tables
=[], directory
=None, drop_tables
=False, verbose
=False):
99 """Load data from CSV files into the given database session.
101 Tables are created automatically.
104 SQLAlchemy session to use.
107 List of tables to load. If omitted, all tables are loaded.
110 Directory the CSV files reside in. Defaults to the `pokedex` data
114 If set to True, existing `pokedex`-related tables will be dropped.
117 If set to True, status messages will be printed to stdout.
120 # First take care of verbosity
121 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
124 if directory
is None:
125 directory
= get_default_csv_dir()
127 table_names
= _get_table_names(metadata
, tables
)
128 table_objs
= [metadata
.tables
[name
] for name
in table_names
]
129 table_objs
= sqlalchemy
.sql
.util
.sort_tables(table_objs
)
132 # Drop all tables if requested
134 print_start('Dropping tables')
135 for table
in reversed(table_objs
):
136 table
.drop(checkfirst
=True)
139 for table
in table_objs
:
141 connection
= session
.connection()
143 # Okay, run through the tables and actually load the data now
144 for table_obj
in table_objs
:
145 table_name
= table_obj
.name
146 insert_stmt
= table_obj
.insert()
148 print_start(table_name
)
151 csvpath
= "%s/%s.csv" %
(directory
, table_name
)
152 csvfile
= open(csvpath
, 'rb')
154 # File doesn't exist; don't load anything!
155 print_done('missing?')
158 csvsize
= os
.stat(csvpath
).st_size
160 reader
= csv
.reader(csvfile
, lineterminator
='\n')
161 column_names
= [unicode(column
) for column
in reader
.next()]
163 # Self-referential tables may contain rows with foreign keys of other
164 # rows in the same table that do not yet exist. Pull these out and add
165 # them to the session last
166 # ASSUMPTION: Self-referential tables have a single PK called "id"
167 deferred_rows
= [] # ( row referring to id, [foreign ids we need] )
168 seen_ids
= {} # primary key we've seen => 1
170 # Fetch foreign key columns that point at this table, if any
171 self_ref_columns
= []
172 for column
in table_obj
.c
:
173 if any(_
.references(table_obj
) for _
in column
.foreign_keys
):
174 self_ref_columns
.append(column
)
177 def insert_and_commit():
178 session
.connection().execute(insert_stmt
, new_rows
)
182 progress
= "%d%%" %
(100 * csvfile
.tell() // csvsize
)
183 print_status(progress
)
188 for column_name
, value
in zip(column_names
, csvs
):
189 column
= table_obj
.c
[column_name
]
190 if column
.nullable
and value
== '':
191 # Empty string in a nullable column really means NULL
193 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
194 # Boolean values are stored as string values 0/1, but both
195 # of those evaluate as true; SQLA wants True/False
201 # Otherwise, unflatten from bytes
202 value
= value
.decode('utf-8')
204 # nb: Dictionaries flattened with ** have to have string keys
205 row_data
[ str(column_name
) ] = value
207 # May need to stash this row and add it later if it refers to a
208 # later row in this table
210 foreign_ids
= [row_data
[_
.name
] for _
in self_ref_columns
]
211 foreign_ids
= [_
for _
in foreign_ids
if _
] # remove NULL ids
214 # NULL key. Remember this row and add as usual.
215 seen_ids
[row_data
['id']] = 1
217 elif all(_
in seen_ids
for _
in foreign_ids
):
218 # Non-NULL key we've already seen. Remember it and commit
219 # so we know the old row exists when we add the new one
221 seen_ids
[row_data
['id']] = 1
224 # Non-NULL future id. Save this and insert it later!
225 deferred_rows
.append((row_data
, foreign_ids
))
229 new_rows
.append(row_data
)
231 # Remembering some zillion rows in the session consumes a lot of
232 # RAM. Let's not do that. Commit every 1000 rows
233 if len(new_rows
) >= 1000:
238 # Attempt to add any spare rows we've collected
239 for row_data
, foreign_ids
in deferred_rows
:
240 if not all(_
in seen_ids
for _
in foreign_ids
):
241 # Could happen if row A refers to B which refers to C.
242 # This is ridiculous and doesn't happen in my data so far
243 raise ValueError("Too many levels of self-reference! "
244 "Row was: " + str(row
))
246 session
.connection().execute(
247 insert_stmt
.values(**row_data
)
249 seen_ids
[row_data
['id']] = 1
256 def dump(session
, tables
=[], directory
=None, verbose
=False):
257 """Dumps the contents of a database to a set of CSV files. Probably not
258 useful to anyone besides a developer.
261 SQLAlchemy session to use.
264 List of tables to dump. If omitted, all tables are dumped.
267 Directory the CSV files should be put in. Defaults to the `pokedex`
271 If set to True, status messages will be printed to stdout.
274 # First take care of verbosity
275 print_start
, print_status
, print_done
= _get_verbose_prints(verbose
)
279 directory
= get_default_csv_dir()
281 table_names
= _get_table_names(metadata
, tables
)
285 for table_name
in table_names
:
286 print_start(table_name
)
287 table
= metadata
.tables
[table_name
]
289 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
291 columns
= [col
.name
for col
in table
.columns
]
292 writer
.writerow(columns
)
294 primary_key
= table
.primary_key
295 for row
in session
.query(table
).order_by(*primary_key
).all():
298 # Convert Pythony values to something more universal
299 val
= getattr(row
, col
)
307 val
= unicode(val
).encode('utf-8')
311 writer
.writerow(csvs
)