Package install :: Package MoSTBioDat :: Package DataBase :: Package ForgetSQL2 :: Module forgetsql2
[hide private]
[frames] | no frames]

Source Code for Module install.MoSTBioDat.DataBase.ForgetSQL2.forgetsql2

  1  #!/usr/bin/env python 
  2  # *-* encoding: utf8 
  3  #  
  4  # Copyright (c) 2005-2006 Stian Soiland 
  5  #  
  6  # This library is free software; you can redistribute it and/or 
  7  # modify it under the terms of the GNU Lesser General Public 
  8  # License as published by the Free Software Foundation; either 
  9  # version 2.1 of the License, or (at your option) any later version. 
 10  #  
 11  # This library is distributed in the hope that it will be useful, 
 12  # but WITHOUT ANY WARRANTY; without even the implied warranty of 
 13  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU 
 14  # Lesser General Public License for more details. 
 15  #  
 16  # You should have received a copy of the GNU Lesser General Public 
 17  # License along with this library; if not, write to the Free Software 
 18  # Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307  USA 
 19  # 
 20  # Author: Stian Soiland <stian@soiland.no> 
 21  # URL: http://soiland.no/i/src/forgetsql2/ 
 22  # License: LGPL 
 23  # 
 24   
 25  """ForgetSQL - SQL to object database wrapper. 
 26   
 27  NOTE:  
 28  This version 2 of forgetSQL is NOT backwards compatible with 
 29  version 0.5.1. This is the reason it is called forgetsql2.py - so that 
 30  it can co-exist with the old version. 
 31   
 32  You should only need to use the generate() function to generate the 
 33  classes mapping to your tables, which will be subclasses of the  
 34  Table class. 
 35   
 36   
 37  Example usage:: 
 38   
 39      import MySQLdb, forgetsql2 
 40      # Connect to MySQLdb using keyword parameters 
 41      db = forgetsql2.generate(MySQLdb, {db='fish'}) 
 42   
 43      # Iterate through generated class from the table "postal" 
 44      for postal in db.Postal: 
 45          # Print normal fields 
 46          print postal.postal_no, postal.postal_name, postal.municipal_id 
 47          # Follow the foreign key municipal_id to retrieve the entry 
 48          # from the Municipal class 
 49          municipal = postal.get_municipal()     
 50          print municipal.municipal_name 
 51           
 52      # Retrieve by primary key 
 53      rogaland = db.County(county_id=11)     
 54      # Iterate over municipals that have foreign keys to rogaland 
 55      for municipal in rogaland.get_municipals(): 
 56          print municipal.municipal_name 
 57      
 58  """ 
 59   
 60   
 61  import sys 
 62  import os 
 63  import types 
 64  import StringIO 
 65  import logging 
 66  from sets import Set 
 67  from itertools import izip, count 
 68  from MoSTBioDat.Log.MoSTBioDatLog import MoSTBioDatLog 
 69  import re  
 70  try: 
 71      import threading 
 72  except ImportError: 
 73      threading = None     
 74   
 75  from doc_exception import DocstringException, ProgrammingError 
 76   
 77  # Valid column names for WHERE sentences 
 78  _col_pat = re.compile(r"\$([a-zA-Z][a-zA-Z0-9_-]*)") 
 79   
