import MySQLdb
import CommandLineApp # Doug Hellmann's CommandLineApp http://snurl.com/2tela
import os
import sys
import re
from datetime import datetime
import hashlib
from subprocess import Popen, PIPE

DB_NAME = ''
DB_HOST = ''
DB_USER = ''
DB_PASS = ''

reSQL = re.compile(";\s*$", re.MULTILINE)

db_connection = MySQLdb.connect(host=DB_HOST, 
                                db=DB_NAME, 
                                user=DB_USER, 
                                passwd=DB_PASS)


def get_version_list():
    c = db_connection.cursor()
    query = """select distinct version, svn_version 
                 from versions
             order by version;"""
    count = c.execute(query)
    versions = []
    if count > 0:
        versions = [version for version in c.fetchall()]
    return versions
    

class upgrade(CommandLineApp.CommandLineApp):
    """Upgrades the website database."""
    
    EXAMPLES_DESCRIPTION = """
To execute all scripts within a directory:

  $ upgrade --execute --dir /path/to/sqlscripts
  
To execute only some scripts within a directory:

  $ upgrade --execute --dir /path/to/sqlscripts 2008*
  
To just print out script, but don't execute:

  $ upgrade --dir /path/to/sqlscripts 2008*

To just print out scripts in the current working directory, 
but don't execute:

  $ upgrade

To execute all scripts in the current working directory:

  $ upgrade --execute
"""
    def __init__(self, commandLineOptions=sys.argv[1:]):
        super(upgrade, self).__init__(commandLineOptions=sys.argv[1:])
        
    execute = False
    def optionHandler_execute(self):
        """Turn on the excution option.  Defaults to False."""
        self.execute = True
        return
    
    dir = os.getcwd()
    def optionHandler_dir(self, name):
        """Set the directory that contains the scripts to execute.  
		Defaults to current working directory."""
        self.dir = name
        return
        
    def filter_down(self, *cherries):
        already_applied = get_version_list()
        print "# ALREADY APPLIED:"
        for x in already_applied:
            print '#\t%s\tr%s' % (x[0], x[1])
        in_directory = os.listdir(self.dir)
        in_directory.sort()
        sql_in_directory = []
        for sql in in_directory:
            if os.path.splitext(sql)[-1] == '.sql':
                sql_in_directory.append(sql)
        to_execute = []
        if cherries:
            for cherry in cherries:
                if cherry not in already_applied:
                    if cherry[0] in sql_in_directory:
                        to_execute.append(cherry[0])
        else:
            for sql in sql_in_directory:
                if sql not in [info[0] for info in already_applied]:
                    to_execute.append(sql)
        print '# TO EXECUTE:'
        for x in to_execute:
            print '#\t%s' % x
        return to_execute
        
    def get_rev(self, sql):
        svninfo = Popen(["svn", "info", sql], stdout=PIPE).stdout.readlines()
        for info in svninfo:
            tokens = info.split(':')
            if tokens[0].strip() == 'Last Changed Rev':
                return tokens[1].strip()
        return 0
        
    def split_file(self, sql):
        full_path = os.path.join(self.dir, sql)
        contents = open(full_path, 'r').read()
        size = os.stat(full_path).st_size
        sha1 = hashlib.sha1(contents).hexdigest()
        rev = self.get_rev(full_path)
        print "## Processing %s, %s bytes, sha1 %s, svn rev %s" % \
                (sql, size, sha1, rev)
        return {
			'statements':reSQL.split(contents), 
			'full_path':full_path, 
			'contents':contents, 
			'size':size, 
			'sha1':sha1, 
			'rev':rev
		}
        
    def execute_sql(self, statement, segment_number):
        segment = (
            segment_number, 
            len(statement), 
            hashlib.sha1(statement).hexdigest()
        )
        if not self.execute:
            print "### printing segment %s, %s bytes, sha1 %s" % segment
            print "%s;" % statement
        if self.execute:
            print "### executing segment %s, %s bytes, sha1 %s" % segment
            print "%s;" % statement
            c = db_connection.cursor()
            count = c.execute(statement)
            print "### SUCCESS, %s rows affected." % count
            return True
        return True
        
    def stamp_database(self, sql, statements, svn):
        print "## DB status updated: %s" % sql
        if self.execute:
            c = db_connection.cursor()
            count = c.execute("""
INSERT INTO versions (version, date_created, sql_executed, svn_version) 
VALUES (%(sql)s, %(date)s, %(statements)s, %(revision)s);""", {
    'sql':sql, 
    'date':datetime.now(), 
    'statements':';\n'.join(statements)+';', 
    'revision':svn
            })       
    
    def main(self, *args):
        print "# Script began %s" % datetime.now()
        to_execute = self.filter_down(*args)
        for sql in to_execute:
            split_data = self.split_file(sql)
            statements = split_data['statements']
            executed_stmts = []
            segment_num = 0
            for statement in  statements:
                stmt = statement.strip()
                if stmt and stmt not in ('BEGIN', 'COMMIT',) and \
                    "UPDATE `version` SET `serial_number`" not in stmt:
                    if self.execute_sql(stmt, segment_num):
                        executed_stmts.append(stmt)
                segment_num += 1
            # for each file update the database version
            self.stamp_database(sql, executed_stmts, split_data['rev'])
            print "-"*70
        print "# Script ended %s" % datetime.now()
        
        
if __name__ == "__main__":
    up = upgrade()
    up.run()