From 42ec8abc587dd95f06521df28aee6063cb80f009 Mon Sep 17 00:00:00 2001
From: a_magical_me <andrew@turnipmints.mooo.com>
Date: Sun, 3 Apr 2011 02:10:33 -0700
Subject: [PATCH] load: Add --recursive option.

Helps somewhat with #526 (`pokedex load` is slow) by making it easier to
load only the tables you're interested in.
---
 pokedex/db/dependencies.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++
 pokedex/db/load.py         | 11 +++++++++-
 pokedex/main.py            |  9 ++++++--
 3 files changed, 71 insertions(+), 3 deletions(-)
 create mode 100644 pokedex/db/dependencies.py

diff --git a/pokedex/db/dependencies.py b/pokedex/db/dependencies.py
new file mode 100644
index 0000000..1f1b118
--- /dev/null
+++ b/pokedex/db/dependencies.py
@@ -0,0 +1,54 @@
+import sqlalchemy.sql.visitors as visitors
+
+from pokedex.db.tables import metadata
+
+# stolen from sqlalchemy.sql.util.sort_tables
+def compute_dependencies(tables):
+    """Construct a reverse dependency graph for the given tables.
+
+    Returns a dict which maps a table to the list of tables which depend on it.
+    """
+    tables = list(tables)
+    graph = {}
+    def visit_foreign_key(fkey):
+        if fkey.use_alter:
+            return
+        parent_table = fkey.column.table
+        if parent_table in tables:
+            child_table = fkey.parent.table
+            if parent_table is not child_table:
+                graph.setdefault(parent_table, []).append(child_table)
+
+    for table in tables:
+        visitors.traverse(table,
+                          {'schema_visitor': True},
+                          {'foreign_key': visit_foreign_key})
+
+        graph.setdefault(table, []).extend(table._extra_dependencies)
+
+    return graph
+
+#: The dependency graph for pokedex.db.tables
+_pokedex_graph = compute_dependencies(metadata.tables.values())
+
+def find_dependent_tables(tables, graph=None):
+    """Recursively find all tables which depend on the given tables.
+
+    The returned set does not include the original tables.
+    """
+    if graph is None:
+        graph = _pokedex_graph
+    tables = list(tables)
+    dependents = set()
+    def add_dependents_of(table):
+        for dependent_table in graph.get(table, []):
+            if dependent_table not in dependents:
+                dependents.add(dependent_table)
+                add_dependents_of(dependent_table)
+
+    for table in tables:
+        add_dependents_of(table)
+
+    dependents -= set(tables)
+
+    return dependents
diff --git a/pokedex/db/load.py b/pokedex/db/load.py
index f0e4b6d..88f5332 100644
--- a/pokedex/db/load.py
+++ b/pokedex/db/load.py
@@ -11,6 +11,7 @@ import sqlalchemy.types
 from pokedex.db import metadata
 import pokedex.db.tables as tables
 from pokedex.defaults import get_default_csv_dir
+from pokedex.db.dependencies import find_dependent_tables
 
 
 def _get_table_names(metadata, patterns):
@@ -95,7 +96,7 @@ def _get_verbose_prints(verbose):
     return print_start, print_status, print_done
 
 
-def load(session, tables=[], directory=None, drop_tables=False, verbose=False, safe=True):
+def load(session, tables=[], directory=None, drop_tables=False, verbose=False, safe=True, recursive=False):
     """Load data from CSV files into the given database session.
 
     Tables are created automatically.
@@ -119,6 +120,9 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
     `safe`
         If set to False, load can be faster, but can corrupt the database if
         it crashes or is interrupted.
+
+    `recursive`
+        If set to True, load all dependent tables too.
     """
 
     # First take care of verbosity
@@ -128,8 +132,13 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
     if directory is None:
         directory = get_default_csv_dir()
 
+    # XXX why isn't this done in command_load
     table_names = _get_table_names(metadata, tables)
     table_objs = [metadata.tables[name] for name in table_names]
+
+    if recursive:
+        table_objs.extend(find_dependent_tables(table_objs))
+
     table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
 
     # SQLite speed tweaks
diff --git a/pokedex/main.py b/pokedex/main.py
index e9810e5..68dc097 100644
--- a/pokedex/main.py
+++ b/pokedex/main.py
@@ -121,6 +121,7 @@ def command_load(*args):
     parser = get_parser(verbose=True)
     parser.add_option('-d', '--directory', dest='directory', default=None)
     parser.add_option('-D', '--drop-tables', dest='drop_tables', default=False, action='store_true')
+    parser.add_option('-r', '--recursive', dest='recursive', default=False, action='store_true')
     parser.add_option('-S', '--safe', dest='safe', default=False, action='store_true',
         help="Do not use backend-specific optimalizations.")
     options, tables = parser.parse_args(list(args))
@@ -139,7 +140,8 @@ def command_load(*args):
                                   drop_tables=options.drop_tables,
                                   tables=tables,
                                   verbose=options.verbose,
-                                  safe=options.safe)
+                                  safe=options.safe,
+                                  recursive=options.recursive)
 
 def command_reindex(*args):
     parser = get_parser(verbose=True)
@@ -277,7 +279,10 @@ System options:
     -d|--directory=DIR  By default, load and dump will use the CSV files in the
                         pokedex install directory.  Use this option to specify
                         a different directory.
-    -D|--drop-tables    With load, drop all tables before loading data.
+
+Load options:
+    -D|--drop-tables    Drop all tables before loading data.
+    -r|--recursive      Load (and drop) all dependent tables.
 
     Additionally, load and dump accept a list of table names (possibly with
     wildcards) and/or csv fileames as an argument list.
-- 
2.7.4