80 -def _sql_bool(value):
81 """Convert SQL bool value to Python True/False""" 82 if isinstance(value, bool): 83 return value 84 value = str(value) 85 value = value.lower() 86 if value in ("t", "true", "1"): 87 return True 88 elif value in ("f", "false", "0"): 89 return False 90 else: 91 raise TypeError, value
92
93 -class error(DocstringException):
94 """Generic table error"""
95
96 -class NotFoundError(error):
97 """Could not find"""
98
99 -class PrimaryKeyError(error, ProgrammingError):
100 """Missing/wrong primary keys"""
101
102 -class UnsupportedDBError(error, ProgrammingError):
103 """Unsupported db module""" 104
105 -class DBConnect(object):
106 """Database connection. 107 108 In addition to connecting to the database and providing cursors, this 109 class will guess information such as the database type (mysql, 110 postgresql, sqlite) and parameter style. 111 """ 112
113 - def __init__(self, module, connect_info,log=None):
114 """Construct and initialize database connection. 115 116 Parameters: 117 module 118 a DB API 2.0 compilent database module object 119 connect_info 120 connect info for module.connect() 121 * If the parameter is a string it will be passed directly 122 to module.connect(connect_info). 123 * If a tuple or a list, it will be passed as positional 124 arguments as module.connect(*connect_info) 125 * If a dictionary, it will be passed as keyword arguments as 126 module.connect(**connect_info). 127 """ 128 # The actual DB module 129 self.module = module 130 # parameters for connecting 131 self.connect_info = connect_info 132 # connection object - made by connect() 133 self.connection = None 134 # db type, "mysql", "postgresql" or "sqlite" - determined by guess_db_info() 135 self.type = None 136 # parameter style, "?" or "%s" - determined by guess_db_info() 137 self.param = None 138 # Connect to the database 139 self.connect(log)
140
141 - def _set_connection(self, connection):
142 """Set the database connection. 143 The connection is as returned by self.module.connect(). 144 145 If the current database module is not threadsafe on connection 146 level, each thread will require a separate connection. This 147 property ensures this feature. 148 """ 149 150 if self.module.threadsafety < 2 and threading: 151 # Not thread safe on connection level, store the connection 152 # in the current thread 153 name = "forgetsql_conn_%s" % id(self) 154 setattr(threading.currentThread(), name, connection) 155 else: 156 # store normally 157 self._connection = connection
158 - def _get_connection(self):
159 if self.module.threadsafety < 2 and threading: 160 name = "forgetsql_conn_%s" % id(self) 161 return getattr(threading.currentThread(), name, None) 162 else: 163 return self._connection
164 connection = property(_get_connection, _set_connection) 165 166
167 - def connect(self,log=None):
168 """Connect to the database. 169 170 The connection object as returned by self.module.connect() 171 is stored in self.connection. In addition, guess_db_info() is 172 called after connecting. 173 """ 174 if isinstance(self.connect_info, dict): 175 # dict etc, kwargs style, connect(a=x1, b=x2) 176 connection = self.module.connect(**self.connect_info) 177 if log: 178 log.info('Connection successful') 179 elif isinstance(self.connect_info, (tuple, list)): 180 # tuple etc, args style connect(a, b) 181 connection = self.module.connect(*self.connect_info) 182 if log: 183 log.info('Connection successful') 184 else: 185 # probably strings (URIs etc.) connect(a) 186 connection = self.module.connect(self.connect_info) 187 if log: 188 log.info('Connection successful') 189 self.connection = connection 190 # DB info should not change between connects, but you never know 191 self.guess_db_info() 192 if self.type == "postgresql": 193 # FIXME: AVOID AUTOCOMMIT!! 194 #connection.autocommit(1) 195 pass 196 self.prepare_db_types()
197
198 - def prepare_db_types(self):
199 """Prepare type casting in database results. 200 201 Currently this makes sure that booleans are returned as real 202 True/False objects. 203 """ 204 if self.type == "postgresql": 205 if not "psycopg" in str(self.module): 206 return 207 c = self.cursor() 208 c.execute("SELECT TRUE") 209 if isinstance(c.fetchone()[0], bool): 210 # Already registered 211 return 212 bool_type_oid = c.description[0][1] 213 bool_type = self.module.new_type((bool_type_oid,), 214 "BOOLEAN", 215 _sql_bool) 216 self.module.register_type(bool_type) 217 if self.type == "sqlite" and "detect_types" in str(self.connect_info): 218 # Add detect_types=sqlite.PARSE_DECLTYPES to connect info 219 # for type conversion 220 self.module.converters["bool"] = _sql_bool
221
222 - def guess_db_info(self):
223 """Guess database type and parameter style. 224 225 self.type is determined to "mysql", "postgresql" or "sqlite" 226 227 self.param is "?" or "%s" and denotes how query parameters to be 228 expanded by cursor.execute(sql, params) is to be expressed in 229 the sql. 230 """ 231 try: 232 db_name = self.module.__name__.lower() 233 except AttributeError: 234 db_name = str(self.module) 235 if "mysql" in db_name: 236 self.type = "mysql" 237 elif "sqlite" in db_name: 238 self.type = "sqlite" 239 elif "psycopg" in db_name: 240 self.type = "postgresql" 241 else: 242 raise UnsupportedDBError, self.module 243 if self.module.paramstyle == "qmark": 244 self.param = "?" 245 # FIXME: Should support named parameters $fish 246 elif self.module.paramstyle == "format": 247 self.param = "%s" 248 elif self.module.paramstyle == "pyformat": 249 # FIXME: Should support named parameters %(fish)s 250 self.param = "%s" 251 else: 252 raise UnsupportedDBError, "paramstyle=%s" % self.module.paramstyle
253
254 - def cursor(self):
255 """Fetch a cursor. Reconnect if needed.""" 256 if not self.connection: 257 self.connect() 258 try: 259 c = self.connection.cursor() 260 # Connection errors won't show until we query something 261 c.execute("SELECT 1+1") 262 assert c.fetchone() == (2,) 263 return c 264 except self.module.Error, e: 265 # Usually because of timeouts 266 logging.warning("Reconnecting database due to %s", 267 e.__class__) 268 # Reconnect and try again 269 self.connect() 270 c = self.connection.cursor() 271 c.execute("SELECT 1+1") 272 assert c.fetchone() == (2,) 273 return c
274
275 - def close(self):
276 """Close the connection. 277 Any pending transactions will be rolled back.""" 278 if self.connection: 279 self.connection.close() 280 self.connection = None
281
282 -class Database(object):
283 """Base class for objects that uses the database. 284 285 Provides class methods for executing and querying SQL against the 286 database. 287 288 Note that the _db *class attribute* must have been set to a 289 DBConnect instance before any of these methods can be used. Normally 290 this is done by subclassing:: 291 292 class MyDatabase(Database): 293 _db = DBConnect(db_module, connect_info) 294 """ 295 # a DBConnect instance, which holds the actual connection, and more 296 # important to us, the cursor() method. 297 _db = None 298
299 - def _execute(cls, sql, parameters={}):
300 """Execute SQL and return cursor""" 301 cursor = cls._db.cursor() 302 if cls._db.type == "mysql" or cls._db.type == "postgresql": 303 # FIXME: Avoid regex-hack-fixing this param crap 304 sql = _col_pat.sub(r"%(\1)s", sql) 305 elif cls._db.type == "sqlite": 306 pass 307 else: 308 raise UnsupportedDBError, cls._db.type 309 logging.debug("%s %r", sql, parameters) 310 cursor.execute(sql, parameters) 311 return cursor
312 _execute = classmethod(_execute) 313
314 - def _iter_cursor(cls, cursor):
315 """Provide iterator of cursor results. 316 317 If the cursor does provide an iterator, provide that. 318 Otherwise, return a generator that yield rows by using repeted 319 calls to fetchmany(). 320 """ 321 try: 322 return iter(cursor) 323 except TypeError: 324 def iterator(): 325 rows = cursor.fetchmany() 326 while rows: 327 for row in rows: 328 yield row 329 rows = cursor.fetchmany()
330 return iterator()
331 332 _iter_cursor = classmethod(_iter_cursor) 333
334 - def _query(cls, sql, parameters={}):
335 """Execute SQL and yield dictionaries. 336 337 The optional parameters argument can be used for variable 338 expansions as explained in PEP 249 .execute(). 339 """ 340 cursor = cls._execute(sql, parameters) 341 if not cursor.description: 342 # Should only happen when there is no data to yield 343 for row in cls._iter_cursor(cursor): 344 # so if there *is* something anyway, raise an exception 345 raise ProgrammingError, \ 346 "Could not find description for sql", sql 347 return 348 fields = [d[0] for d in cursor.description] 349 for row in cls._iter_cursor(cursor): 350 yield dict(izip(fields, row))
351 _query = classmethod(_query) 352 353
354 - def _query_one(cls, sql, parameters={}):
355 """Execute SQL as with _query(), but return first row. 356 357 Return None if no rows were returned. If more than one row is 358 returned, a warning is logged, and only the first row is 359 returned. 360 """ 361 res = cls._query(sql, parameters) 362 try: 363 result = res.next() 364 except StopIteration: 365 return None 366 # Log if there is more than one 367 try: 368 res.next() 369 except StopIteration: 370 pass 371 else: 372 logging.warning('More than one hit returned for "%s" %% %s', 373 sql, parameters) 374 return result
375 _query_one = classmethod(_query_one) 376
377 -class metaclass_table(type):
378 """Metaclass for Table allowing iteration. 379 380 Example: 381 382 # Assume generated class County 383 for county in County: 384 print county.county_name 385 """
386 - def __iter__(cls):
387 for elem in cls.where(): 388 yield elem
389
390 -class Table(Database):
391 """Representation of a table. 392 393 Instances of this class represent a row in the table, and can be 394 retrieved in several ways. In the examples below, the subclass Thing 395 is assumed. 396 397 - By calling the constructor with a primary key:: 398 399 # thing with primary key thing_id=1447 400 thing = Thing(thing_id=1447) 401 402 - By iterating over the the class:: 403 404 # all things 405 for thing in Thing: 406 pass 407 408 - By iterating over a filtered subset using SQL where:: 409 410 # All things with thing.value > 13 411 for thing in Thing.where("value > $value", value=13): 412 pass 413 414 - By following a foreign key from another table:: 415 416 # the thing with thing_id = person.thing_id 417 thing = person.get_thing() 418 419 - By iterating over rows who have a foreign key to another table:: 420 421 # all things which have thing.box_id = box.box_id 422 for thing in box.get_things(): 423 pass 424 """ 425 # To get __iter__ behavour 426 __metaclass__ = metaclass_table 427
428 - def __init__(self, _db_row=None, **primary):
429 """Instanciate a new or existing database row. 430 431 If no parameters are given, a new, blank instance is created, 432 which will be INSERT-ed when save() is called. 433 434 Else, if keyword arguments are supplied, they must match the 435 primary keys in _primary, and will be used for loading an 436 existing row. In such a case, the instance can be UPDATE-d by 437 calling save() - or restored to the database values by calling 438 undo(). 439 440 Mostly for internal usage, if _db_row is provided, it is assumed 441 to be a dictionary of row values as returned from cls.query(), 442 from which the object values will be loaded instead. 443 """ 444 if primary and Set(primary) != Set(self._primary): 445 raise PrimaryKeyError, primary 446 if not (primary or _db_row): 447 # Don't fetch anything, we're new and blank 448 return 449 self._load(_db_row=_db_row, **primary)
450
451 - def where(cls, where=None, **parameters):
452 """Yield all instances limited by ``where`` clause. 453 454 Use $field in where clause and supply values in the optional 455 dict parameter ``**parameters`` or as keyword arguments. 456 Parameters will be properly escaped. 457 458 If you supply keyword parameters, but not a where-clause, a 459 where clause "WHERE x=$x AND y=$y" will be generated from the 460 keywords. 461 462 Example:: 463 464 # Assume generated class County 465 for county in County.where("county_name=$name" 466 name="Oslo"): 467 print county 468 469 for county in County.where(county_name="Oslo"): 470 print county 471 472 """ 473 sql = "SELECT * FROM %s" % cls._table_name 474 if where: 475 sql += " WHERE %s" % where 476 if where is None and parameters: 477 sql += " WHERE " 478 sql += " AND ".join(["%s=$%s" % (key, key) for key in parameters]) 479 for row in cls._query(sql, parameters): 480 yield cls(_db_row=row)
481 where = classmethod(where) 482
483 - def get(cls, where=None, **parameters):
484 """Like where(), but returns first instance or None.""" 485 for elem in cls.where(where, **parameters): 486 return elem 487 return None
488 get = classmethod(get) 489
490 - def _where_primary(cls):
491 """Get WHERE part for primary key. 492 493 Note that the fields are prefixed with p__ to avoid 494 mixup when used by UPDATE. 495 """ 496 where = ["%s=$p__%s" % (key, key) for key in cls._primary] 497 where = " AND ".join(where) 498 return where
499 _where_primary = classmethod(_where_primary) 500
501 - def _load(self, _db_row=None, reload=False, **primary):
502 """Load from database. 503 504 If _db_row is supplied, it is assumed to be a dict as returned 505 by _query, and the instance will be loaded without any new 506 database calls. 507 508 If reload is True, the saved primary keys in the attribute 509 _primary_values will be used to reload all attributes. 510 511 Else, if keyword arguments are supplied, they must match the 512 primary keys in _primary, and will be used for loading a unique 513 row. 514 """ 515 516 if reload: 517 params = self._primary_values 518 elif _db_row: 519 params = None 520 elif primary: 521 # Convert to p__ style keynames as required by 522 # _where_primary 523 params = dict([("p__"+field, value) 524 for field,value in primary.items()]) 525 else: 526 raise ProgrammingError, "Missing parameter for _load()" 527 if not _db_row: 528 # Fetch from database 529 where = self._where_primary() 530 sql = "SELECT * FROM %s WHERE %s" % ( 531 self._table_name, where) 532 _db_row = self._query_one(sql, params) 533 if not _db_row: 534 raise NotFoundError, primary 535 536 for field in self._fields: 537 setattr(self, field, _db_row[field]) 538 self._save_primary()
539 540
541 - def __repr__(self):
542 primaries = ["%s=%s" % (key, getattr(self, key, "?")) 543 for key in self._primary] 544 primaries = " ".join(primaries) 545 return "<%s %s>" % (self.__class__.__name__, primaries)
546
547 - def undo(self):
548 """Undo changed attributes. 549 550 If this instance represent an existing row, the instance will be 551 reloaded to the *current* database values. 552 553 If this is a new instance not yet inserted to the database, 554 all database related attributes are removed. 555 """ 556 if hasattr(self, "_primary_values"): 557 self._load(reload=True) 558 else: 559 for field in self._fields: 560 try: 561 delattr(self, field) 562 except AttributeError: 563 continue
564
565 - def save(self):
566 """Save changes to database. 567 568 Return number of rows updated/inserted, normally 1. 569 (This is database dependant, sqlite will often return 0)""" 570 params = {} 571 if hasattr(self, "_primary_values"): 572 # it's an UPDATE. 573 fields = [] 574 for field in self._fields: 575 try: 576 params[field] = getattr(self, field) 577 except AttributeError: 578 # Blank values we assume will get default values 579 # from the database.. for instance "current date" 580 # etc. 581 continue 582 else: 583 fields.append("%s=$%s" % (field, field)) 584 fields = ",".join(fields) 585 586 params.update(self._primary_values) 587 sql = "UPDATE %s SET %s WHERE %s" % ( 588 self._table_name, fields, self._where_primary()) 589 else: 590 # it's an INSERT 591 fields = [] 592 values = [] 593 for field in self._fields: 594 try: 595 params[field] = getattr(self, field) 596 except AttributeError: 597 continue 598 else: 599 fields.append(field) 600 values.append("$%s" % field) 601 fields = ",".join(fields) 602 values = ",".join(values) 603 sql = "INSERT INTO %s(%s) VALUES (%s)" % ( 604 self._table_name, fields, values) 605 curs = self._execute(sql, params) 606 607 if len(self._primary) == 1 and \ 608 getattr(self, self._primary[0], None) is None: 609 # It's one of those fetch-id-after-inserting-databases 610 # NOTE: We cannot assume this for multi-valued primary 611 # keys, as it is often legal to have a primary key with one 612 # of the values NULL. 613 614 if self._db.type == "mysql": 615 id = self._query_one("SELECT LAST_INSERT_ID() AS id")["id"] 616 elif self._db.type == "sqlite": 617 id = self._query_one("SELECT last_insert_rowid() AS id")["id"] 618 elif self._db.type == "postgresql": 619 # Assume SERIAL and auto generated sequence name 620 # table_field_seq 621 seq_name = "%s_%s_seq" % (self._table_name, self._primary[0]) 622 id = self._query_one("SELECT currval('%s') AS id" % seq_name)["id"] 623 else: 624 # Other databases would probably use sequences BEFORE 625 # inserting. 626 raise UnsupportedDBError, self._db.type 627 setattr(self, self._primary[0], id) 628 629 # Set/Update _primary_values so we can do a reload 630 self._save_primary() 631 self._load(reload=True) 632 return curs.rowcount
633
634 - def _save_primary(self):
635 """Store a copy of current primary keys. 636 637 The attribute _primary_values will be updated with the current 638 primary key values. These will be used for reload-ing and 639 updates in case the normal attributes are changed. 640 641 The attributes are saved as a dictionary with p__fieldname as 642 keys, to match the WHERE clause generated by _where_primary(). 643 """ 644 primary = dict([("p__"+field, getattr(self,field)) 645 for field in self._primary]) 646 self._primary_values = primary
647 648 # Will be used by all generated _get_something() methods
649 - def _get_foreign(self, _foreign):
650 """Fetch foreign key as instance""" 651 table = self._foreigns[_foreign] 652 if getattr(self, _foreign) is None: 653 return None 654 if isinstance(_foreign, unicode): 655 # We can't (shouldn't) have unicode kw args! 656 _foreign = _foreign.encode("ascii", "ignore") 657 primary = {_foreign: getattr(self, _foreign)} 658 return table(**primary)
659 660 # Will be used by all generated _set_something() methods
661 - def _set_foreign(self, value, _foreign):
662 """Set foreign key by instance""" 663 table = self._foreigns[_foreign] 664 if not isinstance(value, table) and value is not None: 665 raise ProgrammingError, "Unsupported foreign type %s" % value 666 if value is None: 667 setattr(self, _foreign, None) 668 else: 669 # Fetch the foreign primary key 670 primary = getattr(value, _foreign) 671 setattr(self, _foreign, primary)
672 673 # Will be used by all generated _get_somethings() methods
674 - def _get_children(self, _Child, _child_field, _my_field):
675 """Yield all children as instances. 676 677 A child is someone whose foreign keys point to us. 678 """ 679 sql = "SELECT * FROM %s WHERE %s=%s" % ( 680 _Child._table_name, _child_field, self._db.param) 681 params = (getattr(self, _my_field), ) 682 for row in self._query(sql, params): 683 yield _Child(_db_row=row)
684 685
686 -class TableBuilder(Database):
687 """Build Table subclasses by investigating database. 688 689 As with the Database class, remember to subclass in the DBConn 690 instance in the _db class attribute. 691 692 The table builder will list all tables in the connected database and 693 generate Table instances, one for each table. 694 695 In addition to figuring out column names and primary keys, the table 696 builder will also guess foreign keys and add methods like 697 get_something(), set_something() and get_somethings(). 698 """
699 - def __init__(self, TableBase=Table):
700 """Build tables using provided TableBase as a base class. 701 702 If the parameter TableBase is not provided, Table will be used 703 as a superclass. 704 705 If the base table does not have a valid _db attribute (ie. a 706 DBConn connetion), it will be inherited from the TableBuilder 707 class. 708 """ 709 # Base of built classes 710 if not TableBase._db: 711 class TableBase(TableBase): 712 _db = self._db
713 self.TableBase = TableBase
714
715 - def all_tables(self):
716 """Find all table names in active database. 717 718 Depending on the database type, this function will find the list 719 of all table names and store them in self.table_names. 720 721 Table names ending in _seq are not included, but are placed in 722 self.sequences instead. 723 """ 724 c = self._db.cursor() 725 if self._db.type == "mysql": 726 c.execute("SHOW TABLES") 727 elif self._db.type == "sqlite": 728 c.execute("SELECT name FROM sqlite_master WHERE type='table'") 729 elif self._db.type == "postgresql": 730 # FIXME: Support other schemas and tablespaces, views, etc. 731 c.execute("""SELECT tablename FROM pg_catalog.pg_tables 732 WHERE schemaname=pg_catalog.current_schema()""") 733 else: 734 raise UnsupportedDBError, self._db.type 735 tables = [t for (t,) in c.fetchall()] 736 self.table_names = [t for t in tables if not t.endswith("_seq")] 737 self.sequences = [t for t in tables if t.endswith("_seq")]
738
739 - def build_class(self, table_name):
740 """Build the (empty) subclass for table_name. 741 742 The generated class will have a Pythonish version of table_name 743 as the official class name. For instance, the class for my_table 744 will be MyTable. 745 """ 746 class table(self.TableBase): 747 _table_name = table_name 748 _children = []
749 if isinstance(table_name, unicode): 750 table_name = table_name.encode("ascii", "ignore") 751 table.__name__ = table_name.capitalize() 752 return table 753
754 - def add_fields(self, table):
755 """Find the fields and primary keys""" 756 table_name = table._table_name 757 fields = {} 758 primary = [] 759 c = self._db.cursor() 760 if self._db.type == "mysql": 761 c.execute("DESCRIBE %s" % table_name) 762 for field, type, is_null, key, default, extra in c.fetchall(): 763 fields[field] = type 764 if key == "PRI": 765 primary.append(field) 766 elif self._db.type == "sqlite": 767 c.execute("PRAGMA table_info(%s)" % table_name) 768 for cid, field, type, is_null, default, key in c.fetchall(): 769 fields[field] = type 770 if key: 771 primary.append(field) 772 elif self._db.type == "postgresql": 773 # We could do it through pg_.* but it is a bit complicated 774 # compared to this approach 775 c.execute("SELECT * FROM %s LIMIT 0" % table_name) 776 for field, _a,_b,_c,_d,_e,_f in c.description: 777 fields[field] = _a # type-oid ? 778 # find primary keys will have to do it the pg_ way :/ 779 c.execute("""SELECT index.indexrelid 780 FROM pg_catalog.pg_index index 781 JOIN pg_catalog.pg_class i ON (index.indexrelid = i.oid) 782 JOIN pg_catalog.pg_class u ON (index.indrelid = u.oid) 783 WHERE index.indisprimary AND u.relname = %s 784 """, (table_name,)) 785 index_oid = c.fetchone() 786 if index_oid: 787 for n in count(1): 788 c.execute("SELECT pg_get_indexdef(%s, %s, FALSE)", 789 (index_oid[0], n)) 790 field = c.fetchone()[0] 791 if not field: 792 break 793 primary.append(field) 794 else: 795 raise UnsupportedDBError, self._db.type 796 797 # If we don't have a primary, we will have to 798 # match on ALL fields for UPDATE/DELETE. 799 if not primary: 800 primary = fields.keys() 801 table._fields = fields 802 table._primary = primary
803
804 - def find_foreign(self, table):
805 """Guess foreign keys. 806 807 Basically a field fish_id is assumed a foreign key for the table 808 fish - if it exists. 809 810 Note that build_table() must have been called on all tables 811 first in order to compare foreign keys with primary keys. 812 """ 813 table._foreigns = {} 814 for field in table._fields.keys(): 815 if not field.endswith("_id"): 816 continue 817 # Chop of _id 818 table_name = field[:-3] 819 if table_name == table._table_name: 820 continue 821 if not table_name in self.table_names: 822 continue 823 foreign = self.tables[table_name] 824 if not field in foreign._primary: 825 continue 826 table._foreigns[field] = foreign 827 # And add a reverse mapping 828 # his_table, his_field, my_field 829 foreign._children.append((table, field, field)) 830
831 - def build_table(self, table_name):
832 """Build a table class and find all fields. 833 """ 834 table = self.build_class(table_name) 835 self.add_fields(table) 836 return table
837
838 - def generate_foreign_methods(self, table):
839 """Generate get/set-methods for foreign keys. 840 841 The method names will be named like get_other() for the 842 table class Other. 843 """ 844 for foreign,Foreign in table._foreigns.items(): 845 get_name = "get_" + Foreign.__name__.lower() 846 def _get_foreign(self, _foreign=foreign): 847 return super(table, self)._get_foreign(_foreign)
848 if sys.version_info > (2,4,None,None,None): 849 _get_foreign.__name__ = get_name 850 setattr(table, get_name, _get_foreign) 851 852 set_name = "set_" + Foreign.__name__.lower() 853 def _set_foreign(self, value, _foreign=foreign): 854 super(table, self)._set_foreign(value, _foreign) 855 if sys.version_info > (2,4,None,None,None): 856 _set_foreign.__name__ = set_name 857 setattr(table, set_name, _set_foreign) 858
859 - def generate_children_methods(self, table):
860 """Generate get-methods for retrieving foreign key children. 861 862 The method names will be named like get_others() for the table 863 class Other. 864 """ 865 for (Child, child_field, my_field) in table._children: 866 # transform name, ie. "car" -> "get_cars" 867 child_name = "get_" + Child.__name__.lower() + "s" 868 def _get_children(self, _Child=Child, 869 _child_field=child_field, 870 _my_field=my_field): 871 return super(table, self)._get_children(_Child, 872 _child_field, _my_field)
873 if sys.version_info > (2,4,None,None,None): 874 _get_children.__name__ = child_name 875 setattr(table, child_name, _get_children) 876
877 - def build_tables(self):
878 """Fully generate the list of table classes. 879 880 This is the main method which will retrieve all tables, generate 881 the Table subclasses and finally generate foreign key methods. 882 883 The generated table classes will be available in self.tables 884 using the SQL table name as a key. 885 """ 886 self.all_tables() 887 self.tables = {} 888 for table_name in self.table_names: 889 table = self.build_table(table_name) 890 self.tables[table_name] = table 891 for table in self.tables.values(): 892 self.find_foreign(table) 893 for table in self.tables.values(): 894 self.generate_foreign_methods(table) 895 self.generate_children_methods(table)
896 897
898 -def generate(db_module, connect_info, globals=None,log=None):
899 """Generate forgetSQL classes and return as a module object. 900 901 The db_module can be MySQLdb or sqlite2. This parameter must be 902 the actual module object, imported by the caller. 903 904 The connect_info is provided to db_module.connect() and can be 905 - a dictionary (sent as keyword arguments) 906 - a tuple/list (sent as positional arguments) 907 - a string (sent as 1st argument) 908 909 If the optional parameter ''globals'' is provided, instead of 910 generating a new module, generated classes will be inserted into 911 the namespace dictionary, usually as provided by globals(). 912 913 Example:: 914 915 # simple usage 916 import MySQLdb, forgetsql2 917 db = forgetsql2.generate(MySQLdb, {db='fish'}) 918 for postal in db.Postal: 919 print postal.postal_no, postal.postal_name 920 921 # export symbols to this module 922 import MySQLdb, forgetsql2 923 forgetsql2.generate(MySQLdb, {db='fish'}, globals()) 924 for postal in Postal: 925 print postal.postal_no, postal.postal_name 926 927 The second variant can be used in a seperate module for bigger 928 projects, for instance database.py. Other modules can then do:: 929 930 import database 931 for postal in database.Postal: 932 print postal.postal_no 933 """ 934 # subclass in the _db connection 935 class TB(TableBuilder): 936 if log: 937 log.info("Trying to connect to %s",connect_info['host']) 938 try: 939 _db = DBConnect(db_module, connect_info,log) 940 print 'Connection to %s succeeded.' %connect_info['host'] 941 except Exception, e: 942 print 'Error: %s' %e 943 if log: 944 log.exception('Error: %s',e) 945 sys.exit(1)
946 builder = TB() 947 if log: 948 log.info('Building tables') 949 builder.build_tables() 950 module = None 951 if globals is None: 952 # Generate a module object. Note that it is not adviced to add 953 # this module object to sys.modules. Even though that would make 954 # "import mymodule" work in other modules, it would only work 955 # after generate() has been called. In a complicated module 956 # hierarchy such an assumption is not always easy to ensure. 957 module = types.ModuleType("forgetsql2.generated") 958 globals = module.__dict__ 959 for table in builder.tables.values(): 960 globals[table.__name__] = table 961 # And export general purpose database methods 962 globals['execute'] = TB._execute 963 globals['query'] = TB._query 964 globals['query_one'] = TB._query_one 965 globals['db'] = TB._db 966 return module 967