grafana/vendor/github.com/crewjam/saml/identity_provider.go
gotjosh e6b8a1529b
SAML: Configuration defaults, examples and dependencies (#17954)
* Add SAML configuration options

* Add crewjam/saml as a depdency

Needed as part of the enterprise SAML integration.

* Vendor github.com/stretchr/testify/require

The package require implements the same assertions as the `assert` package but stops test execution when a test fails.
2019-07-05 11:27:14 +01:00

953 lines
30 KiB
Go

package saml
import (
"bytes"
"compress/flate"
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/xml"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
"regexp"
"strconv"
"text/template"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml/logger"
"github.com/crewjam/saml/xmlenc"
dsig "github.com/russellhaering/goxmldsig"
)
// Session represents a user session. It is returned by the
// SessionProvider implementation's GetSession method. Fields here
// are used to set fields in the SAML assertion.
type Session struct {
ID string
CreateTime time.Time
ExpireTime time.Time
Index string
NameID string
Groups []string
UserName string
UserEmail string
UserCommonName string
UserSurname string
UserGivenName string
}
// SessionProvider is an interface used by IdentityProvider to determine the
// Session associated with a request. For an example implementation, see
// GetSession in the samlidp package.
type SessionProvider interface {
// GetSession returns the remote user session associated with the http.Request.
//
// If (and only if) the request is not associated with a session then GetSession
// must complete the HTTP request and return nil.
GetSession(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session
}
// ServiceProviderProvider is an interface used by IdentityProvider to look up
// service provider metadata for a request.
type ServiceProviderProvider interface {
// GetServiceProvider returns the Service Provider metadata for the
// service provider ID, which is typically the service provider's
// metadata URL. If an appropriate service provider cannot be found then
// the returned error must be os.ErrNotExist.
GetServiceProvider(r *http.Request, serviceProviderID string) (*EntityDescriptor, error)
}
// AssertionMaker is an interface used by IdentityProvider to construct the
// assertion for a request. The default implementation is DefaultAssertionMaker,
// which is used if not AssertionMaker is specified.
type AssertionMaker interface {
// MakeAssertion constructs an assertion from session and the request and
// assigns it to req.Assertion.
MakeAssertion(req *IdpAuthnRequest, session *Session) error
}
// IdentityProvider implements the SAML Identity Provider role (IDP).
//
// An identity provider receives SAML assertion requests and responds
// with SAML Assertions.
//
// You must provide a keypair that is used to
// sign assertions.
//
// You must provide an implementation of ServiceProviderProvider which
// returns
//
// You must provide an implementation of the SessionProvider which
// handles the actual authentication (i.e. prompting for a username
// and password).
type IdentityProvider struct {
Key crypto.PrivateKey
Logger logger.Interface
Certificate *x509.Certificate
Intermediates []*x509.Certificate
MetadataURL url.URL
SSOURL url.URL
LogoutURL url.URL
ServiceProviderProvider ServiceProviderProvider
SessionProvider SessionProvider
AssertionMaker AssertionMaker
SignatureMethod string
}
// Metadata returns the metadata structure for this identity provider.
func (idp *IdentityProvider) Metadata() *EntityDescriptor {
certStr := base64.StdEncoding.EncodeToString(idp.Certificate.Raw)
ed := &EntityDescriptor{
EntityID: idp.MetadataURL.String(),
ValidUntil: TimeNow().Add(DefaultValidDuration),
CacheDuration: DefaultValidDuration,
IDPSSODescriptors: []IDPSSODescriptor{
IDPSSODescriptor{
SSODescriptor: SSODescriptor{
RoleDescriptor: RoleDescriptor{
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
KeyDescriptors: []KeyDescriptor{
{
Use: "signing",
KeyInfo: KeyInfo{
Certificate: certStr,
},
},
{
Use: "encryption",
KeyInfo: KeyInfo{
Certificate: certStr,
},
EncryptionMethods: []EncryptionMethod{
{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc"},
{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes192-cbc"},
{Algorithm: "http://www.w3.org/2001/04/xmlenc#aes256-cbc"},
{Algorithm: "http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"},
},
},
},
},
NameIDFormats: []NameIDFormat{NameIDFormat("urn:oasis:names:tc:SAML:2.0:nameid-format:transient")},
},
SingleSignOnServices: []Endpoint{
{
Binding: HTTPRedirectBinding,
Location: idp.SSOURL.String(),
},
{
Binding: HTTPPostBinding,
Location: idp.SSOURL.String(),
},
},
},
},
}
if idp.LogoutURL.String() != "" {
ed.IDPSSODescriptors[0].SSODescriptor.SingleLogoutServices = []Endpoint{
{
Binding: HTTPRedirectBinding,
Location: idp.LogoutURL.String(),
},
}
}
return ed
}
// Handler returns an http.Handler that serves the metadata and SSO
// URLs
func (idp *IdentityProvider) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc(idp.MetadataURL.Path, idp.ServeMetadata)
mux.HandleFunc(idp.SSOURL.Path, idp.ServeSSO)
return mux
}
// ServeMetadata is an http.HandlerFunc that serves the IDP metadata
func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, r *http.Request) {
buf, _ := xml.MarshalIndent(idp.Metadata(), "", " ")
w.Header().Set("Content-Type", "application/samlmetadata+xml")
w.Write(buf)
}
// ServeSSO handles SAML auth requests.
//
// When it gets a request for a user that does not have a valid session,
// then it prompts the user via XXX.
//
// If the session already exists, then it produces a SAML assertion and
// returns an HTTP response according to the specified binding. The
// only supported binding right now is the HTTP-POST binding which returns
// an HTML form in the appropriate format with Javascript to automatically
// submit that form the to service provider's Assertion Customer Service
// endpoint.
//
// If the SAML request is invalid or cannot be verified a simple StatusBadRequest
// response is sent.
//
// If the assertion cannot be created or returned, a StatusInternalServerError
// response is sent.
func (idp *IdentityProvider) ServeSSO(w http.ResponseWriter, r *http.Request) {
req, err := NewIdpAuthnRequest(idp, r)
if err != nil {
idp.Logger.Printf("failed to parse request: %s", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
if err := req.Validate(); err != nil {
idp.Logger.Printf("failed to validate request: %s", err)
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
// TODO(ross): we must check that the request ID has not been previously
// issued.
session := idp.SessionProvider.GetSession(w, r, req)
if session == nil {
return
}
assertionMaker := idp.AssertionMaker
if assertionMaker == nil {
assertionMaker = DefaultAssertionMaker{}
}
if err := assertionMaker.MakeAssertion(req, session); err != nil {
idp.Logger.Printf("failed to make assertion: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := req.WriteResponse(w); err != nil {
idp.Logger.Printf("failed to write response: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
// ServeIDPInitiated handes an IDP-initiated authorization request. Requests of this
// type require us to know a registered service provider and (optionally) the RelayState
// that will be passed to the application.
func (idp *IdentityProvider) ServeIDPInitiated(w http.ResponseWriter, r *http.Request, serviceProviderID string, relayState string) {
req := &IdpAuthnRequest{
IDP: idp,
HTTPRequest: r,
RelayState: relayState,
Now: TimeNow(),
}
session := idp.SessionProvider.GetSession(w, r, req)
if session == nil {
// If GetSession returns nil, it must have written an HTTP response, per the interface
// (this is probably because it drew a login form or something)
return
}
var err error
req.ServiceProviderMetadata, err = idp.ServiceProviderProvider.GetServiceProvider(r, serviceProviderID)
if err == os.ErrNotExist {
idp.Logger.Printf("cannot find service provider: %s", serviceProviderID)
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return
} else if err != nil {
idp.Logger.Printf("cannot find service provider %s: %v", serviceProviderID, err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
// find an ACS endpoint that we can use
for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
for _, endpoint := range spssoDescriptor.AssertionConsumerServices {
if endpoint.Binding == HTTPPostBinding {
req.ACSEndpoint = &endpoint
req.SPSSODescriptor = &spssoDescriptor
break
}
}
if req.ACSEndpoint != nil {
break
}
}
if req.ACSEndpoint == nil {
idp.Logger.Printf("saml metadata does not contain an Assertion Customer Service url")
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
assertionMaker := idp.AssertionMaker
if assertionMaker == nil {
assertionMaker = DefaultAssertionMaker{}
}
if err := assertionMaker.MakeAssertion(req, session); err != nil {
idp.Logger.Printf("failed to make assertion: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if err := req.WriteResponse(w); err != nil {
idp.Logger.Printf("failed to write response: %s", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
}
// IdpAuthnRequest is used by IdentityProvider to handle a single authentication request.
type IdpAuthnRequest struct {
IDP *IdentityProvider
HTTPRequest *http.Request
RelayState string
RequestBuffer []byte
Request AuthnRequest
ServiceProviderMetadata *EntityDescriptor
SPSSODescriptor *SPSSODescriptor
ACSEndpoint *IndexedEndpoint
Assertion *Assertion
AssertionEl *etree.Element
ResponseEl *etree.Element
Now time.Time
}
// NewIdpAuthnRequest returns a new IdpAuthnRequest for the given HTTP request to the authorization
// service.
func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnRequest, error) {
req := &IdpAuthnRequest{
IDP: idp,
HTTPRequest: r,
Now: TimeNow(),
}
switch r.Method {
case "GET":
compressedRequest, err := base64.StdEncoding.DecodeString(r.URL.Query().Get("SAMLRequest"))
if err != nil {
return nil, fmt.Errorf("cannot decode request: %s", err)
}
req.RequestBuffer, err = ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest)))
if err != nil {
return nil, fmt.Errorf("cannot decompress request: %s", err)
}
req.RelayState = r.URL.Query().Get("RelayState")
case "POST":
if err := r.ParseForm(); err != nil {
return nil, err
}
var err error
req.RequestBuffer, err = base64.StdEncoding.DecodeString(r.PostForm.Get("SAMLRequest"))
if err != nil {
return nil, err
}
req.RelayState = r.PostForm.Get("RelayState")
default:
return nil, fmt.Errorf("method not allowed")
}
return req, nil
}
// Validate checks that the authentication request is valid and assigns
// the AuthnRequest and Metadata properties. Returns a non-nil error if the
// request is not valid.
func (req *IdpAuthnRequest) Validate() error {
if err := xml.Unmarshal(req.RequestBuffer, &req.Request); err != nil {
return err
}
// We always have exactly one IDP SSO descriptor
if len(req.IDP.Metadata().IDPSSODescriptors) != 1 {
panic("expected exactly one IDP SSO descriptor in IDP metadata")
}
idpSsoDescriptor := req.IDP.Metadata().IDPSSODescriptors[0]
// TODO(ross): support signed authn requests
// For now we do the safe thing and fail in the case where we think
// requests might be signed.
if idpSsoDescriptor.WantAuthnRequestsSigned != nil && *idpSsoDescriptor.WantAuthnRequestsSigned {
return fmt.Errorf("Authn request signature checking is not currently supported")
}
// In http://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf §3.4.5.2
// we get a description of the Destination attribute:
//
// If the message is signed, the Destination XML attribute in the root SAML
// element of the protocol message MUST contain the URL to which the sender
// has instructed the user agent to deliver the message. The recipient MUST
// then verify that the value matches the location at which the message has
// been received.
//
// We require the destination be correct either (a) if signing is enabled or
// (b) if it was provided.
mustHaveDestination := idpSsoDescriptor.WantAuthnRequestsSigned != nil && *idpSsoDescriptor.WantAuthnRequestsSigned
mustHaveDestination = mustHaveDestination || req.Request.Destination != ""
if mustHaveDestination {
if req.Request.Destination != req.IDP.SSOURL.String() {
return fmt.Errorf("expected destination to be %q, not %q", req.IDP.SSOURL.String(), req.Request.Destination)
}
}
if req.Request.IssueInstant.Add(MaxIssueDelay).Before(req.Now) {
return fmt.Errorf("request expired at %s",
req.Request.IssueInstant.Add(MaxIssueDelay))
}
if req.Request.Version != "2.0" {
return fmt.Errorf("expected SAML request version 2.0 got %v", req.Request.Version)
}
// find the service provider
serviceProviderID := req.Request.Issuer.Value
serviceProvider, err := req.IDP.ServiceProviderProvider.GetServiceProvider(req.HTTPRequest, serviceProviderID)
if err == os.ErrNotExist {
return fmt.Errorf("cannot handle request from unknown service provider %s", serviceProviderID)
} else if err != nil {
return fmt.Errorf("cannot find service provider %s: %v", serviceProviderID, err)
}
req.ServiceProviderMetadata = serviceProvider
// Check that the ACS URL matches an ACS endpoint in the SP metadata.
if err := req.getACSEndpoint(); err != nil {
return fmt.Errorf("cannot find assertion consumer service: %v", err)
}
return nil
}
func (req *IdpAuthnRequest) getACSEndpoint() error {
if req.Request.AssertionConsumerServiceIndex != "" {
for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
if strconv.Itoa(spAssertionConsumerService.Index) == req.Request.AssertionConsumerServiceIndex {
req.SPSSODescriptor = &spssoDescriptor
req.ACSEndpoint = &spAssertionConsumerService
return nil
}
}
}
}
if req.Request.AssertionConsumerServiceURL != "" {
for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
if spAssertionConsumerService.Location == req.Request.AssertionConsumerServiceURL {
req.SPSSODescriptor = &spssoDescriptor
req.ACSEndpoint = &spAssertionConsumerService
return nil
}
}
}
}
// Some service providers, like the Microsoft Azure AD service provider, issue
// assertion requests that don't specify an ACS url at all.
if req.Request.AssertionConsumerServiceURL == "" && req.Request.AssertionConsumerServiceIndex == "" {
// find a default ACS binding in the metadata that we can use
for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
if spAssertionConsumerService.IsDefault != nil && *spAssertionConsumerService.IsDefault {
switch spAssertionConsumerService.Binding {
case HTTPPostBinding, HTTPRedirectBinding:
req.SPSSODescriptor = &spssoDescriptor
req.ACSEndpoint = &spAssertionConsumerService
return nil
}
}
}
}
// if we can't find a default, use *any* ACS binding
for _, spssoDescriptor := range req.ServiceProviderMetadata.SPSSODescriptors {
for _, spAssertionConsumerService := range spssoDescriptor.AssertionConsumerServices {
switch spAssertionConsumerService.Binding {
case HTTPPostBinding, HTTPRedirectBinding:
req.SPSSODescriptor = &spssoDescriptor
req.ACSEndpoint = &spAssertionConsumerService
return nil
}
}
}
}
return os.ErrNotExist // no ACS url found or specified
}
// DefaultAssertionMaker produces a SAML assertion for the
// given request and assigns it to req.Assertion.
type DefaultAssertionMaker struct {
}
// MakeAssertion implements AssertionMaker. It produces a SAML assertion from the
// given request and assigns it to req.Assertion.
func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Session) error {
attributes := []Attribute{}
var attributeConsumingService *AttributeConsumingService
for _, acs := range req.SPSSODescriptor.AttributeConsumingServices {
if acs.IsDefault != nil && *acs.IsDefault {
attributeConsumingService = &acs
break
}
}
if attributeConsumingService == nil {
for _, acs := range req.SPSSODescriptor.AttributeConsumingServices {
attributeConsumingService = &acs
break
}
}
if attributeConsumingService == nil {
attributeConsumingService = &AttributeConsumingService{}
}
for _, requestedAttribute := range attributeConsumingService.RequestedAttributes {
if requestedAttribute.NameFormat == "urn:oasis:names:tc:SAML:2.0:attrname-format:basic" || requestedAttribute.NameFormat == "urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified" {
attrName := requestedAttribute.Name
attrName = regexp.MustCompile("[^A-Za-z0-9]+").ReplaceAllString(attrName, "")
switch attrName {
case "email", "emailaddress":
attributes = append(attributes, Attribute{
FriendlyName: requestedAttribute.FriendlyName,
Name: requestedAttribute.Name,
NameFormat: requestedAttribute.NameFormat,
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserEmail,
}},
})
case "name", "fullname", "cn", "commonname":
attributes = append(attributes, Attribute{
FriendlyName: requestedAttribute.FriendlyName,
Name: requestedAttribute.Name,
NameFormat: requestedAttribute.NameFormat,
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserCommonName,
}},
})
case "givenname", "firstname":
attributes = append(attributes, Attribute{
FriendlyName: requestedAttribute.FriendlyName,
Name: requestedAttribute.Name,
NameFormat: requestedAttribute.NameFormat,
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserGivenName,
}},
})
case "surname", "lastname", "familyname":
attributes = append(attributes, Attribute{
FriendlyName: requestedAttribute.FriendlyName,
Name: requestedAttribute.Name,
NameFormat: requestedAttribute.NameFormat,
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserSurname,
}},
})
case "uid", "user", "userid":
attributes = append(attributes, Attribute{
FriendlyName: requestedAttribute.FriendlyName,
Name: requestedAttribute.Name,
NameFormat: requestedAttribute.NameFormat,
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserName,
}},
})
}
}
}
if session.UserName != "" {
attributes = append(attributes, Attribute{
FriendlyName: "uid",
Name: "urn:oid:0.9.2342.19200300.100.1.1",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserName,
}},
})
}
if session.UserEmail != "" {
attributes = append(attributes, Attribute{
FriendlyName: "eduPersonPrincipalName",
Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.6",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserEmail,
}},
})
}
if session.UserSurname != "" {
attributes = append(attributes, Attribute{
FriendlyName: "sn",
Name: "urn:oid:2.5.4.4",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserSurname,
}},
})
}
if session.UserGivenName != "" {
attributes = append(attributes, Attribute{
FriendlyName: "givenName",
Name: "urn:oid:2.5.4.42",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserGivenName,
}},
})
}
if session.UserCommonName != "" {
attributes = append(attributes, Attribute{
FriendlyName: "cn",
Name: "urn:oid:2.5.4.3",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: []AttributeValue{{
Type: "xs:string",
Value: session.UserCommonName,
}},
})
}
if len(session.Groups) != 0 {
groupMemberAttributeValues := []AttributeValue{}
for _, group := range session.Groups {
groupMemberAttributeValues = append(groupMemberAttributeValues, AttributeValue{
Type: "xs:string",
Value: group,
})
}
attributes = append(attributes, Attribute{
FriendlyName: "eduPersonAffiliation",
Name: "urn:oid:1.3.6.1.4.1.5923.1.1.1.1",
NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri",
Values: groupMemberAttributeValues,
})
}
// allow for some clock skew in the validity period using the
// issuer's apparent clock.
notBefore := req.Now.Add(-1 * MaxClockSkew)
notOnOrAfterAfter := req.Now.Add(MaxIssueDelay)
if notBefore.Before(req.Request.IssueInstant) {
notBefore = req.Request.IssueInstant
notOnOrAfterAfter = notBefore.Add(MaxIssueDelay)
}
req.Assertion = &Assertion{
ID: fmt.Sprintf("id-%x", randomBytes(20)),
IssueInstant: TimeNow(),
Version: "2.0",
Issuer: Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: req.IDP.Metadata().EntityID,
},
Subject: &Subject{
NameID: &NameID{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:transient",
NameQualifier: req.IDP.Metadata().EntityID,
SPNameQualifier: req.ServiceProviderMetadata.EntityID,
Value: session.NameID,
},
SubjectConfirmations: []SubjectConfirmation{
SubjectConfirmation{
Method: "urn:oasis:names:tc:SAML:2.0:cm:bearer",
SubjectConfirmationData: &SubjectConfirmationData{
Address: req.HTTPRequest.RemoteAddr,
InResponseTo: req.Request.ID,
NotOnOrAfter: req.Now.Add(MaxIssueDelay),
Recipient: req.ACSEndpoint.Location,
},
},
},
},
Conditions: &Conditions{
NotBefore: notBefore,
NotOnOrAfter: notOnOrAfterAfter,
AudienceRestrictions: []AudienceRestriction{
AudienceRestriction{
Audience: Audience{Value: req.ServiceProviderMetadata.EntityID},
},
},
},
AuthnStatements: []AuthnStatement{
AuthnStatement{
AuthnInstant: session.CreateTime,
SessionIndex: session.Index,
SubjectLocality: &SubjectLocality{
Address: req.HTTPRequest.RemoteAddr,
},
AuthnContext: AuthnContext{
AuthnContextClassRef: &AuthnContextClassRef{
Value: "urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport",
},
},
},
},
AttributeStatements: []AttributeStatement{
AttributeStatement{
Attributes: attributes,
},
},
}
return nil
}
// The Canonicalizer prefix list MUST be empty. Various implementations
// (maybe ours?) do not appear to support non-empty prefix lists in XML C14N.
const canonicalizerPrefixList = ""
// MakeAssertionEl sets `AssertionEl` to a signed, possibly encrypted, version of `Assertion`.
func (req *IdpAuthnRequest) MakeAssertionEl() error {
keyPair := tls.Certificate{
Certificate: [][]byte{req.IDP.Certificate.Raw},
PrivateKey: req.IDP.Key,
Leaf: req.IDP.Certificate,
}
for _, cert := range req.IDP.Intermediates {
keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
}
keyStore := dsig.TLSCertKeyStore(keyPair)
signatureMethod := req.IDP.SignatureMethod
if signatureMethod == "" {
signatureMethod = dsig.RSASHA1SignatureMethod
}
signingContext := dsig.NewDefaultSigningContext(keyStore)
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return err
}
assertionEl := req.Assertion.Element()
signedAssertionEl, err := signingContext.SignEnveloped(assertionEl)
if err != nil {
return err
}
sigEl := signedAssertionEl.Child[len(signedAssertionEl.Child)-1]
req.Assertion.Signature = sigEl.(*etree.Element)
signedAssertionEl = req.Assertion.Element()
certBuf, err := req.getSPEncryptionCert()
if err == os.ErrNotExist {
req.AssertionEl = signedAssertionEl
return nil
} else if err != nil {
return err
}
var signedAssertionBuf []byte
{
doc := etree.NewDocument()
doc.SetRoot(signedAssertionEl)
signedAssertionBuf, err = doc.WriteToBytes()
if err != nil {
return err
}
}
encryptor := xmlenc.OAEP()
encryptor.BlockCipher = xmlenc.AES128CBC
encryptor.DigestMethod = &xmlenc.SHA1
encryptedDataEl, err := encryptor.Encrypt(certBuf, signedAssertionBuf)
if err != nil {
return err
}
encryptedDataEl.CreateAttr("Type", "http://www.w3.org/2001/04/xmlenc#Element")
encryptedAssertionEl := etree.NewElement("saml:EncryptedAssertion")
encryptedAssertionEl.AddChild(encryptedDataEl)
req.AssertionEl = encryptedAssertionEl
return nil
}
// WriteResponse writes the `Response` to the http.ResponseWriter. If
// `Response` is not already set, it calls MakeResponse to produce it.
func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error {
if req.ResponseEl == nil {
if err := req.MakeResponse(); err != nil {
return err
}
}
doc := etree.NewDocument()
doc.SetRoot(req.ResponseEl)
responseBuf, err := doc.WriteToBytes()
if err != nil {
return err
}
// the only supported binding is the HTTP-POST binding
switch req.ACSEndpoint.Binding {
case HTTPPostBinding:
tmpl := template.Must(template.New("saml-post-form").Parse(`<html>` +
`<form method="post" action="{{.URL}}" id="SAMLResponseForm">` +
`<input type="hidden" name="SAMLResponse" value="{{.SAMLResponse}}" />` +
`<input type="hidden" name="RelayState" value="{{.RelayState}}" />` +
`<input id="SAMLSubmitButton" type="submit" value="Continue" />` +
`</form>` +
`<script>document.getElementById('SAMLSubmitButton').style.visibility='hidden';</script>` +
`<script>document.getElementById('SAMLResponseForm').submit();</script>` +
`</html>`))
data := struct {
URL string
SAMLResponse string
RelayState string
}{
URL: req.ACSEndpoint.Location,
SAMLResponse: base64.StdEncoding.EncodeToString(responseBuf),
RelayState: req.RelayState,
}
buf := bytes.NewBuffer(nil)
if err := tmpl.Execute(buf, data); err != nil {
return err
}
if _, err := io.Copy(w, buf); err != nil {
return err
}
return nil
default:
return fmt.Errorf("%s: unsupported binding %s",
req.ServiceProviderMetadata.EntityID,
req.ACSEndpoint.Binding)
}
}
// getSPEncryptionCert returns the certificate which we can use to encrypt things
// to the SP in PEM format, or nil if no such certificate is found.
func (req *IdpAuthnRequest) getSPEncryptionCert() (*x509.Certificate, error) {
certStr := ""
for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors {
if keyDescriptor.Use == "encryption" {
certStr = keyDescriptor.KeyInfo.Certificate
break
}
}
// If there are no certs explicitly labeled for encryption, return the first
// non-empty cert we find.
if certStr == "" {
for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors {
if keyDescriptor.Use == "" && keyDescriptor.KeyInfo.Certificate != "" {
certStr = keyDescriptor.KeyInfo.Certificate
break
}
}
}
if certStr == "" {
return nil, os.ErrNotExist
}
// cleanup whitespace and re-encode a PEM
certStr = regexp.MustCompile(`\s+`).ReplaceAllString(certStr, "")
certBytes, err := base64.StdEncoding.DecodeString(certStr)
if err != nil {
return nil, fmt.Errorf("cannot decode certificate base64: %v", err)
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, fmt.Errorf("cannot parse certificate: %v", err)
}
return cert, nil
}
// unmarshalEtreeHack parses `el` and sets values in the structure `v`.
//
// This is a hack -- it first serializes the element, then uses xml.Unmarshal.
func unmarshalEtreeHack(el *etree.Element, v interface{}) error {
doc := etree.NewDocument()
doc.SetRoot(el)
buf, err := doc.WriteToBytes()
if err != nil {
return err
}
return xml.Unmarshal(buf, v)
}
// MakeResponse creates and assigns a new SAML response in ResponseEl. `Assertion` must
// be non-nil. If MakeAssertionEl() has not been called, this function calls it for
// you.
func (req *IdpAuthnRequest) MakeResponse() error {
if req.AssertionEl == nil {
if err := req.MakeAssertionEl(); err != nil {
return err
}
}
response := &Response{
Destination: req.ACSEndpoint.Location,
ID: fmt.Sprintf("id-%x", randomBytes(20)),
InResponseTo: req.Request.ID,
IssueInstant: req.Now,
Version: "2.0",
Issuer: &Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: req.IDP.MetadataURL.String(),
},
Status: Status{
StatusCode: StatusCode{
Value: StatusSuccess,
},
},
}
responseEl := response.Element()
responseEl.AddChild(req.AssertionEl) // AssertionEl either an EncryptedAssertion or Assertion element
// Sign the response element (we've already signed the Assertion element)
{
keyPair := tls.Certificate{
Certificate: [][]byte{req.IDP.Certificate.Raw},
PrivateKey: req.IDP.Key,
Leaf: req.IDP.Certificate,
}
for _, cert := range req.IDP.Intermediates {
keyPair.Certificate = append(keyPair.Certificate, cert.Raw)
}
keyStore := dsig.TLSCertKeyStore(keyPair)
signatureMethod := req.IDP.SignatureMethod
if signatureMethod == "" {
signatureMethod = dsig.RSASHA1SignatureMethod
}
signingContext := dsig.NewDefaultSigningContext(keyStore)
signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList)
if err := signingContext.SetSignatureMethod(signatureMethod); err != nil {
return err
}
signedResponseEl, err := signingContext.SignEnveloped(responseEl)
if err != nil {
return err
}
sigEl := signedResponseEl.ChildElements()[len(signedResponseEl.ChildElements())-1]
response.Signature = sigEl
responseEl = response.Element()
responseEl.AddChild(req.AssertionEl)
}
req.ResponseEl = responseEl
return nil
}