Stubbing out a lookup function. #15
[zzz-pokedex.git] / pokedex / __init__.py
1 # encoding: utf8
2 import sys
3
4 from sqlalchemy.exc import IntegrityError
5 import sqlalchemy.types
6
7 from .db import connect, metadata, tables as tables_module
8
9 def main():
10 if len(sys.argv) <= 1:
11 help()
12
13 command = sys.argv[1]
14 args = sys.argv[2:]
15
16 # Find the command as a function in this file
17 func = globals().get(command, None)
18 if func and callable(func) and command != 'main':
19 func(*args)
20 else:
21 help()
22
23
24 def csvimport(engine_uri, directory='.'):
25 import csv
26
27 from sqlalchemy.orm.attributes import instrumentation_registry
28
29 # Use autocommit in case rows fail due to foreign key incest
30 session = connect(engine_uri, autocommit=True, autoflush=False)
31
32 metadata.create_all()
33
34 # SQLAlchemy is retarded and there is no way for me to get a list of ORM
35 # classes besides to inspect the module they all happen to live in for
36 # things that look right.
37 table_base = tables_module.TableBase
38 orm_classes = {} # table object => table class
39
40 for name in dir(tables_module):
41 # dir() returns strings! How /convenient/.
42 thingy = getattr(tables_module, name)
43
44 if not isinstance(thingy, type):
45 # Not a class; bail
46 continue
47 elif not issubclass(thingy, table_base):
48 # Not a declarative table; bail
49 continue
50 elif thingy == table_base:
51 # Declarative table base, so not a real table; bail
52 continue
53
54 # thingy is definitely a table class! Hallelujah.
55 orm_classes[thingy.__table__] = thingy
56
57 # Okay, run through the tables and actually load the data now
58 for table_obj in metadata.sorted_tables:
59 table_class = orm_classes[table_obj]
60 table_name = table_obj.name
61
62 # Print the table name but leave the cursor in a fixed column
63 print table_name + '...', ' ' * (40 - len(table_name)),
64 sys.stdout.flush()
65
66 try:
67 csvfile = open("%s/%s.csv" % (directory, table_name), 'rb')
68 except IOError:
69 # File doesn't exist; don't load anything!
70 print 'no data!'
71 continue
72
73 reader = csv.reader(csvfile, lineterminator='\n')
74 column_names = [unicode(column) for column in reader.next()]
75
76 # Self-referential tables may contain rows with foreign keys of
77 # other rows in the same table that do not yet exist. We'll keep
78 # a running list of these and try inserting them again after the
79 # rest are done
80 failed_rows = []
81
82 for csvs in reader:
83 row = table_class()
84
85 for column_name, value in zip(column_names, csvs):
86 column = table_obj.c[column_name]
87 if column.nullable and value == '':
88 # Empty string in a nullable column really means NULL
89 value = None
90 elif isinstance(column.type, sqlalchemy.types.Boolean):
91 # Boolean values are stored as string values 0/1, but both
92 # of those evaluate as true; SQLA wants True/False
93 if value == '0':
94 value = False
95 else:
96 value = True
97 else:
98 # Otherwise, unflatten from bytes
99 value = value.decode('utf-8')
100
101 setattr(row, column_name, value)
102
103 try:
104 session.add(row)
105 session.flush()
106 except IntegrityError as e:
107 failed_rows.append(row)
108
109 # Loop over the failed rows and keep trying to insert them. If a loop
110 # doesn't manage to insert any rows, bail.
111 do_another_loop = True
112 while failed_rows and do_another_loop:
113 do_another_loop = False
114
115 for i, row in enumerate(failed_rows):
116 try:
117 session.add(row)
118 session.flush()
119
120 # Success!
121 del failed_rows[i]
122 do_another_loop = True
123 except IntegrityError as e:
124 pass
125
126 if failed_rows:
127 print len(failed_rows), "rows failed"
128 else:
129 print 'loaded'
130
131 def csvexport(engine_uri, directory='.'):
132 import csv
133 session = connect(engine_uri)
134
135 for table_name in sorted(metadata.tables.keys()):
136 print table_name
137 table = metadata.tables[table_name]
138
139 writer = csv.writer(open("%s/%s.csv" % (directory, table_name), 'wb'),
140 lineterminator='\n')
141 columns = [col.name for col in table.columns]
142 writer.writerow(columns)
143
144 for row in session.query(table).all():
145 csvs = []
146 for col in columns:
147 # Convert Pythony values to something more universal
148 val = getattr(row, col)
149 if val == None:
150 val = ''
151 elif val == True:
152 val = '1'
153 elif val == False:
154 val = '0'
155 else:
156 val = unicode(val).encode('utf-8')
157
158 csvs.append(val)
159
160 writer.writerow(csvs)
161
162
163 def help():
164 print u"""pokedex -- a command-line Pokédex interface
165
166 help Displays this message.
167
168 These commands are only useful for developers:
169 csvimport {uri} [dir] Import data from a set of CSVs to the database
170 given by the URI.
171 csvexport {uri} [dir] Export data from the database given by the URI
172 to a set of CSVs.
173 Directory defaults to cwd.
174 """
175
176 sys.exit(0)