Source code for hynet.data.connection

"""
Manage *hynet*'s database connections.
"""

import logging
import textwrap

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import NoResultFound

from hynet.data.structure import (SCHEMA_VERSION,
                                  Base,
                                  DBInfo,
                                  DBBus,
                                  DBBranch,
                                  DBConverter,
                                  DBInjector,
                                  DBShunt,
                                  DBScenario)
from hynet.types_ import DBInfoKey

_log = logging.getLogger(__name__)

DIALECT_PREFIX_SEPARATOR = ':///'
DESCRIPTION_LINE_LENGTH = 75


[docs]def connect(database_uri): """ Return a connection to the specified *hynet* grid database. Parameters ---------- database_uri : str URI or file name of the *hynet* grid database. Returns ------- database : DBConnection Connection to the *hynet* grid database. """ if DIALECT_PREFIX_SEPARATOR not in database_uri: database_uri = 'sqlite' + DIALECT_PREFIX_SEPARATOR + database_uri return DBConnection(database_uri)
[docs]class DBConnection: """ Manager for a *hynet* grid database connection. See Also -------- hynet.data.connection.DBTransaction """ def __init__(self, database_uri): """ Establish a connection to the specified *hynet* grid database. Parameters ---------- database_uri: str URI of the *hynet* grid database. """ # Create an engine and bind it to the metadata and session maker self.database_uri = database_uri self.engine = create_engine(database_uri) Base.metadata.bind = self.engine Base.metadata.create_all(self.engine) # Create tables if they're absent self._session_maker = sessionmaker(bind=self.engine) # Check database version or, if empty, initialize it with DBTransaction(self) as transaction: try: version = transaction.query(DBInfo).filter( DBInfo.key == DBInfoKey.VERSION).one().value if float(version) < float(SCHEMA_VERSION): raise IOError("Deprecated database version '" + version.value + "' (expected '" + SCHEMA_VERSION + "' or higher)") except NoResultFound: transaction.add(DBInfo(key=DBInfoKey.VERSION, value=SCHEMA_VERSION)) _log.debug("Connected to database '{:s}'".format(self.database_uri)) @property def empty(self): """Return True if the database does not contain grid information.""" is_empty = True with DBTransaction(self) as transaction: for table in [DBBus, DBBranch, DBConverter, DBInjector, DBShunt, DBScenario]: if transaction.query(table).first() is not None: is_empty = False break return is_empty @property def version(self): """Return the *hynet* grid database format version of this database.""" return self.get_setting(DBInfoKey.VERSION) @property def grid_name(self): """Return the name of the grid in this database.""" try: return self.get_setting(DBInfoKey.GRID_NAME) except ValueError: return '' @grid_name.setter def grid_name(self, value): """Set the name of the grid in this database.""" self.set_setting(DBInfoKey.GRID_NAME, value) @property def description(self): """ Return the description of this database. Before the text is returned, the description retrieved from the database is wrapped to an appropriate column width to improve readability. """ try: description = self.get_setting(DBInfoKey.DESCRIPTION) return textwrap.fill(description, width=DESCRIPTION_LINE_LENGTH) except ValueError: return '' @description.setter def description(self, value): """Set the description of this database.""" self.set_setting(DBInfoKey.DESCRIPTION, value)
[docs] def get_setting(self, key): """ Return the database setting for the specified key. Parameters ---------- key : DBInfoKey The key for which the value shall be retrieved. Returns ------- value : str The value associated with the provided key. Raises ------ ValueError If the setting was not found. """ with DBTransaction(self) as transaction: try: setting = transaction.query(DBInfo)\ .filter(DBInfo.key == key).one().value except NoResultFound: raise ValueError("Setting '{:s}' was not found.".format(key)) return setting
[docs] def set_setting(self, key, value): """ Set the database setting for the specified key. Parameters ---------- key : DBInfoKey The key for which the value shall be set. value : str Value to be set. Raises ------ ValueError If the value is not a string. """ if not isinstance(value, str): raise ValueError("The setting value must be a string.") with DBTransaction(self) as transaction: try: setting = transaction.query(DBInfo)\ .filter(DBInfo.key == key).one() except NoResultFound: transaction.add(DBInfo(key=key, value=value)) else: setting.value = value transaction.update(setting)
[docs] def start_session(self): """ Return a new database session as an SQLAlchemy Session object. **Remark:** This function is for internal use. """ return self._session_maker()
[docs]class DBTransaction: """ Database transaction that is automatically committed at the exit block. """ def __init__(self, database): """ Create a database transaction. Parameters ---------- database: DBConnection Connection to the *hynet* grid database. """ self._session = database.start_session() def __enter__(self): """ Initiate the transaction. Returns ------- DBTransaction ``self``. """ return self def __exit__(self, type_, value, traceback): """ Commit the transaction. Attempts to commit the current transaction. In case that the commit fails, a roll back of the transaction is attempted. Parameters ---------- type_ : Exception class or None In case of an exception, class of the exception. None otherwise. value : Exception instance or None In case of an exception, the exception object. None otherwise. traceback : object or None In case of an exception, object with traceback info. None otherwise. """ if type_ is None: self._session.commit() else: self._session.rollback() self._session.close() self._session = None return False # In case of an exception, reraise it
[docs] def add(self, object_): """Add an object to the session.""" self._session.add(object_)
[docs] def add_all(self, collection): """Add a collection of objects to the session.""" self._session.add_all(collection)
[docs] def update(self, object_): """Update the state of the object in the session.""" self._session.merge(object_)
[docs] def delete(self, object_): """Mark the object as deleted in the session.""" self._session.delete(object_)
[docs] def delete_all(self, collection): """Mark a collection of objects as deleted""" for object_ in collection: self.delete(object_)
[docs] def query(self, object_type): """Return a new SQLAlchemy Query object for this session.""" return self._session.query(object_type)
[docs] def execute(self, query): """Execute the SQL expression construct or string statement.""" return self._session.execute(query)