4 import sqlalchemy
.types
6 from .db
import connect
, metadata
, tables
as tables_module
15 # Find the command as a function in this file
16 func
= globals().get(command
, None)
17 if func
and callable(func
) and command
!= 'main':
23 def csvimport(engine_uri
, directory
='.'):
26 from sqlalchemy
.orm
.attributes
import instrumentation_registry
28 session
= connect(engine_uri
)
33 # TODO try to insert data in preorder so we don't need this hack and won't
34 # break similarly on other engines
35 if 'mysql' in engine_uri
:
36 session
.execute('SET FOREIGN_KEY_CHECKS = 0')
38 # SQLAlchemy is retarded and there is no way for me to get a list of ORM
39 # classes besides to inspect the module they all happen to live in for
40 # things that look right.
41 table_base
= tables_module
.TableBase
44 for name
in dir(tables_module
):
45 # dir() returns strings! How /convenient/.
46 thingy
= getattr(tables_module
, name
)
48 if not isinstance(thingy
, type):
51 elif not issubclass(thingy
, table_base
):
52 # Not a declarative table; bail
54 elif thingy
== table_base
:
55 # Declarative table base, so not a real table; bail
58 # thingy is definitely a table class! Hallelujah.
59 orm_classes
[thingy
.__table__
.name
] = thingy
61 # Okay, run through the tables and actually load the data now
62 for table_name
, table
in sorted(orm_classes
.items()):
63 # Print the table name but leave the cursor in a fixed column
64 print table_name
+ '...', ' ' * (40 - len(table_name
)),
67 csvfile
= open("%s/%s.csv" %
(directory
, table_name
), 'rb')
69 # File doesn't exist; don't load anything!
73 reader
= csv
.reader(csvfile
, lineterminator
='\n')
74 column_names
= [unicode(column
) for column
in reader
.next()]
79 for column_name
, value
in zip(column_names
, csvs
):
80 column
= table
.__table__
.c
[column_name
]
81 if column
.nullable
and value
== '':
82 # Empty string in a nullable column really means NULL
84 elif isinstance(column
.type, sqlalchemy
.types
.Boolean
):
85 # Boolean values are stored as string values 0/1, but both
86 # of those evaluate as true; SQLA wants True/False
92 # Otherwise, unflatten from bytes
93 value
= value
.decode('utf-8')
95 setattr(row
, column_name
, value
)
102 # Shouldn't matter since this is usually the end of the program and thus
103 # the connection too, but let's change this back just in case
104 if 'mysql' in engine_uri
:
105 session
.execute('SET FOREIGN_KEY_CHECKS = 1')
108 def csvexport(engine_uri
, directory
='.'):
110 session
= connect(engine_uri
)
112 for table_name
in sorted(metadata
.tables
.keys()):
114 table
= metadata
.tables
[table_name
]
116 writer
= csv
.writer(open("%s/%s.csv" %
(directory
, table_name
), 'wb'),
118 columns
= [col
.name
for col
in table
.columns
]
119 writer
.writerow(columns
)
121 for row
in session
.query(table
).all():
124 # Convert Pythony values to something more universal
125 val
= getattr(row
, col
)
133 val
= unicode(val
).encode('utf-8')
137 writer
.writerow(csvs
)
141 print u
"""pokedex -- a command-line Pokédex interface
143 help Displays this message.
145 These commands are only useful for developers:
146 csvimport {uri} [dir] Import data from a set of CSVs to the database
148 csvexport {uri} [dir] Export data from the database given by the URI
150 Directory defaults to cwd.