Package install :: Package MoSTBioDat :: Package DataBase :: Package ForgetSQL2 :: Module testforgetsql2
[hide private]
[frames] | no frames]

Source Code for Module install.MoSTBioDat.DataBase.ForgetSQL2.testforgetsql2

  1  #!/usr/bin/env python 
  2  # *-* encoding: utf8 
  3  #  
  4  # Copyright (c) 2005-2006 Stian Soiland 
  5  #  
  6  # This library is free software; you can redistribute it and/or 
  7  # modify it under the terms of the GNU Lesser General Public 
  8  # License as published by the Free Software Foundation; either 
  9  # version 2.1 of the License, or (at your option) any later version. 
 10  #  
 11  # This library is distributed in the hope that it will be useful, 
 12  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 13  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU 
 14  # Lesser General Public License for more details. 
 15  #  
 16  # You should have received a copy of the GNU Lesser General Public 
 17  # License along with this library; if not, write to the Free Software 
 18  # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA 
 19  # 
 20  # Author: Stian Soiland <stian@soiland.no> 
 21  # URL: http://soiland.no/i/src/forgetsql2/ 
 22  # License: LGPL 
 23  # 
 24  """Tests for forgetSQL2. 
 25   
 26  By itself this module will test using temporary sqlite databases and the 
 27  accompanied test-data.sql. 
 28   
 29  Testing with MySQL host=localhost user=stain pw=(blank) db=test: 
 30      ./testforgetsql2.py --mysql localhost,stain,,test 
 31   
 32  Testing with PostgreSQL host=localhost 
 33      ./testforgetsql2.py --postgresql 'host=localhost'  
 34  """ 
 35   
 36   
 37  import sys 
 38  import os 
 39  import codecs 
 40  import unittest 
 41  import StringIO 
 42  import logging 
 43  import re 
 44  import gc 
 45  from sets import Set 
 46  from doc_exception import ProgrammingError 
 47   
 48  from forgetsql2 import Database, TableBuilder, DBConnect 
 49  from forgetsql2 import NotFoundError, generate 
 50   
 51  gc.disable() 
 52               
