diff --git a/invenio/modules/oauth2server/forms.py b/invenio/modules/oauth2server/forms.py index 8c4651a64..0ec4b5764 100644 --- a/invenio/modules/oauth2server/forms.py +++ b/invenio/modules/oauth2server/forms.py @@ -1,147 +1,159 @@ # -*- coding: utf-8 -*- ## ## This file is part of Invenio. ## Copyright (C) 2014 CERN. ## ## Invenio is free software; you can redistribute it and/or ## modify it under the terms of the GNU General Public License as ## published by the Free Software Foundation; either version 2 of the ## License, or (at your option) any later version. ## ## Invenio is distributed in the hope that it will be useful, but ## WITHOUT ANY WARRANTY; without even the implied warranty of ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ## General Public License for more details. ## ## You should have received a copy of the GNU General Public License ## along with Invenio; if not, write to the Free Software Foundation, Inc., ## 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. -"""Forms for generating access tokens and clients.""" +"""Define forms for generating access tokens and clients.""" from oauthlib.oauth2.rfc6749.errors import InsecureTransportError, \ InvalidRedirectURIError from wtforms_alchemy import model_form_factory from wtforms import fields, validators, widgets + +from invenio.base.i18n import _ from invenio.utils.forms import InvenioBaseForm from .models import Client from .validators import validate_redirect_uri # # Widget # def scopes_multi_checkbox(field, **kwargs): - """ Render multi checkbox widget. """ + """Render multi checkbox widget.""" kwargs.setdefault('type', 'checkbox') field_id = kwargs.pop('id', field.id) html = ['<div class="row">'] for value, label, checked in field.iter_choices(): choice_id = u'%s-%s' % (field_id, value) options = dict( kwargs, name=field.name, value=value, id=choice_id, class_=' ', ) if checked: options['checked'] = 'checked' html.append(u'<div class="col-md-3">') html.append(u'<label for="%s" class="checkbox-inline">' % field_id) html.append(u'<input %s /> ' % widgets.html_params(**options)) html.append("%s <br/><small class='text-muted'>%s</small>" % ( value, label.help_text) ) html.append(u'</label></div>') html.append(u'</div>') return u''.join(html) # # Redirect URI field # class RedirectURIField(fields.TextAreaField): - """ Process redirect URI field data. """ + """Process redirect URI field data.""" def process_formdata(self, valuelist): if valuelist: self.data = "\n".join([ x.strip() for x in filter(lambda x: x, "\n".join(valuelist).splitlines()) ]) def process_data(self, value): self.data = "\n".join(value) class RedirectURIValidator(object): - """ Validate if redirect URIs. """ + """Validate if redirect URIs.""" def __call__(self, form, field): errors = [] for v in field.data.splitlines(): try: validate_redirect_uri(v) except InsecureTransportError: errors.append(v) except InvalidRedirectURIError: errors.append(v) if errors: raise validators.ValidationError( "Invalid redirect URIs: %s" % ", ".join(errors) ) # # Forms # class ClientFormBase(model_form_factory(InvenioBaseForm)): class Meta: model = Client exclude = [ 'client_secret', 'is_internal', - 'is_confidential', ] strip_string_fields = True - field_args = dict(website=dict( - validators=[validators.Required(), validators.URL()], - widget=widgets.TextInput(), - )) + field_args = dict( + website=dict( + validators=[validators.Required(), validators.URL()], + widget=widgets.TextInput(), + ), + ) class ClientForm(ClientFormBase): + is_confidential = fields.SelectField( + label=_('Client Type'), + description=_('If you select public option, your application ' + 'MUST validate redirect URI.'), + coerce=int, + choices=[(1, _('Confidential')), (0, _('Public'))], + widget=widgets.Select(), + ) + # Trick to make redirect_uris render in the bottom of the form. redirect_uris = RedirectURIField( label="Redirect URIs (one per line)", description="One redirect URI per line. This is your applications" " authorization callback URLs. HTTPS must be used for all " "hosts except localhost (for testing purposes).", validators=[RedirectURIValidator(), validators.Required()], default='', ) class TokenForm(InvenioBaseForm): name = fields.TextField( description="Name of personal access token.", validators=[validators.Required()], ) scopes = fields.SelectMultipleField( widget=scopes_multi_checkbox, choices=[], # Must be dynamically provided in view. description="Scopes assigns permissions to your personal access token." " A personal access token works just like a normal OAuth " " access token for authentication against the API." ) diff --git a/invenio/modules/oauth2server/models.py b/invenio/modules/oauth2server/models.py index 936945de3..fee2470fc 100644 --- a/invenio/modules/oauth2server/models.py +++ b/invenio/modules/oauth2server/models.py @@ -1,334 +1,342 @@ # -*- coding: utf-8 -*- ## ## This file is part of Invenio. ## Copyright (C) 2014 CERN. ## ## Invenio is free software; you can redistribute it and/or ## modify it under the terms of the GNU General Public License as ## published by the Free Software Foundation; either version 2 of the ## License, or (at your option) any later version. ## ## Invenio is distributed in the hope that it will be useful, but ## WITHOUT ANY WARRANTY; without even the implied warranty of ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ## General Public License for more details. ## ## You should have received a copy of the GNU General Public License ## along with Invenio; if not, write to the Free Software Foundation, Inc., ## 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. """Database models for OAuth2 server.""" from __future__ import absolute_import +import six from flask import current_app from flask.ext.login import current_user - +from sqlalchemy_utils import URLType from werkzeug.security import gen_salt from wtforms import validators -from sqlalchemy_utils import URLType -import six +from invenio.base.i18n import _ from invenio.ext.sqlalchemy import db from invenio.ext.login.legacy_user import UserInfo from .validators import validate_redirect_uri, validate_scopes from .errors import ScopeDoesNotExists class OAuthUserProxy(object): """Proxy object to an Invenio User.""" def __init__(self, user): + """Initialize proxy object with user instance.""" self._user = user def __getattr__(self, name): """Pass any undefined attribute to the underlying object.""" return getattr(self._user, name) def __getstate__(self): return self.id def __setstate__(self, state): self._user = UserInfo(state) @property def id(self): + """Return user identifier.""" return self._user.get_id() def check_password(self, password): + """Check user password.""" return self.password == password @classmethod def get_current_user(cls): + """Return an instance of current user object.""" return cls(current_user._get_current_object()) class Scope(object): + + """OAuth scope definition.""" + def __init__(self, id_, help_text='', group='', internal=False): + """Initialize scope values.""" self.id = id_ self.group = group self.help_text = help_text self.is_internal = internal class Client(db.Model): """A client is the app which want to use the resource of a user. It is suggested that the client is registered by a user on your site, but it is not required. The client should contain at least these information: client_id: A random string client_secret: A random string client_type: A string represents if it is confidential redirect_uris: A list of redirect uris default_redirect_uri: One of the redirect uris default_scopes: Default scopes of the client But it could be better, if you implemented: allowed_grant_types: A list of grant types allowed_response_types: A list of response types validate_scopes: A function to validate scopes """ __tablename__ = 'oauth2CLIENT' name = db.Column( db.String(40), info=dict( - label='Name', - description='Name of application (displayed to users).', + label=_('Name'), + description=_('Name of application (displayed to users).'), validators=[validators.Required()] ) ) """Human readable name of the application.""" description = db.Column( db.Text(), default=u'', info=dict( - label='Description', - description='Optional. Description of the application' - ' (displayed to users).', + label=_('Description'), + description=_('Optional. Description of the application' + ' (displayed to users).'), ) ) """Human readable description.""" website = db.Column( URLType(), info=dict( - label='Website URL', - description='URL of your application (displayed to users).', + label=_('Website URL'), + description=_('URL of your application (displayed to users).'), ), default=u'', ) user_id = db.Column(db.ForeignKey('user.id')) """Creator of the client application.""" client_id = db.Column(db.String(255), primary_key=True) """Client application ID.""" client_secret = db.Column( db.String(255), unique=True, index=True, nullable=False ) """Client application secret.""" is_confidential = db.Column(db.Boolean, default=True) """Determine if client application is public or not.""" is_internal = db.Column(db.Boolean, default=False) """Determins if client application is an internal application.""" _redirect_uris = db.Column(db.Text) """A newline-separated list of redirect URIs. First is the default URI.""" _default_scopes = db.Column(db.Text) """A space-separated list of default scopes of the client. The value of the scope parameter is expressed as a list of space-delimited, case-sensitive strings. """ user = db.relationship('User') """Relationship to user.""" @property def allowed_grant_types(self): return current_app.config['OAUTH2_ALLOWED_GRANT_TYPES'] @property def allowed_response_types(self): return current_app.config['OAUTH2_ALLOWED_RESPONSE_TYPES'] # def validate_scopes(self, scopes): # return self._validate_scopes @property def client_type(self): if self.is_confidential: return 'confidential' return 'public' @property def redirect_uris(self): if self._redirect_uris: return self._redirect_uris.splitlines() return [] @redirect_uris.setter def redirect_uris(self, value): """Validate and store redirect URIs for client.""" if isinstance(value, six.text_type): value = value.split("\n") value = [v.strip() for v in value] for v in value: validate_redirect_uri(v) self._redirect_uris = "\n".join(value) or "" @property def default_redirect_uri(self): try: return self.redirect_uris[0] except IndexError: pass @property def default_scopes(self): """List of default scopes for client.""" if self._default_scopes: return self._default_scopes.split(" ") return [] @default_scopes.setter def default_scopes(self, scopes): """Set default scopes for client.""" validate_scopes(scopes) self._default_scopes = " ".join(set(scopes)) if scopes else "" def validate_scopes(self, scopes): """Validate if client is allowed to access scopes.""" try: validate_scopes(scopes) return True except ScopeDoesNotExists: return False def gen_salt(self): self.reset_client_id() self.reset_client_secret() def reset_client_id(self): self.client_id = gen_salt( current_app.config.get('OAUTH2_CLIENT_ID_SALT_LEN') ) def reset_client_secret(self): self.client_secret = gen_salt( current_app.config.get('OAUTH2_CLIENT_SECRET_SALT_LEN') ) class Token(db.Model): """A bearer token is the final token that can be used by the client.""" __tablename__ = 'oauth2TOKEN' id = db.Column(db.Integer, primary_key=True, autoincrement=True) """Object ID.""" client_id = db.Column( db.String(40), db.ForeignKey('oauth2CLIENT.client_id'), nullable=False, ) """Foreign key to client application.""" client = db.relationship('Client') """SQLAlchemy relationship to client application.""" user_id = db.Column( db.Integer, db.ForeignKey('user.id') ) """Foreign key to user.""" user = db.relationship('User') """SQLAlchemy relationship to user.""" token_type = db.Column(db.String(255), default='bearer') """Token type - only bearer is supported at the moment.""" access_token = db.Column(db.String(255), unique=True) refresh_token = db.Column(db.String(255), unique=True, nullable=True) expires = db.Column(db.DateTime, nullable=True) _scopes = db.Column(db.Text) is_personal = db.Column(db.Boolean, default=False) """Personal accesss token.""" is_internal = db.Column(db.Boolean, default=False) """Determines if token is an internally generated token.""" @property def scopes(self): if self._scopes: return self._scopes.split() return [] @scopes.setter def scopes(self, scopes): validate_scopes(scopes) self._scopes = " ".join(set(scopes)) if scopes else "" def get_visible_scopes(self): """Get list of non-internal scopes for token.""" from .registry import scopes as scopes_registry return [k for k, s in scopes_registry.choices() if k in self.scopes] @classmethod def create_personal(cls, name, user_id, scopes=None, is_internal=False): """Create a personal access token. A token that is bound to a specific user and which doesn't expire, i.e. similar to the concept of an API key. """ scopes = " ".join(scopes) if scopes else "" c = Client( name=name, user_id=user_id, is_internal=True, is_confidential=False, _default_scopes=scopes ) c.gen_salt() t = Token( client_id=c.client_id, user_id=user_id, access_token=gen_salt( current_app.config.get('OAUTH2_TOKEN_PERSONAL_SALT_LEN') ), expires=None, _scopes=scopes, is_personal=True, is_internal=is_internal, ) db.session.add(c) db.session.add(t) db.session.commit() return t diff --git a/invenio/modules/oauth2server/testsuite/test_provider.py b/invenio/modules/oauth2server/testsuite/test_provider.py index 7eb498f0c..b0a0ed2e1 100644 --- a/invenio/modules/oauth2server/testsuite/test_provider.py +++ b/invenio/modules/oauth2server/testsuite/test_provider.py @@ -1,827 +1,830 @@ # -*- coding: utf-8 -*- ## ## This file is part of Invenio. ## Copyright (C) 2014 CERN. ## ## Invenio is free software; you can redistribute it and/or ## modify it under the terms of the GNU General Public License as ## published by the Free Software Foundation; either version 2 of the ## License, or (at your option) any later version. ## ## Invenio is distributed in the hope that it will be useful, but ## WITHOUT ANY WARRANTY; without even the implied warranty of ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ## General Public License for more details. ## ## You should have received a copy of the GNU General Public License ## along with Invenio; if not, write to the Free Software Foundation, Inc., ## 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. from __future__ import absolute_import, print_function import logging import os from datetime import datetime, timedelta from flask import url_for from flask_oauthlib.client import prepare_request from mock import MagicMock try: from six.moves.urllib.parse import urlparse except ImportError: from urllib.parse import urlparse from invenio.base.globals import cfg from invenio.ext.sqlalchemy import db from invenio.testsuite import InvenioTestCase, make_test_suite, run_test_suite from .helpers import create_client logging.basicConfig(level=logging.DEBUG) class ProviderTestCase(InvenioTestCase): def create_app(self): try: app = super(ProviderTestCase, self).create_app() app.testing = True app.config.update(dict( OAUTH2_CACHE_TYPE='simple', )) client = create_client(app, 'oauth2test') client.http_request = MagicMock( side_effect=self.patch_request(app) ) except Exception as e: print(e) return app def patch_request(self, app): test_client = app.test_client() def make_request(uri, headers=None, data=None, method=None): uri, headers, data, method = prepare_request( uri, headers, data, method ) if not headers and data is not None: headers = { 'Content-Type': ' application/x-www-form-urlencoded' } # test client is a `werkzeug.test.Client` parsed = urlparse(uri) uri = '%s?%s' % (parsed.path, parsed.query) resp = test_client.open( uri, headers=headers, data=data, method=method, base_url=cfg['CFG_SITE_SECURE_URL'] ) # for compatible resp.code = resp.status_code return resp, resp.data return make_request def setUp(self): super(ProviderTestCase, self).setUp() # Set environment variable DEBUG to true, to allow testing without # SSL in oauthlib. if self.app.config.get('CFG_SITE_SECURE_URL').startswith('http://'): self.os_debug = os.environ.get('OAUTHLIB_INSECURE_TRANSPORT', '') os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = 'true' from ..models import Client, Scope from invenio.modules.accounts.models import User from ..registry import scopes as scopes_registry # Register a test scope scopes_registry.register(Scope('test:scope')) self.base_url = self.app.config.get('CFG_SITE_SECURE_URL') # Create needed objects u = User( email='info@invenio-software.org', nickname='tester' ) u.password = "tester" u2 = User( email='abuse@invenio-software.org', nickname='tester2' ) u2.password = "tester2" db.session.add(u) db.session.add(u2) c1 = Client( client_id='dev', client_secret='dev', name='dev', description='', is_confidential=False, user=u, _redirect_uris='%s/oauth2test/authorized' % self.base_url, _default_scopes="test:scope" ) c2 = Client( client_id='confidential', client_secret='confidential', name='confidential', description='', is_confidential=True, user=u, _redirect_uris='%s/oauth2test/authorized' % self.base_url, _default_scopes="test:scope" ) db.session.add(c1) db.session.add(c2) db.session.commit() self.objects = [u, u2, c1, c2] # Create a personal access token as well. from ..models import Token self.personal_token = Token.create_personal( 'test-personal', 1, scopes=[], is_internal=True ) def tearDown(self): super(ProviderTestCase, self).tearDown() # Set back any previous value of DEBUG environment variable. if self.app.config.get('CFG_SITE_SECURE_URL').startswith('http://'): os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = self.os_debug self.base_url = None for o in self.objects: db.session.delete(o) db.session.commit() def parse_redirect(self, location, parse_fragment=False): from werkzeug.urls import url_parse, url_decode, url_unparse scheme, netloc, script_root, qs, anchor = url_parse(location) return ( url_unparse((scheme, netloc, script_root, '', '')), url_decode(anchor if parse_fragment else qs) ) class OAuth2ProviderTestCase(ProviderTestCase): def test_client_salt(self): from ..models import Client c = Client( name='Test something', is_confidential=True, user_id=1, ) c.gen_salt() assert len(c.client_id) == \ self.app.config.get('OAUTH2_CLIENT_ID_SALT_LEN') assert len(c.client_secret) == \ self.app.config.get('OAUTH2_CLIENT_SECRET_SALT_LEN') db.session.add(c) db.session.commit() def test_invalid_authorize_requests(self): # First login on provider site self.login("tester", "tester") for client_id in ['dev', 'confidential']: redirect_uri = '%s/oauth2test/authorized' % self.base_url scope = 'test:scope' response_type = 'code' error_url = url_for('oauth2server.errors') # Valid request authorize request r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri=redirect_uri, scope=scope, response_type=response_type, client_id=client_id, )) self.assertStatus(r, 200) # Invalid scope r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri=redirect_uri, scope='INVALID', response_type=response_type, client_id=client_id, )) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) self.assertEqual(data['error'], 'invalid_scope') assert next_url == redirect_uri # Invalid response type r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri=redirect_uri, scope=scope, response_type='invalid', client_id=client_id, )) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) self.assertEqual(data['error'], 'unauthorized_client') assert next_url == redirect_uri # Missing response_type r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri=redirect_uri, scope=scope, client_id=client_id, )) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) self.assertEqual(data['error'], 'invalid_request') assert next_url == redirect_uri # Duplicate parameter r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri=redirect_uri, scope=scope, response_type='invalid', client_id=client_id, ) + "&client_id=%s" % client_id) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) self.assertEqual(data['error'], 'invalid_request') assert next_url == redirect_uri # Invalid cilent_id r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri=redirect_uri, scope=scope, response_type=response_type, client_id='invalid', )) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) self.assertEqual(data['error'], 'invalid_client_id') assert error_url in next_url r = self.client.get(next_url, query_string=data) assert 'invalid_client_id' in r.data # Invalid redirect uri r = self.client.get(url_for( 'oauth2server.authorize', redirect_uri='http://localhost/', scope=scope, response_type=response_type, client_id=client_id, )) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) self.assertEqual(data['error'], 'mismatching_redirect_uri') assert error_url in next_url def test_refresh_flow(self): # First login on provider site self.login("tester", "tester") data = dict( redirect_uri='%s/oauth2test/authorized' % self.base_url, scope='test:scope', response_type='code', client_id='confidential', state='mystate' ) r = self.client.get(url_for('oauth2server.authorize', **data)) self.assertStatus(r, 200) data['confirm'] = 'yes' data['scope'] = 'test:scope' data['state'] = 'mystate' # Obtain one time code r = self.client.post( url_for('oauth2server.authorize'), data=data ) self.assertStatus(r, 302) next_url, res_data = self.parse_redirect(r.location) assert res_data['code'] assert res_data['state'] == 'mystate' # Exchange one time code for access token r = self.client.post( url_for('oauth2server.access_token'), data=dict( client_id='confidential', client_secret='confidential', grant_type='authorization_code', code=res_data['code'], ) ) self.assertStatus(r, 200) assert r.json['access_token'] assert r.json['refresh_token'] assert r.json['scope'] == 'test:scope' assert r.json['token_type'] == 'Bearer' refresh_token = r.json['refresh_token'] old_access_token = r.json['access_token'] # Access token valid r = self.client.get(url_for('oauth2server.info', access_token=old_access_token)) self.assert200(r) # Obtain new access token with refresh token r = self.client.post( url_for('oauth2server.access_token'), data=dict( client_id='confidential', client_secret='confidential', grant_type='refresh_token', refresh_token=refresh_token, ) ) self.assertStatus(r, 200) assert r.json['access_token'] assert r.json['refresh_token'] assert r.json['access_token'] != old_access_token assert r.json['refresh_token'] != refresh_token assert r.json['scope'] == 'test:scope' assert r.json['token_type'] == 'Bearer' # New access token valid r = self.client.get(url_for('oauth2server.info', access_token=r.json['access_token'])) self.assert200(r) # Old access token no longer valid r = self.client.get(url_for('oauth2server.info', access_token=old_access_token,), base_url=cfg['CFG_SITE_SECURE_URL']) self.assert401(r) def test_web_auth_flow(self): # Go to login - should redirect to oauth2 server for login an # authorization r = self.client.get('/oauth2test/test-ping') # First login on provider site self.login("tester", "tester") r = self.client.get('/oauth2test/login') self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) # Authorize page r = self.client.get(next_url, query_string=data) self.assertStatus(r, 200) # User confirms request data['confirm'] = 'yes' data['scope'] = 'test:scope' data['state'] = '' r = self.client.post(next_url, data=data) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) assert next_url == '%s/oauth2test/authorized' % self.base_url assert 'code' in data # User is redirected back to client site. # - The client view /oauth2test/authorized will in the # background fetch the access token. r = self.client.get(next_url, query_string=data) self.assertStatus(r, 200) # Authentication flow has now been completed, and the access # token can be used to access protected resources. r = self.client.get('/oauth2test/test-ping') self.assert200(r) self.assertEqual(r.json, dict(ping='pong')) # Authentication flow has now been completed, and the access # token can be used to access protected resources. r = self.client.get('/oauth2test/test-ping') self.assert200(r) self.assertEqual(r.json, dict(ping='pong')) r = self.client.get('/oauth2test/test-info') self.assert200(r) assert r.json.get('client') == 'confidential' assert r.json.get('user') == self.objects[0].id assert r.json.get('scopes') == [u'test:scope'] # Access token doesn't provide access to this URL. r = self.client.get( '/oauth2test/test-invalid', base_url=cfg['CFG_SITE_SECURE_URL'] ) self.assertStatus(r, 401) # # Now logout r = self.client.get('/oauth2test/logout') self.assertStatus(r, 200) assert r.data == "logout" # And try to access the information again r = self.client.get('/oauth2test/test-ping') self.assert403(r) def test_implicit_flow(self): # First login on provider site self.login("tester", "tester") for client_id in ['dev', 'confidential']: data = dict( redirect_uri='%s/oauth2test/authorized' % self.base_url, response_type='token', # For implicit grant type client_id=client_id, scope='test:scope', state='teststate' ) # Authorize page r = self.client.get(url_for( 'oauth2server.authorize', **data )) self.assertStatus(r, 200) # User confirms request data['confirm'] = 'yes' data['scope'] = 'test:scope' data['state'] = 'teststate' r = self.client.post(url_for('oauth2server.authorize'), data=data) self.assertStatus(r, 302) # Important - access token exists in URI fragment and must not be # sent to the client. next_url, data = self.parse_redirect(r.location, parse_fragment=True) assert data['access_token'] assert data['token_type'] == 'Bearer' assert data['state'] == 'teststate' assert data['scope'] == 'test:scope' assert data.get('refresh_token') is None assert next_url == '%s/oauth2test/authorized' % self.base_url # Authentication flow has now been completed, and the client can # use the access token to make request to the provider. r = self.client.get(url_for('oauth2server.info', access_token=data['access_token'])) self.assert200(r) assert r.json.get('client') == client_id assert r.json.get('user') == self.objects[0].id assert r.json.get('scopes') == [u'test:scope'] def test_client_flow(self): data = dict( client_id='dev', client_secret='dev', # A public client should NOT do this! grant_type='client_credentials', scope='test:scope', ) # Public clients are not allowed to use grant_type=client_credentials r = self.client.post(url_for( 'oauth2server.access_token', **data )) self.assertStatus(r, 401) self.assertEqual(r.json['error'], 'invalid_client') data = dict( client_id='confidential', client_secret='confidential', grant_type='client_credentials', scope='test:scope', ) # Retrieve access token using client_crendentials r = self.client.post(url_for( 'oauth2server.access_token', **data )) self.assertStatus(r, 200) data = r.json assert data['access_token'] assert data['token_type'] == 'Bearer' assert data['scope'] == 'test:scope' assert data.get('refresh_token') is None # Authentication flow has now been completed, and the client can # use the access token to make request to the provider. r = self.client.get(url_for('oauth2server.info', access_token=data['access_token'])) self.assert200(r) assert r.json.get('client') == 'confidential' assert r.json.get('user') == self.objects[0].id assert r.json.get('scopes') == [u'test:scope'] def test_auth_flow_denied(self): # First login on provider site self.login("tester", "tester") r = self.client.get('/oauth2test/login') self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) # Authorize page r = self.client.get(next_url, query_string=data) self.assertStatus(r, 200) # User rejects request data['confirm'] = 'no' data['scope'] = 'test:scope' data['state'] = '' r = self.client.post(next_url, data=data) self.assertStatus(r, 302) next_url, data = self.parse_redirect(r.location) assert next_url == '%s/oauth2test/authorized' % self.base_url assert data.get('error') == 'access_denied' # Returned r = self.client.get(next_url, query_string=data) self.assert200(r) assert r.data == "Access denied: error=access_denied" def test_personal_access_token(self): r = self.client.get( '/oauth/ping', query_string="access_token=%s" % self.personal_token.access_token ) self.assert200(r) self.assertEqual(r.json, dict(ping='pong')) # Access token is not valid for this scope r = self.client.get( '/oauth/info/', query_string="access_token=%s" % self.personal_token.access_token, base_url=cfg['CFG_SITE_SECURE_URL'] ) self.assertStatus(r, 401) def test_resource_auth_methods(self): # Query string r = self.client.get( '/oauth/ping', query_string="access_token=%s" % self.personal_token.access_token ) self.assert200(r) self.assertEqual(r.json, dict(ping='pong')) # POST request body r = self.client.post( '/oauth/ping', data=dict(access_token=self.personal_token.access_token), ) self.assert200(r) self.assertEqual(r.json, dict(ping='pong')) # Authorization Header r = self.client.get( '/oauth/ping', headers=[ ("Authorization", "Bearer %s" % self.personal_token.access_token), ] ) self.assert200(r) self.assertEqual(r.json, dict(ping='pong')) def test_settings_index(self): # Create a remote account (linked account) r = self.client.get( url_for('oauth2server_settings.index'), base_url=cfg['CFG_SITE_SECURE_URL'], ) self.assertStatus(r, 401) self.login("tester", "tester") res = self.client.get( url_for('oauth2server_settings.index'), base_url=cfg['CFG_SITE_SECURE_URL'], ) self.assert200(res) res = self.client.get( url_for('oauth2server_settings.client_new'), base_url=cfg['CFG_SITE_SECURE_URL'], ) self.assert200(res) # Valid POST res = self.client.post( url_for('oauth2server_settings.client_new'), base_url=cfg['CFG_SITE_SECURE_URL'], data=dict( name='Test', description='Test description', website='http://invenio-software.org', + is_confidential=1, redirect_uris="http://localhost/oauth/authorized/" ) ) self.assertStatus(res, 302) # Invalid redirect_uri (must be https) res = self.client.post( url_for('oauth2server_settings.client_new'), base_url=cfg['CFG_SITE_SECURE_URL'], data=dict( name='Test', description='Test description', website='http://invenio-software.org', + is_confidential=1, redirect_uris="http://example.org/oauth/authorized/" ) ) self.assertStatus(res, 200) # Valid res = self.client.post( url_for('oauth2server_settings.client_new'), base_url=cfg['CFG_SITE_SECURE_URL'], data=dict( name='Test', description='Test description', website='http://invenio-software.org', + is_confidential=1, redirect_uris="https://example.org/oauth/authorized/\n" "http://localhost:4000/oauth/authorized/" ) ) self.assertStatus(res, 302) class OAuth2ProviderExpirationTestCase(ProviderTestCase): @property def config(self): ctx = super(OAuth2ProviderExpirationTestCase, self).config ctx.update( OAUTH2_PROVIDER_TOKEN_EXPIRES_IN=1, # make them all expired ) return ctx def test_refresh_flow(self): # First login on provider site self.login("tester", "tester") data = dict( redirect_uri='%s/oauth2test/authorized' % self.base_url, scope='test:scope', response_type='code', client_id='confidential', state='mystate' ) r = self.client.get(url_for('oauth2server.authorize', **data)) self.assertStatus(r, 200) data['confirm'] = 'yes' data['scope'] = 'test:scope' data['state'] = 'mystate' # Obtain one time code r = self.client.post( url_for('oauth2server.authorize'), data=data ) self.assertStatus(r, 302) next_url, res_data = self.parse_redirect(r.location) assert res_data['code'] assert res_data['state'] == 'mystate' # Exchange one time code for access token r = self.client.post( url_for('oauth2server.access_token'), data=dict( client_id='confidential', client_secret='confidential', grant_type='authorization_code', code=res_data['code'], ) ) self.assertStatus(r, 200) assert r.json['access_token'] assert r.json['refresh_token'] assert r.json['expires_in'] > 0 assert r.json['scope'] == 'test:scope' assert r.json['token_type'] == 'Bearer' refresh_token = r.json['refresh_token'] old_access_token = r.json['access_token'] # Access token valid r = self.client.get(url_for('oauth2server.info', access_token=old_access_token)) self.assert200(r) from ..models import Token Token.query.filter_by(access_token=old_access_token).update( dict(expires=datetime.utcnow() - timedelta(seconds=1)) ) db.session.commit() # Access token is expired r = self.client.get(url_for('oauth2server.info', access_token=old_access_token), base_url=cfg['CFG_SITE_SECURE_URL']) self.assert401(r) # Obtain new access token with refresh token r = self.client.post( url_for('oauth2server.access_token'), data=dict( client_id='confidential', client_secret='confidential', grant_type='refresh_token', refresh_token=refresh_token, ) ) self.assertStatus(r, 200) assert r.json['access_token'] assert r.json['refresh_token'] assert r.json['expires_in'] > 0 assert r.json['access_token'] != old_access_token assert r.json['refresh_token'] != refresh_token assert r.json['scope'] == 'test:scope' assert r.json['token_type'] == 'Bearer' # New access token valid r = self.client.get(url_for('oauth2server.info', access_token=r.json['access_token'])) self.assert200(r) # Old access token no longer valid r = self.client.get(url_for('oauth2server.info', access_token=old_access_token,), base_url=cfg['CFG_SITE_SECURE_URL']) self.assert401(r) def test_not_allowed_public_refresh_flow(self): # First login on provider site self.login("tester", "tester") data = dict( redirect_uri='%s/oauth2test/authorized' % self.base_url, scope='test:scope', response_type='code', client_id='dev', state='mystate' ) r = self.client.get(url_for('oauth2server.authorize', **data)) self.assertStatus(r, 200) data['confirm'] = 'yes' data['scope'] = 'test:scope' data['state'] = 'mystate' # Obtain one time code r = self.client.post( url_for('oauth2server.authorize'), data=data ) self.assertStatus(r, 302) next_url, res_data = self.parse_redirect(r.location) assert res_data['code'] assert res_data['state'] == 'mystate' # Exchange one time code for access token r = self.client.post( url_for('oauth2server.access_token'), data=dict( client_id='dev', client_secret='dev', grant_type='authorization_code', code=res_data['code'], ) ) self.assertStatus(r, 200) assert r.json['access_token'] assert r.json['refresh_token'] assert r.json['expires_in'] > 0 assert r.json['scope'] == 'test:scope' assert r.json['token_type'] == 'Bearer' refresh_token = r.json['refresh_token'] old_access_token = r.json['access_token'] # Access token valid r = self.client.get(url_for('oauth2server.info', access_token=old_access_token)) self.assert200(r) from ..models import Token Token.query.filter_by(access_token=old_access_token).update( dict(expires=datetime.utcnow() - timedelta(seconds=1)) ) db.session.commit() # Access token is expired r = self.client.get(url_for('oauth2server.info', access_token=old_access_token), follow_redirects=True) self.assert401(r) # Obtain new access token with refresh token r = self.client.post( url_for('oauth2server.access_token'), data=dict( client_id='dev', client_secret='dev', grant_type='refresh_token', refresh_token=refresh_token, ), follow_redirects=True ) # Only confidential clients can refresh expired token. self.assert401(r) TEST_SUITE = make_test_suite(OAuth2ProviderTestCase, OAuth2ProviderExpirationTestCase) if __name__ == "__main__": run_test_suite(TEST_SUITE)