from django.conf import settings from django.db.models.query import QuerySet, GET_ITERATOR_CHUNK_SIZE, EmptyResultSet from django.db.models.manager import Manager from django.db.models import Model from django.db import backend, connection, transaction __all__ = ['ProcedureManager'] class PreparedStatementError(Exception): pass class QuerySetLimitationError(PreparedStatementError): pass class InvalidSQLProcedure(PreparedStatementError): pass if 'mysql' in settings.DATABASE_ENGINE.lower(): prepared_command = 'CALL' db_mysql = True else: prepared_command = 'SELECT * FROM' db_mysql = False class QuerySetPrepared(QuerySet): """A QuerySet that represents the resultset of a procedure -- either through MySQL's CALL or PostgreSQL's stored functions. USAGE ===== To obtain one of these objects, simply do:: result = Model.objecs.filter_by_procedure('procedure_name',arg1, ...) LIMITATIONS =========== Since there are a lot of limitations in MySQL with stored procedures, you cannot do much with this. You cannot filter, exclude, order, or otherwise modify this query. """ def __init__(self, *args, **kwargs): """ Define the procedure variables. """ self._proc_params = () self._proc_name = '' super(QuerySetPrepared, self).__init__(*args, **kwargs) def iterator(self): """ Like the Django iterator except is used for calling stored procedures. """ # set the params that we're going to call the stored procedure of proc_params = self._proc_params proc_name = self._proc_name try: select, sql, params = self._get_sql_clause() except EmptyResultSet: raise StopIteration index_start = len(sql) for token in (' ORDER BY ', ' WHERE ', ' LIMIT ',): current_index = sql.find(token) if current_index != -1 and current_index < index_start: index_start = current_index if index_start == len(sql) or db_mysql: where_clause = '' else: where_clause = sql[index_start:].replace('"%s".' % self.model._meta.db_table, '') cursor = connection.cursor() cursor.execute("%s %s(%s)%s" % (prepared_command, proc_name, ', '.join('%s' for x in proc_params), where_clause), proc_params+params) model_keys = [f.column for f in self.model._meta.fields] while 1: rows = cursor.dictfetchmany(GET_ITERATOR_CHUNK_SIZE) if not rows: raise StopIteration for row in rows: # very simple "return result of procedure" try: args = [row[model_key] for model_key in model_keys] except KeyError: raise InvalidSQLProcedure("'%s' does not provide the all the correct columns for the model, %s" % (proc_name, tuple(model_keys))) object_ = self.model(*args) object_.__dict__.update(row) yield object_ def count(self): """ Counts the number of objects this queryset represents. """ if self._result_cache is not None: return len(self._result_cache) # since we're using a stored procedure/prepared statement, # we cannot use COUNT if db_mysql: return len(self._get_data()) else: counter = self._clone() offset = counter._offset limit = counter._limit cursor = connection.cursor() cursor.execute('SELECT COUNT(*) FROM %s(%s)' % (self._proc_name, ', '.join('%s' for x in self._proc_params)), self._proc_params) count = cursor.fetchone()[0] if offset: count = max(0, count - offset) if limit: count = min(limit, count) return count def complain(self, *args, **kwargs): raise QuerySetLimitationError("You cannot perform this operation on a query that uses prepared statements or stored procedures.") def complain_optionally(method): """ Complain only if the database backend is MySQL. """ if db_mysql: return QuerySet.complain else: return method # These functions are not allowed when used with MySQL's Stored Procedures _filter_or_exclude = complain_optionally(QuerySet._filter_or_exclude) complex_filter = complain_optionally(QuerySet.complex_filter) order_by = complain_optionally(QuerySet.order_by) distinct = complain_optionally(QuerySet.distinct) # These functions will not work with any of this. values = complain dates = complain delete = complain extra = complain select_related = complain in_bulk = complain def __getitem__(self, k): if db_mysql: return self._get_data().__getitem__(k) else: return super(QuerySetPrepared, self).__getitem__(k) def _clone(self, klass=None, **kwargs): """ Clone this queryset to a new one. """ if klass is None: klass = self.__class__ c = super(QuerySetPrepared, self)._clone(klass, **kwargs) c._proc_name = self._proc_name c._proc_params = self._proc_params return c class ProcedureManager(Manager): """ ``ProcedureManager`` allows Django Models to easily call procedures from the database. This manager exposes two additional functions to ``Model.objects``:: - ``values_from_procedure``: Returns a list of tuples that were returned from the call. - ``filter_by_procedure``: Returns a ``QuerySetPrepared`` that represents the list of objects returned by that procedure. USAGE ===== To use, simply add the objects statement in your model. For example:: class Article(models.Model): objects = ProcedureManager() Then just call it like any filter:: Article.objects.filter_by_procedure('articles_with_author', request.user) """ def values_from_procedure(self, proc_name, *proc_params): """ Return whatever a result of a procedure is. The proc_name is the name of a stored procedure or function. This will return a list of dictionaries representing the rows and columns of the result. """ new_params = [clean_param(param) for param in proc_params] cursor = connection.cursor() cursor.execute("%s %s(%s)" % (prepared_command, proc_name, ', '.join('%s' for x in new_params)), new_params) rows = cursor.dictfetchmany(GET_ITERATOR_CHUNK_SIZE) retVal = [] while rows: for row in rows: retVal.append(row) rows = cursor.dictfetchmany(GET_ITERATOR_CHUNK_SIZE) return retVal def filter_by_procedure(self, proc_name, *proc_params): """ Use this to get a QuerySetPrepared of objects by a database procedure. """ query_set = self.get_query_set() proc_query_set = QuerySetPrepared() proc_query_set.__dict__.update(query_set.__dict__) new_params = [clean_param(param) for param in proc_params] proc_query_set._proc_name = proc_name proc_query_set._proc_params = new_params del query_set return proc_query_set def clean_param(param): if hasattr(param, '_get_pk_val'): # has a pk value -- must be a model return str(param._get_pk_val()) if callable(param): # it's callable, should call it. return str(param()) return str(param)