3908d8d5ee6e5bbdf9b2e85ced8e4af870826cc5
[zzz-pokedex.git] / pokedex / __init__.py
1 # encoding: utf8
2 import sys
3
4 from sqlalchemy.exc import IntegrityError
5 import sqlalchemy.types
6
7 from .db import connect, metadata, tables as tables_module
8 from pokedex.lookup import lookup as pokedex_lookup
9
10 def main():
11 if len(sys.argv) <= 1:
12 help()
13
14 command = sys.argv[1]
15 args = sys.argv[2:]
16
17 # Find the command as a function in this file
18 func = globals().get(command, None)
19 if func and callable(func) and command != 'main':
20 func(*args)
21 else:
22 help()
23
24
25 def csvimport(engine_uri, directory='.'):
26 import csv
27
28 from sqlalchemy.orm.attributes import instrumentation_registry
29
30 # Use autocommit in case rows fail due to foreign key incest
31 session = connect(engine_uri, autocommit=True, autoflush=False)
32
33 metadata.create_all()
34
35 # SQLAlchemy is retarded and there is no way for me to get a list of ORM
36 # classes besides to inspect the module they all happen to live in for
37 # things that look right.
38 table_base = tables_module.TableBase
39 orm_classes = {} # table object => table class
40
41 for name in dir(tables_module):
42 # dir() returns strings! How /convenient/.
43 thingy = getattr(tables_module, name)
44
45 if not isinstance(thingy, type):
46 # Not a class; bail
47 continue
48 elif not issubclass(thingy, table_base):
49 # Not a declarative table; bail
50 continue
51 elif thingy == table_base:
52 # Declarative table base, so not a real table; bail
53 continue
54
55 # thingy is definitely a table class! Hallelujah.
56 orm_classes[thingy.__table__] = thingy
57
58 # Okay, run through the tables and actually load the data now
59 for table_obj in metadata.sorted_tables:
60 table_class = orm_classes[table_obj]
61 table_name = table_obj.name
62
63 # Print the table name but leave the cursor in a fixed column
64 print table_name + '...', ' ' * (40 - len(table_name)),
65 sys.stdout.flush()
66
67 try:
68 csvfile = open("%s/%s.csv" % (directory, table_name), 'rb')
69 except IOError:
70 # File doesn't exist; don't load anything!
71 print 'no data!'
72 continue
73
74 reader = csv.reader(csvfile, lineterminator='\n')
75 column_names = [unicode(column) for column in reader.next()]
76
77 # Self-referential tables may contain rows with foreign keys of
78 # other rows in the same table that do not yet exist. We'll keep
79 # a running list of these and try inserting them again after the
80 # rest are done
81 failed_rows = []
82
83 for csvs in reader:
84 row = table_class()
85
86 for column_name, value in zip(column_names, csvs):
87 column = table_obj.c[column_name]
88 if column.nullable and value == '':
89 # Empty string in a nullable column really means NULL
90 value = None
91 elif isinstance(column.type, sqlalchemy.types.Boolean):
92 # Boolean values are stored as string values 0/1, but both
93 # of those evaluate as true; SQLA wants True/False
94 if value == '0':
95 value = False
96 else:
97 value = True
98 else:
99 # Otherwise, unflatten from bytes
100 value = value.decode('utf-8')
101
102 setattr(row, column_name, value)
103
104 try:
105 session.add(row)
106 session.flush()
107 except IntegrityError as e:
108 failed_rows.append(row)
109
110 # Loop over the failed rows and keep trying to insert them. If a loop
111 # doesn't manage to insert any rows, bail.
112 do_another_loop = True
113 while failed_rows and do_another_loop:
114 do_another_loop = False
115
116 for i, row in enumerate(failed_rows):
117 try:
118 session.add(row)
119 session.flush()
120
121 # Success!
122 del failed_rows[i]
123 do_another_loop = True
124 except IntegrityError as e:
125 pass
126
127 if failed_rows:
128 print len(failed_rows), "rows failed"
129 else:
130 print 'loaded'
131
132 def csvexport(engine_uri, directory='.'):
133 import csv
134 session = connect(engine_uri)
135
136 for table_name in sorted(metadata.tables.keys()):
137 print table_name
138 table = metadata.tables[table_name]
139
140 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
141 lineterminator='\n')
142 columns = [col.name for col in table.columns]
143 writer.writerow(columns)
144
145 for row in session.query(table).all():
146 csvs = []
147 for col in columns:
148 # Convert Pythony values to something more universal
149 val = getattr(row, col)
150 if val == None:
151 val = ''
152 elif val == True:
153 val = '1'
154 elif val == False:
155 val = '0'
156 else:
157 val = unicode(val).encode('utf-8')
158
159 csvs.append(val)
160
161 writer.writerow(csvs)
162
163 def lookup(engine_uri, name):
164 # XXX don't require uri! somehow
165 session = connect(engine_uri)
166
167 results = pokedex_lookup(session, name)
168 print "Matched:"
169 for object, matchiness in results:
170 print object.__tablename__, object.name, "(%.03f)" % matchiness
171
172
173 def help():
174 print u"""pokedex -- a command-line Pokédex interface
175
176 help Displays this message.
177 lookup {uri} [name] Look up something in the Pokédex.
178
179 These commands are only useful for developers:
180 csvimport {uri} [dir] Import data from a set of CSVs to the database
181 given by the URI.
182 csvexport {uri} [dir] Export data from the database given by the URI
183 to a set of CSVs.
184 Directory defaults to cwd.
185 """
186
187 sys.exit(0)