projects
/
zzz-pokedex.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Make PokemonForm.pokemon an actual relation
[zzz-pokedex.git]
/
pokedex
/
db
/
load.py
diff --git
a/pokedex/db/load.py
b/pokedex/db/load.py
index
f0e4b6d
..
d307fd8
100644
(file)
--- a/
pokedex/db/load.py
+++ b/
pokedex/db/load.py
@@
-11,6
+11,7
@@
import sqlalchemy.types
from pokedex.db import metadata
import pokedex.db.tables as tables
from pokedex.defaults import get_default_csv_dir
from pokedex.db import metadata
import pokedex.db.tables as tables
from pokedex.defaults import get_default_csv_dir
+from pokedex.db.dependencies import find_dependent_tables
def _get_table_names(metadata, patterns):
def _get_table_names(metadata, patterns):
@@
-52,7
+53,7
@@
def _get_verbose_prints(verbose):
def print_start(thing):
# Truncate to 66 characters, leaving 10 characters for a success
# or failure message
def print_start(thing):
# Truncate to 66 characters, leaving 10 characters for a success
# or failure message
- truncated_thing = thing[
0
:66]
+ truncated_thing = thing[:66]
# Also, space-pad to keep the cursor in a known column
num_spaces = 66 - len(truncated_thing)
# Also, space-pad to keep the cursor in a known column
num_spaces = 66 - len(truncated_thing)
@@
-95,7
+96,7
@@
def _get_verbose_prints(verbose):
return print_start, print_status, print_done
return print_start, print_status, print_done
-def load(session, tables=[], directory=None, drop_tables=False, verbose=False, safe=True):
+def load(session, tables=[], directory=None, drop_tables=False, verbose=False, safe=True
, recursive=False
):
"""Load data from CSV files into the given database session.
Tables are created automatically.
"""Load data from CSV files into the given database session.
Tables are created automatically.
@@
-119,6
+120,9
@@
def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
`safe`
If set to False, load can be faster, but can corrupt the database if
it crashes or is interrupted.
`safe`
If set to False, load can be faster, but can corrupt the database if
it crashes or is interrupted.
+
+ `recursive`
+ If set to True, load all dependent tables too.
"""
# First take care of verbosity
"""
# First take care of verbosity
@@
-128,8
+132,13
@@
def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
if directory is None:
directory = get_default_csv_dir()
if directory is None:
directory = get_default_csv_dir()
+ # XXX why isn't this done in command_load
table_names = _get_table_names(metadata, tables)
table_objs = [metadata.tables[name] for name in table_names]
table_names = _get_table_names(metadata, tables)
table_objs = [metadata.tables[name] for name in table_names]
+
+ if recursive:
+ table_objs.extend(find_dependent_tables(table_objs))
+
table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
# SQLite speed tweaks
table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
# SQLite speed tweaks
@@
-203,12
+212,12
@@
def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
# them to the session last
# ASSUMPTION: Self-referential tables have a single PK called "id"
deferred_rows = [] # ( row referring to id, [foreign ids we need] )
# them to the session last
# ASSUMPTION: Self-referential tables have a single PK called "id"
deferred_rows = [] # ( row referring to id, [foreign ids we need] )
- seen_ids =
{} # primary key we've seen => 1
+ seen_ids =
set() # primary keys we've seen
# Fetch foreign key columns that point at this table, if any
self_ref_columns = []
for column in table_obj.c:
# Fetch foreign key columns that point at this table, if any
self_ref_columns = []
for column in table_obj.c:
- if any(
_.references(table_obj) for _
in column.foreign_keys):
+ if any(
x.references(table_obj) for x
in column.foreign_keys):
self_ref_columns.append(column)
new_rows = []
self_ref_columns.append(column)
new_rows = []
@@
-247,18
+256,18
@@
def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
# May need to stash this row and add it later if it refers to a
# later row in this table
if self_ref_columns:
# May need to stash this row and add it later if it refers to a
# later row in this table
if self_ref_columns:
- foreign_ids =
[row_data[_.name] for _ in self_ref_columns]
- foreign_ids
= [_ for _ in foreign_ids if _]
# remove NULL ids
+ foreign_ids =
set(row_data[x.name] for x in self_ref_columns)
+ foreign_ids
.discard(None)
# remove NULL ids
if not foreign_ids:
# NULL key. Remember this row and add as usual.
if not foreign_ids:
# NULL key. Remember this row and add as usual.
- seen_ids
[row_data['id']] = 1
+ seen_ids
.add(row_data['id'])
- elif
all(_ in seen_ids for _ in foreig
n_ids):
+ elif
foreign_ids.issubset(see
n_ids):
# Non-NULL key we've already seen. Remember it and commit
# so we know the old row exists when we add the new one
insert_and_commit()
# Non-NULL key we've already seen. Remember it and commit
# so we know the old row exists when we add the new one
insert_and_commit()
- seen_ids
[row_data['id']] = 1
+ seen_ids
.add(row_data['id'])
else:
# Non-NULL future id. Save this and insert it later!
else:
# Non-NULL future id. Save this and insert it later!
@@
-277,7
+286,7
@@
def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
# Attempt to add any spare rows we've collected
for row_data, foreign_ids in deferred_rows:
# Attempt to add any spare rows we've collected
for row_data, foreign_ids in deferred_rows:
- if not
all(_ in seen_ids for _ in foreig
n_ids):
+ if not
foreign_ids.issubset(see
n_ids):
# Could happen if row A refers to B which refers to C.
# This is ridiculous and doesn't happen in my data so far
raise ValueError("Too many levels of self-reference! "
# Could happen if row A refers to B which refers to C.
# This is ridiculous and doesn't happen in my data so far
raise ValueError("Too many levels of self-reference! "
@@
-286,7
+295,7
@@
def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
session.connection().execute(
insert_stmt.values(**row_data)
)
session.connection().execute(
insert_stmt.values(**row_data)
)
- seen_ids
[row_data['id']] = 1
+ seen_ids
.add(row_data['id'])
session.commit()
print_done()
session.commit()
print_done()