diff --git a/server/channels/api4/user_test.go b/server/channels/api4/user_test.go index 75ad046205..a9b9d2ea9e 100644 --- a/server/channels/api4/user_test.go +++ b/server/channels/api4/user_test.go @@ -4328,6 +4328,23 @@ func TestSwitchAccount(t *testing.T) { _, resp, err = th.Client.SwitchAccountType(context.Background(), sr) require.Error(t, err) CheckUnauthorizedStatus(t, resp) + + sr = &model.SwitchRequest{ + CurrentService: model.UserAuthServiceEmail, + NewService: model.UserAuthServiceSaml, + Email: th.BasicUser.Email, + Password: th.BasicUser.Password, + } + + link, _, err = th.Client.SwitchAccountType(context.Background(), sr) + require.NoError(t, err) + + values, parseErr := url.ParseQuery(link) + require.NoError(t, parseErr) + + appToken, tokenErr := th.App.Srv().Store().Token().GetByToken(values.Get("email_token")) + require.NoError(t, tokenErr) + require.Equal(t, th.BasicUser.Email, appToken.Extra) } func assertToken(t *testing.T, th *TestHelper, token *model.UserAccessToken, expectedUserId string) { diff --git a/server/channels/app/app_iface.go b/server/channels/app/app_iface.go index 9c22f83eff..0a7c1013c9 100644 --- a/server/channels/app/app_iface.go +++ b/server/channels/app/app_iface.go @@ -527,6 +527,7 @@ type AppIface interface { CreatePostMissingChannel(c request.CTX, post *model.Post, triggerWebhooks bool, setOnline bool) (*model.Post, *model.AppError) CreateRetentionPolicy(policy *model.RetentionPolicyWithTeamAndChannelIDs) (*model.RetentionPolicyWithTeamAndChannelCounts, *model.AppError) CreateRole(role *model.Role) (*model.Role, *model.AppError) + CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) CreateScheme(scheme *model.Scheme) (*model.Scheme, *model.AppError) CreateSession(c request.CTX, session *model.Session) (*model.Session, *model.AppError) CreateSidebarCategory(c request.CTX, userID, teamID string, newCategory *model.SidebarCategoryWithChannels) (*model.SidebarCategoryWithChannels, *model.AppError) @@ -785,6 +786,7 @@ type AppIface interface { GetRoleByName(ctx context.Context, name string) (*model.Role, *model.AppError) GetRolesByNames(names []string) ([]*model.Role, *model.AppError) GetSamlCertificateStatus() *model.SamlCertificateStatus + GetSamlEmailToken(token string) (*model.Token, *model.AppError) GetSamlMetadata(c request.CTX) (string, *model.AppError) GetSamlMetadataFromIdp(idpMetadataURL string) (*model.SamlMetadataResponse, *model.AppError) GetSanitizeOptions(asAdmin bool) map[string]bool diff --git a/server/channels/app/oauth.go b/server/channels/app/oauth.go index a2e8d704cc..191b2b4262 100644 --- a/server/channels/app/oauth.go +++ b/server/channels/app/oauth.go @@ -958,7 +958,12 @@ func (a *App) SwitchEmailToOAuth(c request.CTX, w http.ResponseWriter, r *http.R stateProps["email"] = email if service == model.UserAuthServiceSaml { - return a.GetSiteURL() + "/login/sso/saml?action=" + model.OAuthActionEmailToSSO + "&email=" + utils.URLEncode(email), nil + samlToken, samlErr := a.CreateSamlRelayToken(email) + if samlErr != nil { + return "", samlErr + } + + return a.GetSiteURL() + "/login/sso/saml?action=" + model.OAuthActionEmailToSSO + "&email_token=" + utils.URLEncode(samlToken.Token), nil } authURL, err := a.GetAuthorizationCode(c, w, r, service, stateProps, "") diff --git a/server/channels/app/opentracing/opentracing_layer.go b/server/channels/app/opentracing/opentracing_layer.go index 4defcd1e7d..1059e9bf54 100644 --- a/server/channels/app/opentracing/opentracing_layer.go +++ b/server/channels/app/opentracing/opentracing_layer.go @@ -2504,6 +2504,28 @@ func (a *OpenTracingAppLayer) CreateRole(role *model.Role) (*model.Role, *model. return resultVar0, resultVar1 } +func (a *OpenTracingAppLayer) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CreateSamlRelayToken") + + a.ctx = newCtx + a.app.Srv().Store().SetContext(newCtx) + defer func() { + a.app.Srv().Store().SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.CreateSamlRelayToken(extra) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) CreateScheme(scheme *model.Scheme) (*model.Scheme, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.CreateScheme") @@ -9038,6 +9060,28 @@ func (a *OpenTracingAppLayer) GetSamlCertificateStatus() *model.SamlCertificateS return resultVar0 } +func (a *OpenTracingAppLayer) GetSamlEmailToken(token string) (*model.Token, *model.AppError) { + origCtx := a.ctx + span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSamlEmailToken") + + a.ctx = newCtx + a.app.Srv().Store().SetContext(newCtx) + defer func() { + a.app.Srv().Store().SetContext(origCtx) + a.ctx = origCtx + }() + + defer span.Finish() + resultVar0, resultVar1 := a.app.GetSamlEmailToken(token) + + if resultVar1 != nil { + span.LogFields(spanlog.Error(resultVar1)) + ext.Error.Set(span, true) + } + + return resultVar0, resultVar1 +} + func (a *OpenTracingAppLayer) GetSamlMetadata(c request.CTX) (string, *model.AppError) { origCtx := a.ctx span, newCtx := tracing.StartSpanWithParentByContext(a.ctx, "app.GetSamlMetadata") diff --git a/server/channels/app/saml.go b/server/channels/app/saml.go index 27462a7459..80f883effc 100644 --- a/server/channels/app/saml.go +++ b/server/channels/app/saml.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/pem" "encoding/xml" + "errors" "fmt" "io" "mime/multipart" @@ -296,3 +297,32 @@ func (a *App) ResetSamlAuthDataToEmail(includeDeleted bool, dryRun bool, userIDs } return } + +func (a *App) CreateSamlRelayToken(extra string) (*model.Token, *model.AppError) { + token := model.NewToken(model.TokenTypeSaml, extra) + + if err := a.Srv().Store().Token().Save(token); err != nil { + var appErr *model.AppError + switch { + case errors.As(err, &appErr): + return nil, appErr + default: + return nil, model.NewAppError("CreateSamlRelayToken", "app.recover.save.app_error", nil, "", http.StatusInternalServerError).Wrap(err) + } + } + + return token, nil +} + +func (a *App) GetSamlEmailToken(token string) (*model.Token, *model.AppError) { + mToken, err := a.Srv().Store().Token().GetByToken(token) + if err != nil { + return nil, model.NewAppError("GetSamlEmailToken", "api.saml.invalid_email_token.app_error", nil, "", http.StatusBadRequest).Wrap(err) + } + + if mToken.Type != model.TokenTypeSaml { + return nil, model.NewAppError("GetSamlEmailToken", "api.saml.invalid_email_token.app_error", nil, "", http.StatusBadRequest) + } + + return mToken, nil +} diff --git a/server/channels/web/saml.go b/server/channels/web/saml.go index 6a4eb58e73..4ebe72b113 100644 --- a/server/channels/web/saml.go +++ b/server/channels/web/saml.go @@ -47,7 +47,7 @@ func loginWithSaml(c *Context, w http.ResponseWriter, r *http.Request) { relayProps["team_id"] = teamId relayProps["action"] = action if action == model.OAuthActionEmailToSSO { - relayProps["email"] = r.URL.Query().Get("email") + relayProps["email_token"] = r.URL.Query().Get("email_token") } } diff --git a/server/i18n/en.json b/server/i18n/en.json index dbf1b05621..195bd1ff93 100644 --- a/server/i18n/en.json +++ b/server/i18n/en.json @@ -2718,6 +2718,10 @@ "id": "api.roles.patch_roles.not_allowed_permission.error", "translation": "One or more of the following permissions that you are trying to add or remove is not allowed" }, + { + "id": "api.saml.invalid_email_token.app_error", + "translation": "Invalid email_token" + }, { "id": "api.scheme.create_scheme.license.error", "translation": "Your license does not support creating permissions schemes." diff --git a/server/public/model/token.go b/server/public/model/token.go index 90fc729fb6..c0c73e367a 100644 --- a/server/public/model/token.go +++ b/server/public/model/token.go @@ -11,6 +11,7 @@ const ( TokenSize = 64 MaxTokenExipryTime = 1000 * 60 * 60 * 48 // 48 hour TokenTypeOAuth = "oauth" + TokenTypeSaml = "saml" ) type Token struct {