# oracle/zxjdbc.py # Copyright (C) 2005-2013 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """ .. dialect:: oracle+zxjdbc :name: zxJDBC for Jython :dbapi: zxjdbc :connectstring: oracle+zxjdbc://user:pass@host/dbname :driverurl: http://www.oracle.com/technology/software/tech/java/sqlj_jdbc/index.html. """ import decimal import re from sqlalchemy import sql, types as sqltypes, util from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector from sqlalchemy.dialects.oracle.base import OracleCompiler, OracleDialect, OracleExecutionContext from sqlalchemy.engine import result as _result from sqlalchemy.sql import expression import collections SQLException = zxJDBC = None class _ZxJDBCDate(sqltypes.Date): def result_processor(self, dialect, coltype): def process(value): if value is None: return None else: return value.date() return process class _ZxJDBCNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): #XXX: does the dialect return Decimal or not??? # if it does (in all cases), we could use a None processor as well as # the to_float generic processor if self.asdecimal: def process(value): if isinstance(value, decimal.Decimal): return value else: return decimal.Decimal(str(value)) else: def process(value): if isinstance(value, decimal.Decimal): return float(value) else: return value return process class OracleCompiler_zxjdbc(OracleCompiler): def returning_clause(self, stmt, returning_cols): self.returning_cols = list(expression._select_iterables(returning_cols)) # within_columns_clause=False so that labels (foo AS bar) don't render columns = [self.process(c, within_columns_clause=False, result_map=self.result_map) for c in self.returning_cols] if not hasattr(self, 'returning_parameters'): self.returning_parameters = [] binds = [] for i, col in enumerate(self.returning_cols): dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) self.returning_parameters.append((i + 1, dbtype)) bindparam = sql.bindparam("ret_%d" % i, value=ReturningParam(dbtype)) self.binds[bindparam.key] = bindparam binds.append(self.bindparam_string(self._truncate_bindparam(bindparam))) return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) class OracleExecutionContext_zxjdbc(OracleExecutionContext): def pre_exec(self): if hasattr(self.compiled, 'returning_parameters'): # prepare a zxJDBC statement so we can grab its underlying # OraclePreparedStatement's getReturnResultSet later self.statement = self.cursor.prepare(self.statement) def get_result_proxy(self): if hasattr(self.compiled, 'returning_parameters'): rrs = None try: try: rrs = self.statement.__statement__.getReturnResultSet() rrs.next() except SQLException, sqle: msg = '%s [SQLCode: %d]' % (sqle.getMessage(), sqle.getErrorCode()) if sqle.getSQLState() is not None: msg += ' [SQLState: %s]' % sqle.getSQLState() raise zxJDBC.Error(msg) else: row = tuple(self.cursor.datahandler.getPyObject(rrs, index, dbtype) for index, dbtype in self.compiled.returning_parameters) return ReturningResultProxy(self, row) finally: if rrs is not None: try: rrs.close() except SQLException: pass self.statement.close() return _result.ResultProxy(self) def create_cursor(self): cursor = self._dbapi_connection.cursor() cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) return cursor class ReturningResultProxy(_result.FullyBufferedResultProxy): """ResultProxy backed by the RETURNING ResultSet results.""" def __init__(self, context, returning_row): self._returning_row = returning_row super(ReturningResultProxy, self).__init__(context) def _cursor_description(self): ret = [] for c in self.context.compiled.returning_cols: if hasattr(c, 'name'): ret.append((c.name, c.type)) else: ret.append((c.anon_label, c.type)) return ret def _buffer_rows(self): return collections.deque([self._returning_row]) class ReturningParam(object): """A bindparam value representing a RETURNING parameter. Specially handled by OracleReturningDataHandler. """ def __init__(self, type): self.type = type def __eq__(self, other): if isinstance(other, ReturningParam): return self.type == other.type return NotImplemented def __ne__(self, other): if isinstance(other, ReturningParam): return self.type != other.type return NotImplemented def __repr__(self): kls = self.__class__ return '<%s.%s object at 0x%x type=%s>' % (kls.__module__, kls.__name__, id(self), self.type) class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): jdbc_db_name = 'oracle' jdbc_driver_name = 'oracle.jdbc.OracleDriver' statement_compiler = OracleCompiler_zxjdbc execution_ctx_cls = OracleExecutionContext_zxjdbc colspecs = util.update_copy( OracleDialect.colspecs, { sqltypes.Date: _ZxJDBCDate, sqltypes.Numeric: _ZxJDBCNumeric } ) def __init__(self, *args, **kwargs): super(OracleDialect_zxjdbc, self).__init__(*args, **kwargs) global SQLException, zxJDBC from java.sql import SQLException from com.ziclix.python.sql import zxJDBC from com.ziclix.python.sql.handler import OracleDataHandler class OracleReturningDataHandler(OracleDataHandler): """zxJDBC DataHandler that specially handles ReturningParam.""" def setJDBCObject(self, statement, index, object, dbtype=None): if type(object) is ReturningParam: statement.registerReturnParameter(index, object.type) elif dbtype is None: OracleDataHandler.setJDBCObject( self, statement, index, object) else: OracleDataHandler.setJDBCObject( self, statement, index, object, dbtype) self.DataHandler = OracleReturningDataHandler def initialize(self, connection): super(OracleDialect_zxjdbc, self).initialize(connection) self.implicit_returning = connection.connection.driverversion >= '10.2' def _create_jdbc_url(self, url): return 'jdbc:oracle:thin:@%s:%s:%s' % (url.host, url.port or 1521, url.database) def _get_server_version_info(self, connection): version = re.search(r'Release ([\d\.]+)', connection.connection.dbversion).group(1) return tuple(int(x) for x in version.split('.')) dialect = OracleDialect_zxjdbc