X-Git-Url: http://git.veekun.com/zzz-pokedex.git/blobdiff_plain/37d6405bf9f93a7cdc64b668764884672f8b2bea..193856e88e36b2108c14fadd83883048ce93954c:/pokedex/db/load.py?ds=inline diff --git a/pokedex/db/load.py b/pokedex/db/load.py index 2b14883..0425dea 100644 --- a/pokedex/db/load.py +++ b/pokedex/db/load.py @@ -1,15 +1,49 @@ """CSV to database or vice versa.""" import csv +import os.path import pkg_resources +import re import sys from sqlalchemy.orm.attributes import instrumentation_registry +import sqlalchemy.sql.util import sqlalchemy.types from pokedex.db import metadata import pokedex.db.tables as tables +def _wildcard_char_to_regex(char): + """Converts a single wildcard character to the regex equivalent.""" + + if char == '?': + return '.?' + elif char == '*': + return '.*' + else: + return re.escape(char) + +def _wildcard_glob_to_regex(glob): + """Converts a single wildcard glob to a regex STRING.""" + + # If it looks like a filename, make it not one + if '.' in glob or '/' in glob: + _, filename = os.path.split(glob) + table_name, _ = os.path.splitext(filename) + glob = table_name + + return u''.join(map(_wildcard_char_to_regex, glob)) + +def _wildcards_to_regex(strings): + """Converts a list of wildcard globs to a single regex object.""" + + regex_parts = map(_wildcard_glob_to_regex, strings) + + regex = '^(?:' + '|'.join(regex_parts) + ')$' + + return re.compile(regex) + + def _get_verbose_prints(verbose): """If `verbose` is true, returns two functions: one for printing a starting message, and the other for printing a success or failure message when @@ -44,7 +78,7 @@ def _get_verbose_prints(verbose): return dummy, dummy -def load(session, directory=None, drop_tables=False, verbose=False): +def load(session, tables=[], directory=None, drop_tables=False, verbose=False): """Load data from CSV files into the given database session. Tables are created automatically. @@ -52,6 +86,9 @@ def load(session, directory=None, drop_tables=False, verbose=False): `session` SQLAlchemy session to use. + `tables` + List of tables to load. If omitted, all tables are loaded. + `directory` Directory the CSV files reside in. Defaults to the `pokedex` data directory. @@ -70,17 +107,29 @@ def load(session, directory=None, drop_tables=False, verbose=False): if not directory: directory = pkg_resources.resource_filename('pokedex', 'data/csv') + if tables: + regex = _wildcards_to_regex(tables) + table_names = filter(regex.match, metadata.tables.keys()) + else: + table_names = metadata.tables.keys() + + table_objs = [metadata.tables[name] for name in table_names] + table_objs = sqlalchemy.sql.util.sort_tables(table_objs) + + # Drop all tables if requested if drop_tables: print_start('Dropping tables') - metadata.drop_all() + for table in reversed(table_objs): + table.drop(checkfirst=True) print_done() - metadata.create_all() + for table in table_objs: + table.create() connection = session.connection() # Okay, run through the tables and actually load the data now - for table_obj in metadata.sorted_tables: + for table_obj in table_objs: table_name = table_obj.name insert_stmt = table_obj.insert() @@ -163,7 +212,7 @@ def load(session, directory=None, drop_tables=False, verbose=False): # Remembering some zillion rows in the session consumes a lot of # RAM. Let's not do that. Commit every 1000 rows - if len(new_rows) > 1000: + if len(new_rows) >= 1000: insert_and_commit() insert_and_commit() @@ -186,13 +235,16 @@ def load(session, directory=None, drop_tables=False, verbose=False): -def dump(session, directory=None, verbose=False): +def dump(session, tables=[], directory=None, verbose=False): """Dumps the contents of a database to a set of CSV files. Probably not useful to anyone besides a developer. `session` SQLAlchemy session to use. + `tables` + List of tables to dump. If omitted, all tables are dumped. + `directory` Directory the CSV files should be put in. Defaults to the `pokedex` data directory. @@ -208,7 +260,16 @@ def dump(session, directory=None, verbose=False): if not directory: directory = pkg_resources.resource_filename('pokedex', 'data/csv') - for table_name in sorted(metadata.tables.keys()): + if tables: + regex = _wildcards_to_regex(tables) + table_names = filter(regex.match, metadata.tables.keys()) + else: + table_names = metadata.tables.keys() + + table_names.sort() + + + for table_name in table_names: print_start(table_name) table = metadata.tables[table_name]