upgrade.py 10.7 KB
Newer Older
François C. committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
# -*- encoding: utf-8 -*-
##############################################################################
#
#    OpenERP, Open Source Management Solution
#    Copyright (C) 2013 Smile (<http://www.smile.fr>). All Rights Reserved
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
##############################################################################

import psycopg2

from contextlib import contextmanager
from distutils.version import LooseVersion
import logging
import os

from openerp import api, sql_db, SUPERUSER_ID, tools
from openerp.exceptions import UserError
import openerp.modules as addons
from openerp.report.interface import report_int as ReportService
from openerp.tools.safe_eval import safe_eval as eval
from openerp.workflow.service import WorkflowService

from config import configuration as upgrade_config

_logger = logging.getLogger(__package__)


@contextmanager
def cursor(db, auto_commit=True):
    cr = db.cursor()
    try:
        yield cr
        if auto_commit:
            cr.commit()
    finally:
        cr.close()


class UpgradeManager(object):
    """Upgrade Manager
    * Compare code and database versions
    * Get upgrades to apply to database
    """

    def __init__(self, db_name):
        self.db_name = db_name
        self.db = sql_db.db_connect(db_name)
        self.cr = self.db.cursor()
        self.cr.autocommit(True)
        self.db_in_creation = self._get_db_in_creation()
        self.code_version = self._get_code_version()
        self.db_version = self._get_db_version()
        self.upgrades = self._get_upgrades()
        self.modules_to_upgrade = list(set(sum([upgrade.modules_to_upgrade for upgrade in self.upgrades], [])))
        self.modules_to_install_at_creation = self.upgrades and self.upgrades[-1].modules_to_install_at_creation or []

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is None:
            self.cr.commit()
        self.cr.close()

    def _get_db_in_creation(self):
        self.cr.execute("SELECT relname FROM pg_class WHERE relname='ir_config_parameter'")
        if self.cr.rowcount:
            return False
        return True

    def _get_code_version(self):
        version = upgrade_config.get('version')
        if not version:
            _logger.warning('Unspecified version in upgrades configuration file')
            version = '0'
        _logger.debug('code version: %s', version)
        return LooseVersion(version)

    def _get_db_version(self):
        if self.db_in_creation:
            return '0'
        self.cr.execute("SELECT value FROM ir_config_parameter WHERE key = 'code.version' LIMIT 1")
        version = self.cr.fetchone()
        if not version:
            _logger.warning('Unspecified version in database')
        version = version and version[0] or '0'
        _logger.debug('database version: %s', version)
        return LooseVersion(version)

    def set_db_version(self):
        params = (SUPERUSER_ID, str(self.code_version))
        self.cr.execute("SELECT value FROM ir_config_parameter WHERE key = 'code.version' LIMIT 1")
        if not self.cr.rowcount:
            query = """INSERT INTO ir_config_parameter (create_date, create_uid, key, value)
               VALUES (now() at time zone 'UTC', %s, 'code.version', %s)"""
        else:
            query = """UPDATE ir_config_parameter
                SET (write_date, write_uid, value) = (now() at time zone 'UTC', %s, %s)
                WHERE key = 'code.version'"""
        self.cr.execute(query, params)
        _logger.debug('database version updated to %s', self.code_version)

    def _try_lock(self, warning=None):
        try:
            self.cr.execute("SELECT value FROM ir_config_parameter WHERE key = 'code.version'", log_exceptions=False)
        except psycopg2.OperationalError:
            self.cr.rollback()  # INFO: Early rollback to allow translations to work for the user feedback
            if warning:
                raise UserError(warning)
            raise

    def _get_upgrades(self):
        upgrades_path = upgrade_config.get('upgrades_path')
        if not upgrades_path:
            return []
        if not self.db_in_creation:
            self._try_lock('Upgrade in progress')
        upgrades = []
        for dir in os.listdir(upgrades_path):
            dir_path = os.path.join(upgrades_path, dir)
            if os.path.isdir(dir_path):
                file_path = os.path.join(dir_path, '__upgrade__.py')
                if not os.path.exists(file_path):
                    _logger.warning(u"%s doesn't exist", file_path)
                    continue
                if not os.path.isfile(file_path):
                    _logger.warning(u'%s is not a file', file_path)
                    continue
                with open(file_path) as f:
                    try:
                        upgrade_infos = eval(f.read())
                        upgrade = Upgrade(dir_path, upgrade_infos)
                        if (not upgrade.databases or self.db_name in upgrade.databases) \
                                and self.db_version < upgrade.version <= self.code_version:
                            upgrades.append(upgrade)
                    except Exception, e:
                        _logger.error('%s is not valid: %s', file_path, repr(e))
        upgrades.sort(key=lambda upgrade: upgrade.version)
        if upgrades and self.db_in_creation:
            upgrades = upgrades[-1:]
        return upgrades

    def _reset_services(self):
        ReportService._reports = {}
        WorkflowService.CACHE = {}

    def pre_load(self):
        with cursor(self.db) as cr:
            for upgrade in self.upgrades:
                upgrade.load_files(cr, 'pre-load')

    def post_load(self):
        with cursor(self.db) as cr:
            for upgrade in self.upgrades:
                upgrade.load_files(cr, 'post-load')
        self._reset_services()

    def force_modules_upgrade(self, registry, modules_to_upgrade):
        uid = SUPERUSER_ID
        with cursor(self.db) as cr:
            with api.Environment.manage():
                module_obj = registry.get('ir.module.module')
                module_obj.update_list(cr, uid)
                ids_to_install = module_obj.search(cr, uid, [('name', 'in', modules_to_upgrade),
                                                             ('state', 'in', ('uninstalled', 'to install'))])
                module_obj.button_install(cr, uid, ids_to_install)
                ids_to_upgrade = module_obj.search(cr, uid, [('name', 'in', modules_to_upgrade),
                                                             ('state', 'in', ('installed', 'to upgrade'))])
                module_obj.button_upgrade(cr, uid, ids_to_upgrade)
                cr.execute("UPDATE ir_module_module SET state = 'to upgrade' WHERE state = 'to install'")
        self._reset_services()


