diff --git a/invenio/ext/sqlalchemy/__init__.py b/invenio/ext/sqlalchemy/__init__.py index 880c20bee..f67795b1b 100644 --- a/invenio/ext/sqlalchemy/__init__.py +++ b/invenio/ext/sqlalchemy/__init__.py @@ -1,248 +1,241 @@ # -*- coding: utf-8 -*- # # This file is part of Invenio. # Copyright (C) 2011, 2012, 2013, 2014, 2015 CERN. # # Invenio is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License as # published by the Free Software Foundation; either version 2 of the # License, or (at your option) any later version. # # Invenio is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Invenio; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. -""" - invenio.ext.sqlalchemy. - - This module provides initialization and configuration for - `flask_sqlalchemy` module. -""" +"""Initialization and configuration for `flask_sqlalchemy` module.""" import sqlalchemy -from flask_registry import RegistryProxy, ModuleAutoDiscoveryRegistry +from flask_registry import ModuleAutoDiscoveryRegistry, RegistryProxy + from flask_sqlalchemy import SQLAlchemy as FlaskSQLAlchemy -from sqlalchemy import event -from sqlalchemy.ext.hybrid import hybrid_property, Comparator -from sqlalchemy.pool import Pool -from sqlalchemy_utils import JSONType +from invenio.ext.sqlalchemy.types import LegacyBigInteger, LegacyInteger, \ + LegacyMediumInteger, LegacySmallInteger, LegacyTinyInteger from invenio.utils.hash import md5 + +from sqlalchemy import event, types +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.ext.hybrid import Comparator, hybrid_property +from sqlalchemy.pool import Pool + +from sqlalchemy_utils.types import JSONType + from .expressions import AsBINARY -from .types import MarshalBinary, PickleBinary, GUID +from .types import GUID, MarshalBinary, PickleBinary from .utils import get_model_type -from invenio.ext.sqlalchemy.types import (LegacyInteger, LegacyMediumInteger, - LegacySmallInteger, - LegacyTinyInteger, - LegacyBigInteger) def _include_sqlalchemy(obj, engine=None): """Init all required SQLAlchemy's types.""" # for module in sqlalchemy, sqlalchemy.orm: # for key in module.__all__: # if not hasattr(obj, key): # setattr(obj, key, # getattr(module, key)) if engine == 'mysql': from sqlalchemy.dialects import mysql as engine_types else: from sqlalchemy import types as engine_types # Length is provided to JSONType to ensure MySQL uses LONGTEXT instead # of TEXT which only provides for 64kb storage compared to 4gb for # LONGTEXT. setattr(obj, 'JSON', JSONType(length=2 ** 32 - 2)) setattr(obj, 'Char', engine_types.CHAR) try: setattr(obj, 'TinyText', engine_types.TINYTEXT) - except: + except Exception: setattr(obj, 'TinyText', engine_types.TEXT) setattr(obj, 'hybrid_property', hybrid_property) try: setattr(obj, 'Double', engine_types.DOUBLE) - except: + except Exception: setattr(obj, 'Double', engine_types.FLOAT) setattr(obj, 'Binary', sqlalchemy.types.LargeBinary) setattr(obj, 'iBinary', sqlalchemy.types.LargeBinary) setattr(obj, 'iLargeBinary', sqlalchemy.types.LargeBinary) setattr(obj, 'iMediumBinary', sqlalchemy.types.LargeBinary) setattr(obj, 'UUID', GUID) setattr(obj, 'Integer', LegacyInteger) setattr(obj, 'MediumInteger', LegacyMediumInteger) setattr(obj, 'SmallInteger', LegacySmallInteger) setattr(obj, 'TinyInteger', LegacyTinyInteger) setattr(obj, 'BigInteger', LegacyBigInteger) if engine == 'mysql': from .engines import mysql as dummy_mysql # noqa # module = invenio.sqlalchemyutils_mysql # for key in module.__dict__: # setattr(obj, key, # getattr(module, key)) obj.AsBINARY = AsBINARY obj.MarshalBinary = MarshalBinary obj.PickleBinary = PickleBinary - ## Overwrite :meth:`MutableDick.update` to detect changes. + # Overwrite :meth:`MutableDick.update` to detect changes. from sqlalchemy.ext.mutable import MutableDict def update_mutable_dict(self, *args, **kwargs): super(MutableDict, self).update(*args, **kwargs) self.changed() MutableDict.update = update_mutable_dict obj.MutableDict = MutableDict -from sqlalchemy.ext.compiler import compiles -from sqlalchemy import types -import sqlalchemy.dialects.postgresql - - @compiles(types.Text, 'postgresql') @compiles(sqlalchemy.dialects.postgresql.TEXT, 'postgresql') def compile_text(element, compiler, **kw): """Redefine Text filed type for PostgreSQL.""" return 'TEXT' @compiles(types.VARBINARY, 'postgresql') -def compile_text(element, compiler, **kw): +def compile_varbinary(element, compiler, **kw): """Redefine VARBINARY filed type for PostgreSQL.""" return 'BYTEA' class PasswordComparator(Comparator): """Implement a password comparator.""" def __eq__(self, other): """Implement the equal operator.""" return self.__clause_element__() == self.hash(other) def hash(self, password): """Generate a hashed version of the password.""" if db.engine.name != 'mysql': return md5(password).digest() email = self.__clause_element__().table.columns.email return db.func.aes_encrypt(email, password) def autocommit_on_checkin(dbapi_con, con_record): """Call autocommit on raw mysql connection for fixing bug in MySQL 5.5.""" try: dbapi_con.autocommit(True) - except: + except Exception: pass # FIXME # from invenio.ext.logging import register_exception # register_exception() # Possibly register globally. # event.listen(Pool, 'checkin', autocommit_on_checkin) class SQLAlchemy(FlaskSQLAlchemy): """Database object.""" PasswordComparator = PasswordComparator def init_app(self, app): """Init application.""" super(self.__class__, self).init_app(app) engine = app.config.get('CFG_DATABASE_TYPE', 'mysql') self.Model = get_model_type(self.Model) if engine == 'mysql': # Override MySQL parameters to force MyISAM engine mysql_parameters = {'keep_existing': True, 'extend_existing': False, 'mysql_engine': 'MyISAM', 'mysql_charset': 'utf8'} original_table = self.Table def table_with_myisam(*args, **kwargs): """Use same MySQL parameters that are used for ORM models.""" new_kwargs = dict(mysql_parameters) new_kwargs.update(kwargs) return original_table(*args, **new_kwargs) self.Table = table_with_myisam self.Model.__table_args__ = mysql_parameters _include_sqlalchemy(self, engine=engine) def __getattr__(self, name): """ Called when the normal mechanism fails. This is only called when the normal mechanism fails, so in practice should never be called. It is only provided to satisfy pylint that it is okay not to raise E1101 errors in the client code. :see http://stackoverflow.com/a/3515234/780928 """ raise AttributeError("%r instance has no attribute %r" % (self, name)) def schemadiff(self, excludeTables=None): """Generate a schema diff.""" from migrate.versioning import schemadiff return schemadiff \ .getDiffOfModelAgainstDatabase(self.metadata, self.engine, excludeTables=excludeTables) def apply_driver_hacks(self, app, info, options): """Called before engine creation.""" # Don't forget to apply hacks defined on parent object. super(self.__class__, self).apply_driver_hacks(app, info, options) if info.drivername == 'mysql': options.setdefault('execution_options', { # Autocommit cause Exception in SQLAlchemy >= 0.9. # @see http://docs.sqlalchemy.org/en/rel_0_9/ # core/connections.html#understanding-autocommit # 'autocommit': True, 'use_unicode': False, 'charset': 'utf8mb4', }) event.listen(Pool, 'checkin', autocommit_on_checkin) db = SQLAlchemy() """ Provides access to :class:`~.SQLAlchemy` instance. """ models = RegistryProxy('models', ModuleAutoDiscoveryRegistry, 'models') def setup_app(app): """Setup SQLAlchemy extension.""" if 'SQLALCHEMY_DATABASE_URI' not in app.config: from sqlalchemy.engine.url import URL cfg = app.config app.config['SQLALCHEMY_DATABASE_URI'] = URL( cfg.get('CFG_DATABASE_TYPE', 'mysql'), username=cfg.get('CFG_DATABASE_USER'), password=cfg.get('CFG_DATABASE_PASS'), host=cfg.get('CFG_DATABASE_HOST'), database=cfg.get('CFG_DATABASE_NAME'), port=cfg.get('CFG_DATABASE_PORT'), ) - ## Let's initialize database. + # Let's initialize database. db.init_app(app) return app diff --git a/invenio/modules/oauth2server/models.py b/invenio/modules/oauth2server/models.py index 65098a90d..db5942bac 100644 --- a/invenio/modules/oauth2server/models.py +++ b/invenio/modules/oauth2server/models.py @@ -1,390 +1,391 @@ # -*- coding: utf-8 -*- # # This file is part of Invenio. # Copyright (C) 2014, 2015 CERN. # # Invenio is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License as # published by the Free Software Foundation; either version 2 of the # License, or (at your option) any later version. # # Invenio is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Invenio; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. """Database models for OAuth2 server.""" from __future__ import absolute_import from flask import current_app + from flask_login import current_user from invenio.base.i18n import _ from invenio.config import SECRET_KEY as secret_key from invenio.ext.login.legacy_user import UserInfo from invenio.ext.sqlalchemy import db import six -from sqlalchemy_utils import URLType +from sqlalchemy_utils.types import URLType from sqlalchemy_utils.types.encrypted import AesEngine, EncryptedType from werkzeug.security import gen_salt from wtforms import validators from .errors import ScopeDoesNotExists from .validators import validate_redirect_uri, validate_scopes class NoneAesEngine(AesEngine): """Filter None values from encrypting.""" def encrypt(self, value): """Encrypt a value on the way in.""" if value is not None: return super(NoneAesEngine, self).encrypt(value) def decrypt(self, value): """Decrypt value on the way out.""" if value is not None: return super(NoneAesEngine, self).decrypt(value) class String255EncryptedType(EncryptedType): """String encrypted type.""" impl = db.String(255) class OAuthUserProxy(object): """Proxy object to an Invenio User.""" def __init__(self, user): """Initialize proxy object with user instance.""" self._user = user def __getattr__(self, name): """Pass any undefined attribute to the underlying object.""" return getattr(self._user, name) def __getstate__(self): """Return the id.""" return self.id def __setstate__(self, state): """Set user info.""" self._user = UserInfo(state) @property def id(self): """Return user identifier.""" return self._user.get_id() def check_password(self, password): """Check user password.""" return self.password == password @classmethod def get_current_user(cls): """Return an instance of current user object.""" return cls(current_user._get_current_object()) class Scope(object): """OAuth scope definition.""" def __init__(self, id_, help_text='', group='', internal=False): """Initialize scope values.""" self.id = id_ self.group = group self.help_text = help_text self.is_internal = internal class Client(db.Model): """A client is the app which want to use the resource of a user. It is suggested that the client is registered by a user on your site, but it is not required. The client should contain at least these information: client_id: A random string client_secret: A random string client_type: A string represents if it is confidential redirect_uris: A list of redirect uris default_redirect_uri: One of the redirect uris default_scopes: Default scopes of the client But it could be better, if you implemented: allowed_grant_types: A list of grant types allowed_response_types: A list of response types validate_scopes: A function to validate scopes """ __tablename__ = 'oauth2CLIENT' name = db.Column( db.String(40), info=dict( label=_('Name'), description=_('Name of application (displayed to users).'), validators=[validators.DataRequired()] ) ) """Human readable name of the application.""" description = db.Column( db.Text(), default=u'', info=dict( label=_('Description'), description=_('Optional. Description of the application' ' (displayed to users).'), ) ) """Human readable description.""" website = db.Column( URLType(), info=dict( label=_('Website URL'), description=_('URL of your application (displayed to users).'), ), default=u'', ) user_id = db.Column(db.ForeignKey('user.id')) """Creator of the client application.""" client_id = db.Column(db.String(255), primary_key=True) """Client application ID.""" client_secret = db.Column( db.String(255), unique=True, index=True, nullable=False ) """Client application secret.""" is_confidential = db.Column(db.Boolean, default=True) """Determine if client application is public or not.""" is_internal = db.Column(db.Boolean, default=False) """Determins if client application is an internal application.""" _redirect_uris = db.Column(db.Text) """A newline-separated list of redirect URIs. First is the default URI.""" _default_scopes = db.Column(db.Text) """A space-separated list of default scopes of the client. The value of the scope parameter is expressed as a list of space-delimited, case-sensitive strings. """ user = db.relationship('User') """Relationship to user.""" @property def allowed_grant_types(self): """Return allowed grant types.""" return current_app.config['OAUTH2_ALLOWED_GRANT_TYPES'] @property def allowed_response_types(self): """Return allowed response types.""" return current_app.config['OAUTH2_ALLOWED_RESPONSE_TYPES'] # def validate_scopes(self, scopes): # return self._validate_scopes @property def client_type(self): """Return client type.""" if self.is_confidential: return 'confidential' return 'public' @property def redirect_uris(self): """Return redirect uris.""" if self._redirect_uris: return self._redirect_uris.splitlines() return [] @redirect_uris.setter def redirect_uris(self, value): """Validate and store redirect URIs for client.""" if isinstance(value, six.text_type): value = value.split("\n") value = [v.strip() for v in value] for v in value: validate_redirect_uri(v) self._redirect_uris = "\n".join(value) or "" @property def default_redirect_uri(self): """Return default redirect uri.""" try: return self.redirect_uris[0] except IndexError: pass @property def default_scopes(self): """List of default scopes for client.""" if self._default_scopes: return self._default_scopes.split(" ") return [] @default_scopes.setter def default_scopes(self, scopes): """Set default scopes for client.""" validate_scopes(scopes) self._default_scopes = " ".join(set(scopes)) if scopes else "" def validate_scopes(self, scopes): """Validate if client is allowed to access scopes.""" try: validate_scopes(scopes) return True except ScopeDoesNotExists: return False def gen_salt(self): """Generate salt.""" self.reset_client_id() self.reset_client_secret() def reset_client_id(self): """Reset client id.""" self.client_id = gen_salt( current_app.config.get('OAUTH2_CLIENT_ID_SALT_LEN') ) def reset_client_secret(self): """Reset client secret.""" self.client_secret = gen_salt( current_app.config.get('OAUTH2_CLIENT_SECRET_SALT_LEN') ) class Token(db.Model): """A bearer token is the final token that can be used by the client.""" __tablename__ = 'oauth2TOKEN' id = db.Column(db.Integer, primary_key=True, autoincrement=True) """Object ID.""" client_id = db.Column( db.String(40), db.ForeignKey('oauth2CLIENT.client_id'), nullable=False, ) """Foreign key to client application.""" client = db.relationship('Client') """SQLAlchemy relationship to client application.""" user_id = db.Column( db.Integer, db.ForeignKey('user.id') ) """Foreign key to user.""" user = db.relationship('User') """SQLAlchemy relationship to user.""" token_type = db.Column(db.String(255), default='bearer') """Token type - only bearer is supported at the moment.""" access_token = db.Column(String255EncryptedType( type_in=db.String(255), key=secret_key), unique=True ) refresh_token = db.Column(String255EncryptedType( type_in=db.String(255), key=secret_key, engine=NoneAesEngine), unique=True, nullable=True ) expires = db.Column(db.DateTime, nullable=True) _scopes = db.Column(db.Text) is_personal = db.Column(db.Boolean, default=False) """Personal accesss token.""" is_internal = db.Column(db.Boolean, default=False) """Determines if token is an internally generated token.""" @property def scopes(self): """Return all scopes.""" if self._scopes: return self._scopes.split() return [] @scopes.setter def scopes(self, scopes): """Set scopes.""" validate_scopes(scopes) self._scopes = " ".join(set(scopes)) if scopes else "" def get_visible_scopes(self): """Get list of non-internal scopes for token.""" from .registry import scopes as scopes_registry return [k for k, s in scopes_registry.choices() if k in self.scopes] @classmethod def create_personal(cls, name, user_id, scopes=None, is_internal=False): """Create a personal access token. A token that is bound to a specific user and which doesn't expire, i.e. similar to the concept of an API key. """ scopes = " ".join(scopes) if scopes else "" c = Client( name=name, user_id=user_id, is_internal=True, is_confidential=False, _default_scopes=scopes ) c.gen_salt() t = Token( client_id=c.client_id, user_id=user_id, access_token=gen_salt( current_app.config.get('OAUTH2_TOKEN_PERSONAL_SALT_LEN') ), expires=None, _scopes=scopes, is_personal=True, is_internal=is_internal, ) db.session.add(c) db.session.add(t) db.session.commit() return t diff --git a/invenio/modules/oauth2server/upgrades/oauth2server_2014_02_17_initial.py b/invenio/modules/oauth2server/upgrades/oauth2server_2014_02_17_initial.py index a0a15ce26..cb4138dd0 100644 --- a/invenio/modules/oauth2server/upgrades/oauth2server_2014_02_17_initial.py +++ b/invenio/modules/oauth2server/upgrades/oauth2server_2014_02_17_initial.py @@ -1,89 +1,96 @@ # -*- coding: utf-8 -*- # # This file is part of Invenio. # Copyright (C) 2014, 2015 CERN. # # Invenio is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License as # published by the Free Software Foundation; either version 2 of the # License, or (at your option) any later version. # # Invenio is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # General Public License for more details. # # You should have received a copy of the GNU General Public License # along with Invenio; if not, write to the Free Software Foundation, Inc., # 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. +"""Upgrade recipe.""" + import warnings -from sqlalchemy import * from invenio.ext.sqlalchemy import db -from sqlalchemy_utils import URLType from invenio.modules.upgrader.api import op +from sqlalchemy_utils.types import URLType + + depends_on = [] def info(): + """Info.""" return "Tables for oauth2server" def do_upgrade(): - """ Implement your upgrades here """ + """Implement your upgrades here.""" if not op.has_table('oauth2CLIENT'): op.create_table( 'oauth2CLIENT', db.Column('name', db.String(length=40), nullable=True), db.Column('description', db.Text(), nullable=True), db.Column('website', URLType(), nullable=True), db.Column('user_id', db.Integer(15, unsigned=True), nullable=True), db.Column('client_id', db.String(length=255), nullable=False), db.Column('client_secret', db.String(length=255), nullable=False), db.Column('is_confidential', db.Boolean(), nullable=True), db.Column('is_internal', db.Boolean(), nullable=True), db.Column('_redirect_uris', db.Text(), nullable=True), db.Column('_default_scopes', db.Text(), nullable=True), db.ForeignKeyConstraint(['user_id'], ['user.id'], ), db.PrimaryKeyConstraint('client_id'), mysql_charset='utf8', mysql_engine='MyISAM' ) else: warnings.warn("*** Creation of table 'oauth2CLIENT' skipped!") if not op.has_table('oauth2TOKEN'): op.create_table( 'oauth2TOKEN', db.Column('id', db.Integer(15, unsigned=True), autoincrement=True, nullable=False), db.Column('client_id', db.String(length=40), nullable=False), db.Column('user_id', db.Integer(15, unsigned=True), nullable=True), db.Column('token_type', db.String(length=255), nullable=True), db.Column('access_token', db.String(length=255), nullable=True), db.Column('refresh_token', db.String(length=255), nullable=True), db.Column('expires', db.DateTime(), nullable=True), db.Column('_scopes', db.Text(), nullable=True), db.Column('is_personal', db.Boolean(), nullable=True), db.Column('is_internal', db.Boolean(), nullable=True), - db.ForeignKeyConstraint(['client_id'], ['oauth2CLIENT.client_id'], ), + db.ForeignKeyConstraint( + ['client_id'], ['oauth2CLIENT.client_id'],), db.ForeignKeyConstraint(['user_id'], ['user.id'], ), db.PrimaryKeyConstraint('id'), db.UniqueConstraint('access_token'), db.UniqueConstraint('refresh_token'), mysql_charset='utf8', mysql_engine='MyISAM' ) else: warnings.warn("*** Creation of table 'oauth2TOKEN' skipped!") # # Following create index causes problems # op.create_index( # 'ix_oauth2CLIENT_client_secret', 'oauth2CLIENT', ['client_secret'], # unique=True # ) + def estimate(): + """Estimate.""" return 1