MM-15889 Add unit tests for CSRF checks (#11058)

* MM-15889 Add unit tests for CSRF checks

* Moved CSRF token test to login tests

* Remove empty test

* Remove debug messages
This commit is contained in:
Harrison Healey
2019-06-11 15:09:00 -04:00
committed by GitHub
parent 28cf642ccb
commit 803ce61ef8
8 changed files with 300 additions and 80 deletions

View File

@@ -515,7 +515,6 @@ func completeOAuth(c *Context, w http.ResponseWriter, r *http.Request) {
if action == model.OAUTH_ACTION_EMAIL_TO_SSO {
redirectUrl = c.GetSiteURLHeader() + "/login?extra=signin_change"
} else if action == model.OAUTH_ACTION_SSO_TO_EMAIL {
redirectUrl = app.GetProtocol(r) + "://" + r.Host + "/claim?email=" + url.QueryEscape(props["email"])
} else {
session, err := c.App.DoLogin(w, r, user, "")
@@ -528,6 +527,8 @@ func completeOAuth(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
c.App.AttachSessionCookies(w, r, session)
c.App.Session = *session
if _, ok := props["redirect_to"]; ok {

View File

@@ -1341,6 +1341,10 @@ func login(c *Context, w http.ResponseWriter, r *http.Request) {
c.LogAuditWithUserId(user.Id, "success")
if r.Header.Get(model.HEADER_REQUESTED_WITH) == model.HEADER_REQUESTED_WITH_XML {
c.App.AttachSessionCookies(w, r, session)
}
userTermsOfService, err := c.App.GetUserTermsOfService(user.Id)
if err != nil && err.StatusCode != http.StatusNotFound {
c.Err = err

View File

@@ -6,6 +6,7 @@ package api4
import (
"fmt"
"net/http"
"regexp"
"strconv"
"strings"
"testing"
@@ -32,22 +33,7 @@ func TestCreateUser(t *testing.T) {
CheckNoError(t, resp)
CheckCreatedStatus(t, resp)
_, resp = th.Client.Login(user.Email, user.Password)
session, _ := th.App.GetSession(th.Client.AuthToken)
expectedCsrf := "MMCSRF=" + session.GetCSRF()
actualCsrf := ""
for _, cookie := range resp.Header["Set-Cookie"] {
if strings.HasPrefix(cookie, "MMCSRF") {
cookieParts := strings.Split(cookie, ";")
actualCsrf = cookieParts[0]
break
}
}
if expectedCsrf != actualCsrf {
t.Errorf("CSRF Mismatch - Expected %s, got %s", expectedCsrf, actualCsrf)
}
_, _ = th.Client.Login(user.Email, user.Password)
if ruser.Nickname != user.Nickname {
t.Fatal("nickname didn't match")
@@ -2721,33 +2707,74 @@ func TestLogin(t *testing.T) {
}
func TestLoginCookies(t *testing.T) {
th := Setup().InitBasic()
defer th.TearDown()
th.Client.Logout()
t.Run("should return cookies with X-Requested-With header", func(t *testing.T) {
th := Setup().InitBasic()
defer th.TearDown()
testCases := []struct {
Description string
SiteURL string
ExpectedSetCookieHeaderRegexp string
}{
{"no subpath", "http://localhost:8065", "^MMAUTHTOKEN=[a-z0-9]+; Path=/"},
{"subpath", "http://localhost:8065/subpath", "^MMAUTHTOKEN=[a-z0-9]+; Path=/subpath"},
}
th.Client.HttpHeader[model.HEADER_REQUESTED_WITH] = model.HEADER_REQUESTED_WITH_XML
for _, tc := range testCases {
t.Run(tc.Description, func(t *testing.T) {
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.SiteURL = tc.SiteURL
user, resp := th.Client.Login(th.BasicUser.Email, th.BasicUser.Password)
sessionCookie := ""
userCookie := ""
csrfCookie := ""
for _, cookie := range resp.Header["Set-Cookie"] {
if match := regexp.MustCompile("^" + model.SESSION_COOKIE_TOKEN + "=([a-z0-9]+)").FindStringSubmatch(cookie); match != nil {
sessionCookie = match[1]
} else if match := regexp.MustCompile("^" + model.SESSION_COOKIE_USER + "=([a-z0-9]+)").FindStringSubmatch(cookie); match != nil {
userCookie = match[1]
} else if match := regexp.MustCompile("^" + model.SESSION_COOKIE_CSRF + "=([a-z0-9]+)").FindStringSubmatch(cookie); match != nil {
csrfCookie = match[1]
}
}
session, _ := th.App.GetSession(th.Client.AuthToken)
assert.Equal(t, th.Client.AuthToken, sessionCookie)
assert.Equal(t, user.Id, userCookie)
assert.Equal(t, session.GetCSRF(), csrfCookie)
})
t.Run("should not return cookies without X-Requested-With header", func(t *testing.T) {
th := Setup().InitBasic()
defer th.TearDown()
_, resp := th.Client.Login(th.BasicUser.Email, th.BasicUser.Password)
assert.Empty(t, resp.Header.Get("Set-Cookie"))
})
t.Run("should include subpath in path", func(t *testing.T) {
th := Setup().InitBasic()
defer th.TearDown()
th.Client.HttpHeader[model.HEADER_REQUESTED_WITH] = model.HEADER_REQUESTED_WITH_XML
testCases := []struct {
Description string
SiteURL string
ExpectedSetCookieHeaderRegexp string
}{
{"no subpath", "http://localhost:8065", "^MMAUTHTOKEN=[a-z0-9]+; Path=/"},
{"subpath", "http://localhost:8065/subpath", "^MMAUTHTOKEN=[a-z0-9]+; Path=/subpath"},
}
for _, tc := range testCases {
t.Run(tc.Description, func(t *testing.T) {
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.SiteURL = tc.SiteURL
})
user, resp := th.Client.Login(th.BasicUser.Email, th.BasicUser.Password)
CheckNoError(t, resp)
assert.Equal(t, user.Id, th.BasicUser.Id)
cookies := resp.Header.Get("Set-Cookie")
assert.Regexp(t, tc.ExpectedSetCookieHeaderRegexp, cookies)
})
user, resp := th.Client.Login(th.BasicUser.Email, th.BasicUser.Password)
CheckNoError(t, resp)
assert.Equal(t, user.Id, th.BasicUser.Id)
cookies := resp.Header.Get("Set-Cookie")
assert.Regexp(t, tc.ExpectedSetCookieHeaderRegexp, cookies)
})
}
}
})
}
func TestCBALogin(t *testing.T) {

View File

@@ -15,7 +15,7 @@ import (
type TokenLocation int
const (
TokenLocationNotFound = iota
TokenLocationNotFound TokenLocation = iota
TokenLocationHeader
TokenLocationCookie
TokenLocationQueryString

View File

@@ -126,7 +126,6 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
session := &model.Session{UserId: user.Id, Roles: user.GetRawRoles(), DeviceId: deviceId, IsOAuth: false}
session.GenerateCSRF()
maxAge := *a.Config().ServiceSettings.SessionLengthWebInDays * 60 * 60 * 24
if len(deviceId) > 0 {
session.SetExpireInDays(*a.Config().ServiceSettings.SessionLengthMobileInDays)
@@ -159,11 +158,26 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
w.Header().Set(model.HEADER_TOKEN, session.Token)
if pluginsEnvironment := a.GetPluginsEnvironment(); pluginsEnvironment != nil {
a.Srv.Go(func() {
pluginContext := a.PluginContext()
pluginsEnvironment.RunMultiPluginHook(func(hooks plugin.Hooks) bool {
hooks.UserHasLoggedIn(pluginContext, user)
return true
}, plugin.UserHasLoggedInId)
})
}
return session, nil
}
func (a *App) AttachSessionCookies(w http.ResponseWriter, r *http.Request, session *model.Session) {
secure := false
if GetProtocol(r) == "https" {
secure = true
}
maxAge := *a.Config().ServiceSettings.SessionLengthWebInDays * 60 * 60 * 24
domain := a.GetCookieDomain()
subpath, _ := utils.GetSubpathFromConfig(a.Config())
@@ -181,7 +195,7 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
userCookie := &http.Cookie{
Name: model.SESSION_COOKIE_USER,
Value: user.Id,
Value: session.UserId,
Path: subpath,
MaxAge: maxAge,
Expires: expiresAt,
@@ -202,18 +216,6 @@ func (a *App) DoLogin(w http.ResponseWriter, r *http.Request, user *model.User,
http.SetCookie(w, sessionCookie)
http.SetCookie(w, userCookie)
http.SetCookie(w, csrfCookie)
if pluginsEnvironment := a.GetPluginsEnvironment(); pluginsEnvironment != nil {
a.Srv.Go(func() {
pluginContext := a.PluginContext()
pluginsEnvironment.RunMultiPluginHook(func(hooks plugin.Hooks) bool {
hooks.UserHasLoggedIn(pluginContext, user)
return true
}, plugin.UserHasLoggedInId)
})
}
return session, nil
}
func GetProtocol(r *http.Request) string {

View File

@@ -121,30 +121,7 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
csrfCheckPassed := false
// CSRF Check
if c.Err == nil && tokenLocation == app.TokenLocationCookie && h.RequireSession && !h.TrustRequester && r.Method != "GET" {
csrfHeader := r.Header.Get(model.HEADER_CSRF_TOKEN)
if csrfHeader == session.GetCSRF() {
csrfCheckPassed = true
} else if r.Header.Get(model.HEADER_REQUESTED_WITH) == model.HEADER_REQUESTED_WITH_XML {
// ToDo(DSchalla) 2019/01/04: Remove after deprecation period and only allow CSRF Header (MM-13657)
csrfErrorMessage := "CSRF Header check failed for request - Please upgrade your web application or custom app to set a CSRF Header"
if *c.App.Config().ServiceSettings.ExperimentalStrictCSRFEnforcement {
c.Log.Warn(csrfErrorMessage)
} else {
c.Log.Debug(csrfErrorMessage)
csrfCheckPassed = true
}
}
if !csrfCheckPassed {
token = ""
c.App.Session = model.Session{}
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
}
}
h.checkCSRFToken(c, r, token, tokenLocation, session)
}
c.Log = c.App.Log.With(
@@ -216,3 +193,34 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
}
// checkCSRFToken performs a CSRF check on the provided request with the given CSRF token. Returns whether or not
// a CSRF check occurred and whether or not it succeeded.
func (h *Handler) checkCSRFToken(c *Context, r *http.Request, token string, tokenLocation app.TokenLocation, session *model.Session) (checked bool, passed bool) {
csrfCheckNeeded := c.Err == nil && tokenLocation == app.TokenLocationCookie && h.RequireSession && !h.TrustRequester && r.Method != "GET"
csrfCheckPassed := false
if csrfCheckNeeded {
csrfHeader := r.Header.Get(model.HEADER_CSRF_TOKEN)
if csrfHeader == session.GetCSRF() {
csrfCheckPassed = true
} else if r.Header.Get(model.HEADER_REQUESTED_WITH) == model.HEADER_REQUESTED_WITH_XML {
// ToDo(DSchalla) 2019/01/04: Remove after deprecation period and only allow CSRF Header (MM-13657)
csrfErrorMessage := "CSRF Header check failed for request - Please upgrade your web application or custom app to set a CSRF Header"
if *c.App.Config().ServiceSettings.ExperimentalStrictCSRFEnforcement {
c.Log.Warn(csrfErrorMessage)
} else {
c.Log.Debug(csrfErrorMessage)
csrfCheckPassed = true
}
}
if !csrfCheckPassed {
c.App.Session = model.Session{}
c.Err = model.NewAppError("ServeHTTP", "api.context.session_expired.app_error", nil, "token="+token+" Appears to be a CSRF attempt", http.StatusUnauthorized)
}
}
return csrfCheckNeeded, csrfCheckPassed
}

View File

@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"testing"
"github.com/mattermost/mattermost-server/app"
"github.com/mattermost/mattermost-server/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -337,3 +338,178 @@ func TestHandlerServeInvalidToken(t *testing.T) {
})
}
}
func TestCheckCSRFToken(t *testing.T) {
t.Run("should allow a POST request with a valid CSRF token header", func(t *testing.T) {
th := Setup()
defer th.TearDown()
h := &Handler{
RequireSession: true,
TrustRequester: false,
}
token := "token"
tokenLocation := app.TokenLocationCookie
c := &Context{
App: th.App,
}
r, _ := http.NewRequest(http.MethodPost, "", nil)
r.Header.Set(model.HEADER_CSRF_TOKEN, token)
session := &model.Session{
Props: map[string]string{
"csrf": token,
},
}
checked, passed := h.checkCSRFToken(c, r, token, tokenLocation, session)
assert.True(t, checked)
assert.True(t, passed)
assert.Nil(t, c.Err)
})
t.Run("should allow a POST request with an X-Requested-With header", func(t *testing.T) {
th := Setup()
defer th.TearDown()
h := &Handler{
RequireSession: true,
TrustRequester: false,
}
token := "token"
tokenLocation := app.TokenLocationCookie
c := &Context{
App: th.App,
Log: th.App.Log,
}
r, _ := http.NewRequest(http.MethodPost, "", nil)
r.Header.Set(model.HEADER_REQUESTED_WITH, model.HEADER_REQUESTED_WITH_XML)
session := &model.Session{
Props: map[string]string{
"csrf": token,
},
}
checked, passed := h.checkCSRFToken(c, r, token, tokenLocation, session)
assert.True(t, checked)
assert.True(t, passed)
assert.Nil(t, c.Err)
})
t.Run("should not allow a POST request with an X-Requested-With header with strict CSRF enforcement enabled", func(t *testing.T) {
th := Setup()
defer th.TearDown()
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.ExperimentalStrictCSRFEnforcement = true
})
h := &Handler{
RequireSession: true,
TrustRequester: false,
}
token := "token"
tokenLocation := app.TokenLocationCookie
c := &Context{
App: th.App,
Log: th.App.Log,
}
r, _ := http.NewRequest(http.MethodPost, "", nil)
r.Header.Set(model.HEADER_REQUESTED_WITH, model.HEADER_REQUESTED_WITH_XML)
session := &model.Session{
Props: map[string]string{
"csrf": token,
},
}
checked, passed := h.checkCSRFToken(c, r, token, tokenLocation, session)
assert.True(t, checked)
assert.False(t, passed)
assert.NotNil(t, c.Err)
})
t.Run("should not allow a POST request without either header", func(t *testing.T) {
th := Setup()
defer th.TearDown()
h := &Handler{
RequireSession: true,
TrustRequester: false,
}
token := "token"
tokenLocation := app.TokenLocationCookie
c := &Context{
App: th.App,
}
r, _ := http.NewRequest(http.MethodPost, "", nil)
session := &model.Session{
Props: map[string]string{
"csrf": token,
},
}
checked, passed := h.checkCSRFToken(c, r, token, tokenLocation, session)
assert.True(t, checked)
assert.False(t, passed)
assert.NotNil(t, c.Err)
})
t.Run("should not check GET requests", func(t *testing.T) {
th := Setup()
defer th.TearDown()
h := &Handler{
RequireSession: true,
TrustRequester: false,
}
token := "token"
tokenLocation := app.TokenLocationCookie
c := &Context{
App: th.App,
}
r, _ := http.NewRequest(http.MethodGet, "", nil)
checked, passed := h.checkCSRFToken(c, r, token, tokenLocation, nil)
assert.False(t, checked)
assert.False(t, passed)
assert.Nil(t, c.Err)
})
t.Run("should not check a request passing the auth token in a header", func(t *testing.T) {
th := Setup()
defer th.TearDown()
h := &Handler{
RequireSession: true,
TrustRequester: false,
}
token := "token"
tokenLocation := app.TokenLocationHeader
c := &Context{
App: th.App,
}
r, _ := http.NewRequest(http.MethodPost, "", nil)
checked, passed := h.checkCSRFToken(c, r, token, tokenLocation, nil)
assert.False(t, checked)
assert.False(t, passed)
assert.Nil(t, c.Err)
})
}

View File

@@ -132,6 +132,8 @@ func completeSaml(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
c.App.AttachSessionCookies(w, r, session)
c.App.Session = *session
if val, ok := relayProps["redirect_to"]; ok {