53 -class TestFramework(unittest.TestCase):
54 # Change to "mysql" to test against mysql database 55 db_mod = "sqlite" 56 # if db_mod=="mysql", set this to mysql.connect() parameters 57 db_connect = None 58 prepared = False
59 - def setUp(self):
60 self.prepareLogger() 61 self.prepareDatabase() 62 self.prepareTables() 63 self.prepareClasses()
64
65 - def tearDown(self):
66 self.assertEqual(self.lastLog(), "") 67 if self.db_mod == "sqlite": 68 try: 69 os.unlink(self.db_connect["database"]) 70 except OSError: 71 pass 72 self.db_c.close()
73
74 - def prepareClasses(self):
75 """Prepare subclasses of Database and TableBuilder. 76 77 The reason for this is to introduce the _db-class method 78 without modifying the actual Database class. 79 80 The new subclasses are accessible as self.Database and 81 self.TableBuilder. 82 """ 83 self.db_c = DBConnect(self.db, self.db_connect) 84 class DB(Database): 85 _db = self.db_c
86 self.Database = DB 87 class TB(TableBuilder): 88 _db = self.db_c
89 self.TableBuilder = TB 90 91 # Copied from the Cerebrum project Ceresync 92 # (It should be OK since we make this module GPL)
93 - def prepareLogger(self):
94 """Make logger use a buffer instead of stderr. 95 96 The lines logged (format: "WARNING: asdlksldk") can be 97 fetched with the method lastLog. 98 """ 99 self.logbuf = StringIO.StringIO() 100 logger = logging.getLogger() 101 loghandler = logging.StreamHandler(self.logbuf) 102 loghandler.setLevel(logging.INFO) 103 loghandler.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) 104 del logger.handlers[:] # any old ones goodbye 105 logger.addHandler(loghandler) 106 logger.setLevel(logging.INFO)
107
108 - def lastLog(self):
109 """Returns what's been logged since last call.""" 110 last = self.logbuf.getvalue() 111 self.logbuf.seek(0) 112 self.logbuf.truncate() 113 return last
114
115 - def prepareDatabase(self):
116 if self.db_mod == "sqlite": 117 self.prepareDatabaseSqlite() 118 elif self.db_mod == "mysql": 119 self.prepareDatabaseMysql() 120 elif self.db_mod == "postgresql": 121 self.prepareDatabasePostgresql() 122 else: 123 raise "Unknown db_mod", self.db_mod
124
125 - def prepareDatabaseMysql(self):
126 import MySQLdb as db 127 self.db = db 128 assert self.db_connect
129
130 - def prepareDatabasePostgresql(self):
131 import psycopg as db 132 self.db = db 133 assert self.db_connect
134
135 - def prepareDatabaseSqlite(self):
136 # FIXME: Should also do all the tests with mysql 137 import tempfile 138 try: 139 from pysqlite2 import dbapi2 as db 140 except ImportError, e: 141 print >>sys.stderr, "Need pysqlite2 for db testing" 142 sys.exit(2) 143 self.db = db 144 # Include db.PARSE_DECLTYPES to support boolean conversion 145 self.db_connect = dict(database=tempfile.mktemp(), 146 detect_types=db.PARSE_DECLTYPES)
147
148 - def prepareTables(self):
149 # First connect raw 150 if isinstance(self.db_connect, dict): 151 db = self.db.connect(**self.db_connect) 152 else: 153 db = self.db.connect(*self.db_connect) 154 if self.db_mod == "postgresql": 155 db.autocommit(1) 156 try: 157 c = db.cursor() 158 if "sqlite" in str(db).lower(): 159 # Ignore sync-checks for faster import 160 c.execute("pragma synchronous = off") 161 # Find our root by inspecting our own module 162 import testforgetsql2 163 root = os.path.dirname(testforgetsql2.__file__) 164 file = os.path.join(root, "test-data.sql") 165 sql = codecs.open(file, encoding="utf8").read() 166 167 # DROP TABLE 168 if self.db_mod == "mysql": 169 for table in ("county", "municipal", "postal", "insertion", 170 "shop", "changed"): 171 c.execute("DROP TABLE IF EXISTS %s" % table) 172 elif self.db_mod == "postgresql": 173 c.execute("""SELECT tablename FROM pg_catalog.pg_tables 174 WHERE schemaname=pg_catalog.current_schema()""") 175 existing = c.fetchall() 176 for table in ("county", "municipal", "postal", "insertion", 177 "shop", "changed"): 178 if (table,) in existing: 179 c.execute("DROP TABLE %s" % table) 180 elif self.db_mod == "sqlite": 181 # No need to drop tables in sqlite, we blank out the db each 182 # time 183 pass 184 else: 185 raise "Unknown db", self.db_mod 186 187 # CREATE TABLE // EXECUTE 188 if self.db_mod == "sqlite": 189 # We have to fake since sqlite does not support the 190 # fancy "bool" type. 191 sql = sql.replace("FALSE", "0") 192 sql = sql.replace("TRUE", "1") 193 c.executescript(sql) 194 elif self.db_mod in ("mysql", "postgresql"): 195 for statement in sql.split(";"): 196 if not statement.strip(): 197 continue # Skip empty lines 198 c.execute(statement.encode("utf8")) 199 200 # Create database specific table "insertion" 201 if self.db_mod == "sqlite": 202 # This one is seperate because of "AUTOINCREMENT" vs "AUTO_INCREMENT" 203 c.execute(""" 204 CREATE TABLE insertion ( 205 insertion_id INTEGER PRIMARY KEY AUTOINCREMENT, 206 value VARCHAR(15) 207 )""") 208 elif self.db_mod == "mysql": 209 c.execute(""" 210 CREATE TABLE insertion ( 211 insertion_id INTEGER PRIMARY KEY AUTO_INCREMENT, 212 value VARCHAR(15) 213 )""") 214 elif self.db_mod == "postgresql": 215 c.execute(""" 216 CREATE TABLE insertion ( 217 insertion_id SERIAL PRIMARY KEY, 218 value VARCHAR(15) 219 )""") 220 else: 221 raise "Unknown db", self.db_mod 222 db.commit() 223 finally: 224 db.rollback()
225 226
227 -class TestTestFramework(TestFramework):
228 - def testLog(self):
229 logging.error("Hello") 230 self.assertEqual(self.lastLog(), "ERROR: Hello\n") 231 self.assertEqual(self.lastLog(), "")
232
233 - def testDatabase(self):
234 self.assert_(os.path.exists, self.db) 235 if isinstance(self.db_connect, dict): 236 db = self.db.connect(**self.db_connect) 237 else: 238 db = self.db.connect(*self.db_connect) 239 try: 240 c = db.cursor() 241 # Check that data got imported 242 c.execute("SELECT count(*) FROM county") 243 self.assertEqual(c.fetchone(), (22,)) 244 c.execute("SELECT * FROM municipal WHERE municipal_id=1103") 245 self.assertEqual(list(c.fetchall()), 246 [(1103, "Stavanger", 11)]) 247 c.execute("SELECT * FROM postal WHERE postal_no=4042") 248 # FIXME: BOOL with mysql might not equal False 249 self.assertEqual(c.fetchone(), 250 (4042, "HAFRSFJORD", 1103, False)) 251 # Check if wtf-8 worked fine. Note that the line 252 # -*- coding: utf-8 -*- must be present at the top of 253 # this file for the u"Østfold" thingie to work 254 c.execute("SELECT county_name FROM county WHERE county_id=1") 255 a = c.fetchone()[0] 256 if not isinstance(a, unicode): 257 a = a.decode("utf8") 258 self.assertEqual(a, u"Østfold") 259 finally: 260 db.rollback()
261
262 -class TestDBConnect(TestFramework):
263
264 - def testConstructor(self):
265 d = DBConnect(self.db, self.db_connect) 266 self.assertEqual(d.module, self.db) 267 self.assertEqual(d.connect_info, self.db_connect) 268 self.assert_(d.connection) 269 self.assert_(d.connection.cursor())
270
271 - def testCursor(self):
272 cursor = self.Database._db.cursor() 273 cursor.execute("SELECT 1+1") 274 self.assertEqual(cursor.fetchone(), (2,)) 275 self.assertEqual(self.lastLog(), "")
276
277 - def testCursorReconnect(self):
278 d = self.Database() 279 # Nasty! 280 d._db.connection.close() 281 # should reconnect 282 cursor = d._db.cursor() 283 cursor.execute("SELECT 1+1") 284 self.assertEqual(cursor.fetchone(), (2,)) 285 self.assert_(self.lastLog().startswith( 286 "WARNING: Reconnecting database due to "))
287
288 - def testSameConnection(self):
289 d1 = self.Database() 290 d2 = self.Database() 291 class Subclass(self.Database): 292 pass
293 d3 = Subclass() 294 self.assertEqual(d1._db, self.Database._db) 295 self.assertEqual(d2._db, self.Database._db) 296 self.assertEqual(d3._db, self.Database._db) 297 d1._db.cursor() 298 self.assertEqual(d1._db, self.Database._db) 299 self.assertEqual(d2._db, self.Database._db) 300 # OK.. and if we close/lose the connection, will we 301 # also replace it all over the place? 302 old_conn = id(d1._db.connection) 303 self.assertEqual(old_conn, id(d2._db.connection)) 304 d3._db.connection.close() 305 d1._db.cursor() 306 self.assert_(self.lastLog().startswith( 307 "WARNING: Reconnecting database due to ")) 308 self.assertNotEqual(old_conn, id(d2._db.connection)) 309 self.assertEqual(d1._db.connection, 310 self.Database._db.connection) 311 self.assertEqual(d2._db.connection, 312 self.Database._db.connection) 313 self.assertEqual(d3._db.connection, 314 self.Database._db.connection)
315 316
317 -class TestDatabase(TestFramework):
318
319 - def testQuery(self):
320 res = self.Database._query("SELECT 1+1 AS mysum") 321 self.assertEqual(list(res), [{'mysum': 2}]) 322 res = self.Database._query("SELECT 1+1 AS mysum, 5+3 AS other") 323 self.assertEqual(list(res), [{'mysum': 2, "other": 8}]) 324 325 res = self.Database._query("SELECT county_id FROM county " 326 "ORDER BY county_id") 327 # Norwegian counties are numbered 1 till 24, but not 13 328 should_be = [{"county_id": x} for x in range(1,24) if x != 13] 329 self.assertEqual(list(res), should_be)
330
331 - def testQueryOne(self):
332 res = self.Database._query_one("SELECT 1+1 AS mysum") 333 self.assertEqual(res, {'mysum': 2}) 334 res = self.Database._query_one("SELECT 1+1 AS mysum, 5+3 AS other") 335 self.assertEqual(res, {'mysum': 2, "other": 8}) 336 337 res = self.Database._query_one("SELECT * FROM postal WHERE " 338 "postal_no=%s" % 339 self.Database._db.param, (4001,)) 340 self.assertEqual(res, { 341 "postal_no": 4001, 342 "postal_name": "STAVANGER", 343 "municipal_id": 1103, 344 "is_pobox": 1, 345 }) 346 sql = "SELECT * FROM postal WHERE postal_no > %s" 347 sql %= self.Database._db.param 348 self.Database._query_one(sql, (4001,)) 349 #FIXME: Not %s in error 350 self.assertEqual(self.lastLog(), 'WARNING: More than one hit ' 351 'returned for "%s" %% (4001,)\n' % sql)
352
353 -class TestTableBuilder(TestFramework):
354 - def setUp(self):
355 super(TestTableBuilder, self).setUp() 356 self.builder = self.TableBuilder()
357
358 - def testAllTables(self):
359 self.builder.all_tables() 360 for table in ("postal", "municipal", "county"): 361 self.assert_(table in self.builder.table_names)
362
363 - def testBuildTable(self):
364 t = self.builder.build_table("postal") 365 self.assertEqual(t.__name__, "Postal") 366 fields = Set(t._fields) 367 self.assertEqual(fields, 368 Set(('is_pobox', 'municipal_id', 'postal_name', 'postal_no'))) 369 self.assertEqual(t._primary, ["postal_no"])
370
371 - def testBuildTables(self):
372 self.builder.build_tables() 373 tables = Set(self.builder.tables) 374 names = Set(self.builder.table_names) 375 # Make sure all tables have been built 376 self.assertEqual(tables, names)
377
378 - def testFindForeign(self):
379 self.builder.build_tables() 380 Postal = self.builder.tables["postal"] 381 Municipal = self.builder.tables["municipal"] 382 County = self.builder.tables["county"] 383 # Called by build_tables 384 #self.builder.find_foreign(Postal) 385 self.assertEqual(Postal._foreigns, 386 {"municipal_id": Municipal}) 387 # FIXME: Should test a table with several foreigns 388 self.assertEqual(Set(Municipal._foreigns), 389 Set(("county_id",)))
390
391 -class TestBuiltClass(TestFramework):
392 - def setUp(self):
393 super(TestBuiltClass, self).setUp() 394 self.builder = self.TableBuilder() 395 self.builder.build_tables()
396
397 - def testLoad(self):
398 Postal = self.builder.tables["postal"] 399 # Fetch Stavanger 400 svg = Postal(postal_no=4001) 401 self.assertEqual(svg.postal_no, 4001) 402 self.assertEqual(svg.postal_name, "STAVANGER") 403 self.assertEqual(svg.municipal_id, 1103) 404 self.assertEqual(svg.is_pobox, 1)
405
406 - def testNotFound(self):
407 Postal = self.builder.tables["postal"] 408 self.assertRaises(NotFoundError, 409 Postal, postal_no=9999)
410
411 - def testPrimaryFails(self):
412 Postal = self.builder.tables["postal"] 413 self.assertRaises(ProgrammingError, 414 Postal, postal_name="STAVANGER")
415
416 - def testUndo(self):
417 Postal = self.builder.tables["postal"] 418 # Fetch Stavanger 419 svg = Postal(postal_no=4001) 420 svg.postal_name = "Fisk" 421 self.assertEqual(svg.postal_name, "Fisk") 422 svg.undo() 423 self.assertEqual(svg.postal_name, "STAVANGER") 424 425 new = Postal() 426 new.postal_name = "Fjosk" 427 self.assert_(hasattr(new, "postal_name")) 428 new.undo() 429 self.failIf(hasattr(new, "postal_name"))
430
431 - def testForeigns(self):
432 Postal = self.builder.tables["postal"] 433 # Fetch Stavanger 434 svg = Postal(postal_no=4001) 435 municipal = svg.get_municipal() 436 self.assertEqual(municipal.municipal_id, 1103) 437 self.assertEqual(municipal.municipal_name, "Stavanger") 438 self.assertEqual(municipal.county_id, 11) 439 county = municipal.get_county() 440 self.assertEqual(county.county_name, "Rogaland") 441 442 # Temporary set to some other value so we can test 443 # the generated set_method 444 municipal.county_id = 12 445 self.assertEqual(municipal.county_id, 12) 446 municipal.set_county(county) 447 self.assertEqual(municipal.county_id, 11)
448
449 - def testForeignsNone(self):
450 Postal = self.builder.tables["postal"] 451 # Fetch Stavanger 452 svg = Postal(postal_no=4001) 453 svg.set_municipal(None) 454 self.assertEqual(svg.municipal_id, None) 455 self.assertEqual(svg.get_municipal(), None)
456
457 - def testRepr(self):
458 Postal = self.builder.tables["postal"] 459 svg = Postal(postal_no=4001) 460 self.assertEqual(repr(svg), "<Postal postal_no=4001>")
461
462 - def testChildren(self):
463 Municipal = self.builder.tables["municipal"] 464 municipal = Municipal(municipal_id=1103) # Stavanger 465 postals = list(municipal.get_postals()) 466 # Should be something inbetween 467 self.assert_(40 < len(postals) < 100)
468
469 - def testIterate(self):
470 County = self.builder.tables["county"] 471 # the county IDs are 1 till 24, but not 13 472 x = range(1,24) 473 for county in County: 474 x.remove(county.county_id) 475 self.assertEqual(x, [13]) 476 477 all_counties = list(County) 478 self.assertEqual(len(all_counties), 22)
479
480 - def testWhere(self):
481 County = self.builder.tables["county"] 482 result = [] 483 for county in County.where("county_name=$name", 484 name="Oslo"): 485 result.append(county.county_id) 486 # Only one row, and it should be county_id=3 487 self.assertEqual(result, [3])
488
489 - def testWhereKeywords(self):
490 County = self.builder.tables["county"] 491 result = [] 492 for county in County.where(county_name="Oslo"): 493 result.append(county.county_id) 494 # Only one row, and it should be county_id=3 495 self.assertEqual(result, [3])
496
497 - def testGet(self):
498 County = self.builder.tables["county"] 499 county = County.get("county_name=$name", 500 name="Oslo") 501 self.assertEqual(county.county_id, 3) 502 county = County.get(county_name="Oslo") 503 self.assertEqual(county.county_id, 3) 504 505 county = County.get("county_name=$name", 506 name="Not Found") 507 self.assertEqual(county, None) 508 county = County.get(county_name="Not Found") 509 self.assertEqual(county, None)
510 511 512
513 -class TestSave(TestFramework):
514 - def setUp(self):
515 super(TestSave, self).setUp() 516 self.builder = self.TableBuilder() 517 self.builder.build_tables() 518 self.Postal = self.builder.tables["postal"] 519 self.Insertion = self.builder.tables["insertion"]
520
521 - def tearDown(self):
522 # Roll back those stupid changes we might make 523 self.builder._execute("UPDATE postal SET postal_name='STAVANGER'" 524 " WHERE postal_no=4001") 525 self.builder._execute("DELETE FROM postal WHERE postal_no=9998") 526 self.builder._execute("DELETE FROM postal WHERE postal_no=9999") 527 super(TestSave, self).tearDown()
528
529 - def testUpdate(self):
530 # Fetch Stavanger 531 svg = self.Postal(postal_no=4001) 532 self.assertEqual(svg.postal_name, "STAVANGER") 533 svg.postal_name = "Nesten Stavanger" 534 self.assertEqual(svg.save(), 1) 535 # Reload as a new object and check that it was stored 536 svg = self.Postal(postal_no=4001) 537 self.assertEqual(svg.postal_name, "Nesten Stavanger")
538
539 - def testSave(self):
540 p = self.Postal() 541 p.postal_no = 9999 542 p.postal_name = "Ingenmannsland" 543 p.municipal_id = 1103 544 self.assertEqual(p.save(), 1) 545 p1 = self.Postal(postal_no=9999) 546 self.assertEqual(p1.postal_name, "Ingenmannsland") 547 self.assertEqual(p1.municipal_id, 1103) 548 # Should get default values 549 self.assertEqual(p1.is_pobox, 0)
550
551 - def testSaveExists(self):
552 p = self.Postal() 553 p.postal_no = 4001 554 p.postal_name = "Ingenmannsland" 555 # Will collide with 4001 STAVANGER 556 self.assertRaises(p._db.module.Error, p.save)
557
558 - def testSaveAutoincrement(self):
559 ins = self.Insertion() 560 ins.value = "fish" 561 self.assert_(not hasattr(ins, "insertion_id")) 562 ins.save() 563 self.assert_(ins.insertion_id > 0)
564
565 - def testUpdatePrimary(self):
566 p = self.Postal() 567 p.postal_no = 9999 568 p.postal_name = "Ingenmannsland" 569 p.municipal_id = 1103 570 self.assertEqual(p.save(), 1) 571 p1 = self.Postal(postal_no=9999) 572 p1.postal_no = 9998 573 self.assertEqual(p1.save(), 1) 574 p2 = self.Postal(postal_no=9998)
575
576 - def testUpdateWasDeleted(self):
577 Postal = self.builder.tables["postal"] 578 p = Postal() 579 p.postal_no = 9999 580 p.postal_name = "Ingenmannsland" 581 p.municipal_id = 1103 582 p.save() 583 self.builder._execute("DELETE FROM postal WHERE postal_no=9999") 584 # Should fail 585 self.assertRaises(NotFoundError, p.save)
586
587 -class TestGenerate(TestFramework):
588 - def testGenerate(self):
589 db = generate(self.db, self.db_connect) 590 self.assert_(hasattr(db, "Postal")) 591 self.assert_(hasattr(db, "Municipal")) 592 self.assert_(hasattr(db, "County")) 593 594 svg = db.Postal(postal_no=4001) 595 self.assertEqual(svg.postal_name, "STAVANGER") 596 597 db.query_one("SELECT 1+1") 598 db.query("SELECT * FROM postal") 599 db.execute("DELETE FROM postal WHERE postal_no=9999") 600 db.db.cursor()
601 602
603 -def main():
604 if "--mysql" in sys.argv: 605 TestFramework.db_mod = "mysql" 606 pos = sys.argv.index("--mysql") 607 db_args = sys.argv[pos+1] 608 db_args = db_args.split(",") 609 TestFramework.db_connect = db_args 610 del sys.argv[pos+1] 611 del sys.argv[pos] 612 elif "--postgresql" in sys.argv: 613 TestFramework.db_mod = "postgresql" 614 pos = sys.argv.index("--postgresql") 615 db_args = sys.argv[pos+1] 616 TestFramework.db_connect = [db_args] 617 del sys.argv[pos+1] 618 del sys.argv[pos] 619 unittest.main() 620 621 if __name__ == "__main__": 622 main() 623