Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move helper functions from views.py into separate modules #71

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions django_saml2_auth/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from django.conf import settings

from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT
from saml2.client import Saml2Client
from saml2.config import Config as Saml2Config

from django_saml2_auth.utils import get_reverse


def get_saml_client(domain):
sp_config = _get_saml_config(domain)
saml_client = Saml2Client(config=sp_config)
return saml_client


def _get_saml_config(domain):
settings = _parse_settings(domain)
sp_config = Saml2Config()
sp_config.load(settings)
sp_config.allow_unknown_attributes = True
return sp_config


def _parse_settings(domain):
acs_url = domain + get_reverse([acs, 'acs', 'django_saml2_auth:acs'])
metadata = _get_metadata()

saml_settings = {
'metadata': metadata,
'service': {
'sp': {
'endpoints': {
'assertion_consumer_service': [
(acs_url, BINDING_HTTP_REDIRECT),
(acs_url, BINDING_HTTP_POST)
],
},
'allow_unsolicited': True,
'authn_requests_signed': False,
'logout_requests_signed': True,
'want_assertions_signed': True,
'want_response_signed': False,
},
},
}

if 'ENTITY_ID' in settings.SAML2_AUTH:
saml_settings['entityid'] = settings.SAML2_AUTH['ENTITY_ID']

if 'NAME_ID_FORMAT' in settings.SAML2_AUTH:
saml_settings['service']['sp']['name_id_format'] = settings.SAML2_AUTH['NAME_ID_FORMAT']

return saml_settings


def _get_metadata():
if 'METADATA_LOCAL_FILE_PATH' in settings.SAML2_AUTH:
return {
'local': [settings.SAML2_AUTH['METADATA_LOCAL_FILE_PATH']]
}
else:
return {
'remote': [
{
"url": settings.SAML2_AUTH['METADATA_AUTO_CONF_URL'],
},
]
}
30 changes: 30 additions & 0 deletions django_saml2_auth/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from django import get_version
from django.conf import settings

from pkg_resources import parse_version


def get_sp_domain(r):
if 'ASSERTION_URL' in settings.SAML2_AUTH:
return settings.SAML2_AUTH['ASSERTION_URL']
return '{scheme}://{host}'.format(
scheme='https' if r.is_secure() else 'http',
host=r.get_host(),
)


def get_reverse(objs):
'''In order to support different django version, I have to do this '''
if parse_version(get_version()) >= parse_version('2.0'):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
if objs.__class__.__name__ not in ['list', 'tuple']:
objs = [objs]

for obj in objs:
try:
return reverse(obj)
except:
pass
raise Exception('We got a URL reverse issue: %s. This is a known issue but please still submit a ticket at https://github.com/fangli/django-saml2-auth/issues/new' % str(objs))
84 changes: 5 additions & 79 deletions django_saml2_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

from rest_auth.utils import jwt_encode

from django_saml2_auth.conf import get_saml_client
from django_saml2_auth.utils import get_reverse, get_sp_domain


# default User or custom User. Now both will work.
User = get_user_model()
Expand All @@ -41,83 +44,6 @@
from django.utils.module_loading import import_by_path as import_string


def get_current_domain(r):
if 'ASSERTION_URL' in settings.SAML2_AUTH:
return settings.SAML2_AUTH['ASSERTION_URL']
return '{scheme}://{host}'.format(
scheme='https' if r.is_secure() else 'http',
host=r.get_host(),
)


def get_reverse(objs):
'''In order to support different django version, I have to do this '''
if parse_version(get_version()) >= parse_version('2.0'):
from django.urls import reverse
else:
from django.core.urlresolvers import reverse
if objs.__class__.__name__ not in ['list', 'tuple']:
objs = [objs]

for obj in objs:
try:
return reverse(obj)
except:
pass
raise Exception('We got a URL reverse issue: %s. This is a known issue but please still submit a ticket at https://github.com/fangli/django-saml2-auth/issues/new' % str(objs))


def _get_metadata():
if 'METADATA_LOCAL_FILE_PATH' in settings.SAML2_AUTH:
return {
'local': [settings.SAML2_AUTH['METADATA_LOCAL_FILE_PATH']]
}
else:
return {
'remote': [
{
"url": settings.SAML2_AUTH['METADATA_AUTO_CONF_URL'],
},
]
}


def _get_saml_client(domain):
acs_url = domain + get_reverse([acs, 'acs', 'django_saml2_auth:acs'])
metadata = _get_metadata()

saml_settings = {
'metadata': metadata,
'service': {
'sp': {
'endpoints': {
'assertion_consumer_service': [
(acs_url, BINDING_HTTP_REDIRECT),
(acs_url, BINDING_HTTP_POST)
],
},
'allow_unsolicited': True,
'authn_requests_signed': False,
'logout_requests_signed': True,
'want_assertions_signed': True,
'want_response_signed': False,
},
},
}

if 'ENTITY_ID' in settings.SAML2_AUTH:
saml_settings['entityid'] = settings.SAML2_AUTH['ENTITY_ID']

if 'NAME_ID_FORMAT' in settings.SAML2_AUTH:
saml_settings['service']['sp']['name_id_format'] = settings.SAML2_AUTH['NAME_ID_FORMAT']

spConfig = Saml2Config()
spConfig.load(saml_settings)
spConfig.allow_unknown_attributes = True
saml_client = Saml2Client(config=spConfig)
return saml_client


@login_required
def welcome(r):
try:
Expand Down Expand Up @@ -148,7 +74,7 @@ def _create_new_user(username, email, firstname, lastname):

@csrf_exempt
def acs(r):
saml_client = _get_saml_client(get_current_domain(r))
saml_client = get_saml_client(get_sp_domain(r))
resp = r.POST.get('SAMLResponse', None)
next_url = r.session.get('login_next_url', settings.SAML2_AUTH.get('DEFAULT_NEXT_URL', get_reverse('admin:index')))

Expand Down Expand Up @@ -234,7 +160,7 @@ def signin(r):

r.session['login_next_url'] = next_url

saml_client = _get_saml_client(get_current_domain(r))
saml_client = get_saml_client(get_sp_domain(r))
_, info = saml_client.prepare_for_authenticate()

redirect_url = None
Expand Down