class Upgrade(object):
    """Upgrade
    * Pre-load: accept only .sql files
    * Post-load: accept .sql and .yml files
    """

    def __init__(self, dir_path, infos):
        self.dir_path = dir_path
        for k, v in infos.iteritems():
            if k == 'version':
                v = LooseVersion(v or '0')
            setattr(self, k, v)

    def __getattr__(self, key):
        default_values = {'version': '', 'databases': [], 'modules_to_upgrade': [],
                          'modules_to_install_at_creation': [], 'pre-load': [], 'post-load': []}
        if key not in self.__dict__ and key not in default_values:
            raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, key))
        return self.__dict__.get(key) or default_values.get(key)

    def _sql_import(self, cr, f_obj):
        for query in f_obj.read().split(';'):
            clean_query = ' '.join(query.split())
            if clean_query:
                cr.execute(clean_query)

    def _import_file(self, cr, mode, f_obj, module):
        root, ext = os.path.splitext(f_obj.name)
        if ext == '.sql':
            self._sql_import(cr, f_obj)
        elif mode != 'pre-load' and ext == '.yml':
            with api.Environment.manage():
                tools.convert_yaml_import(cr, module, f_obj, 'upgrade')
        elif mode != 'pre-load' and ext == '.csv':
            tools.convert_csv_import(cr, module, f_obj.name, f_obj.read(), 'upgrade')
        elif mode != 'pre-load' and ext == '.xml':
            tools.convert_xml_import(cr, module, f_obj, 'upgrade')
        else:
            _logger.error('%s extension is not supported in upgrade %sing', ext, mode)
            pass

    def load_files(self, cr, mode):
        _logger.debug('%sing %s upgrade...', mode, self.version)
        files_list = getattr(self, mode, [])
        format_files_list = lambda f: isinstance(f, tuple) and (f[0], len(f) == 2 and f[1] or 'raise') or (f, 'raise')
        for fname, error_management in map(format_files_list, files_list):
            f_name = fname.replace('/', os.path.sep)
            fp = os.path.join(self.dir_path, f_name)
            module = 'base'
            if not os.path.exists(fp):
                for adp in addons.module.ad_paths:
                    fp = os.path.join(adp, f_name)
                    if os.path.exists(fp):
                        module = fname.split('/')[0]
                        break
                else:
                    _logger.error("No such file: %s", fp)
                    continue
            with open(fp) as f_obj:
                _logger.info('importing %s file...', fname)
                cr.execute('SAVEPOINT smile_upgrades')
                try:
                    self._import_file(cr, mode, f_obj, module)
                    _logger.info('%s successfully imported', fname)
                except Exception, e:
                    if error_management == 'rollback_and_continue':
                        cr.execute("ROLLBACK TO SAVEPOINT smile_upgrades")
                        _logger.warning("%s import rollbacking: %s", fname, e)
                    elif error_management == 'raise':
                        raise e
                    elif error_management != 'not_rollback_and_continue':
                        _logger.error('%s value not supported in error management', error_management)