diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 0bc56314ca..48e1603cc8 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -43,7 +43,7 @@ jobs: go run main.go init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml go run main.go setup --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml - name: Run integration tests - run: go test -tags=integration -race -p 1 -v -coverprofile=profile.cov -coverpkg=./internal/...,./cmd/... ./internal/integration ./internal/api/grpc/... ./internal/notification/handlers/... + run: go test -tags=integration -race -p 1 -v -coverprofile=profile.cov -coverpkg=./internal/...,./cmd/... ./internal/integration ./internal/api/grpc/... ./internal/notification/handlers/... ./internal/api/oidc - name: Publish go coverage uses: codecov/codecov-action@v3.1.0 with: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 32de6b633c..e076a00b7d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -224,7 +224,7 @@ export INTEGRATION_DB_FLAVOR="cockroach" ZITADEL_MASTERKEY="MasterkeyNeedsToHave docker compose -f internal/integration/config/docker-compose.yaml up --wait ${INTEGRATION_DB_FLAVOR} go run main.go init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml go run main.go setup --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml -go test -count 1 -tags=integration -race -p 1 ./internal/integration ./internal/api/grpc/... +go test -count 1 -tags=integration -race -p 1 ./internal/integration ./internal/api/grpc/... ./internal/api/oidc docker compose -f internal/integration/config/docker-compose.yaml down ``` diff --git a/build/zitadel/generate-grpc.sh b/build/zitadel/generate-grpc.sh index 865952fb66..d53654c96c 100755 --- a/build/zitadel/generate-grpc.sh +++ b/build/zitadel/generate-grpc.sh @@ -108,4 +108,16 @@ protoc \ --validate_out=lang=go:${GOPATH}/src \ ${PROTO_PATH}/settings/v2alpha/settings_service.proto +protoc \ + -I=/proto/include \ + --grpc-gateway_out ${GOPATH}/src \ + --grpc-gateway_opt logtostderr=true \ + --grpc-gateway_opt allow_delete_body=true \ + --openapiv2_out ${OPENAPI_PATH} \ + --openapiv2_opt logtostderr=true \ + --openapiv2_opt allow_delete_body=true \ + --zitadel_out=${GOPATH}/src \ + --validate_out=lang=go:${GOPATH}/src \ + ${PROTO_PATH}/oidc/v2alpha/oidc_service.proto + echo "done generating grpc" diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 5b6806a262..9d599fe04a 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -270,6 +270,7 @@ OIDC: Path: /oauth/v2/keys DeviceAuth: Path: /oauth/v2/device_authorization + DefaultLoginURLV2: "/login?authRequest=" SAML: ProviderConfig: diff --git a/cmd/setup/03.go b/cmd/setup/03.go index ca1e0e21ca..867e9d4981 100644 --- a/cmd/setup/03.go +++ b/cmd/setup/03.go @@ -88,6 +88,9 @@ func (mig *FirstInstance) Execute(ctx context.Context) error { nil, nil, nil, + 0, + 0, + 0, ) if err != nil { return err diff --git a/cmd/setup/config_change.go b/cmd/setup/config_change.go index 14a6849cea..7f35f7ee74 100644 --- a/cmd/setup/config_change.go +++ b/cmd/setup/config_change.go @@ -53,6 +53,9 @@ func (mig *externalConfigChange) Execute(ctx context.Context) error { nil, nil, nil, + 0, + 0, + 0, ) if err != nil { diff --git a/cmd/start/start.go b/cmd/start/start.go index 1a90cda42a..7d9c87d7c0 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -32,6 +32,7 @@ import ( "github.com/zitadel/zitadel/internal/api/grpc/admin" "github.com/zitadel/zitadel/internal/api/grpc/auth" "github.com/zitadel/zitadel/internal/api/grpc/management" + oidc_v2 "github.com/zitadel/zitadel/internal/api/grpc/oidc/v2" "github.com/zitadel/zitadel/internal/api/grpc/session/v2" "github.com/zitadel/zitadel/internal/api/grpc/settings/v2" "github.com/zitadel/zitadel/internal/api/grpc/system" @@ -192,6 +193,9 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error &http.Client{}, permissionCheck, sessionTokenVerifier, + config.OIDC.DefaultAccessTokenLifetime, + config.OIDC.DefaultRefreshTokenExpiration, + config.OIDC.DefaultRefreshTokenIdleExpiration, ) if err != nil { return fmt.Errorf("cannot start commands: %w", err) @@ -344,6 +348,7 @@ func startAPIs( if err := apis.RegisterService(ctx, session.CreateServer(commands, queries, permissionCheck)); err != nil { return err } + if err := apis.RegisterService(ctx, settings.CreateServer(commands, queries, config.ExternalSecure)); err != nil { return err } @@ -397,6 +402,11 @@ func startAPIs( apis.RegisterHandlerOnPrefix(login.HandlerPrefix, l.Handler()) apis.HandleFunc(login.EndpointDeviceAuth, login.RedirectDeviceAuthToPrefix) + // After OIDC provider so that the callback endpoint can be used + if err := apis.RegisterService(ctx, oidc_v2.CreateServer(commands, queries, oidcProvider, config.ExternalSecure)); err != nil { + return err + } + // handle grpc at last to be able to handle the root, because grpc and gateway require a lot of different prefixes apis.RouteGRPC() return nil diff --git a/docs/docusaurus.config.js b/docs/docusaurus.config.js index ccac55b08e..c7131a2267 100644 --- a/docs/docusaurus.config.js +++ b/docs/docusaurus.config.js @@ -274,6 +274,13 @@ module.exports = { groupPathsBy: "tag", }, }, + oidc: { + specPath: ".artifacts/openapi/zitadel/oidc/v2alpha/oidc_service.swagger.json", + outputDir: "docs/apis/resources/oidc_service", + sidebarOptions: { + groupPathsBy: "tag", + }, + }, settings: { specPath: ".artifacts/openapi/zitadel/settings/v2alpha/settings_service.swagger.json", outputDir: "docs/apis/resources/settings_service", diff --git a/docs/sidebars.js b/docs/sidebars.js index 1d3ce7b10c..98c752eeb5 100644 --- a/docs/sidebars.js +++ b/docs/sidebars.js @@ -484,6 +484,20 @@ module.exports = { }, items: require("./docs/apis/resources/session_service/sidebar.js"), }, + { + type: "category", + label: "OIDC lifecycle (Alpha)", + link: { + type: "generated-index", + title: "OIDC service API (Alpha)", + slug: "/apis/resources/oidc_service", + description: + "Get OIDC Auth Request details and create callback URLs.\n"+ + "\n"+ + "This project is in alpha state. It can AND will continue breaking until the services provide the same functionality as the current login.", + }, + items: require("./docs/apis/resources/oidc_service/sidebar.js"), + }, { type: "category", label: "Settings lifecycle (alpha)", diff --git a/internal/api/authz/context_mock.go b/internal/api/authz/context_mock.go index e7b5c4017e..6badf15862 100644 --- a/internal/api/authz/context_mock.go +++ b/internal/api/authz/context_mock.go @@ -4,11 +4,11 @@ import "context" func NewMockContext(instanceID, orgID, userID string) context.Context { ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID}) - return context.WithValue(ctx, instanceKey, instanceID) + return context.WithValue(ctx, instanceKey, &instance{id: instanceID}) } func NewMockContextWithPermissions(instanceID, orgID, userID string, permissions []string) context.Context { ctx := context.WithValue(context.Background(), dataKey, CtxData{UserID: userID, OrgID: orgID}) - ctx = context.WithValue(ctx, instanceKey, instanceID) + ctx = context.WithValue(ctx, instanceKey, &instance{id: instanceID}) return context.WithValue(ctx, requestPermissionsKey, permissions) } diff --git a/internal/api/grpc/oidc/v2/oidc.go b/internal/api/grpc/oidc/v2/oidc.go new file mode 100644 index 0000000000..4675dfc6f1 --- /dev/null +++ b/internal/api/grpc/oidc/v2/oidc.go @@ -0,0 +1,204 @@ +package oidc + +import ( + "context" + + "github.com/zitadel/logging" + "github.com/zitadel/oidc/v2/pkg/op" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/object/v2" + "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" +) + +func (s *Server) GetAuthRequest(ctx context.Context, req *oidc_pb.GetAuthRequestRequest) (*oidc_pb.GetAuthRequestResponse, error) { + authRequest, err := s.query.AuthRequestByID(ctx, true, req.GetAuthRequestId(), true) + if err != nil { + logging.WithError(err).Error("query authRequest by ID") + return nil, err + } + return &oidc_pb.GetAuthRequestResponse{ + AuthRequest: authRequestToPb(authRequest), + }, nil +} + +func authRequestToPb(a *query.AuthRequest) *oidc_pb.AuthRequest { + pba := &oidc_pb.AuthRequest{ + Id: a.ID, + CreationDate: timestamppb.New(a.CreationDate), + ClientId: a.ClientID, + Scope: a.Scope, + RedirectUri: a.RedirectURI, + Prompt: promptsToPb(a.Prompt), + UiLocales: a.UiLocales, + LoginHint: a.LoginHint, + HintUserId: a.HintUserID, + } + if a.MaxAge != nil { + pba.MaxAge = durationpb.New(*a.MaxAge) + } + return pba +} + +func promptsToPb(promps []domain.Prompt) []oidc_pb.Prompt { + out := make([]oidc_pb.Prompt, len(promps)) + for i, p := range promps { + out[i] = promptToPb(p) + } + return out +} + +func promptToPb(p domain.Prompt) oidc_pb.Prompt { + switch p { + case domain.PromptUnspecified: + return oidc_pb.Prompt_PROMPT_UNSPECIFIED + case domain.PromptNone: + return oidc_pb.Prompt_PROMPT_NONE + case domain.PromptLogin: + return oidc_pb.Prompt_PROMPT_LOGIN + case domain.PromptConsent: + return oidc_pb.Prompt_PROMPT_CONSENT + case domain.PromptSelectAccount: + return oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT + case domain.PromptCreate: + return oidc_pb.Prompt_PROMPT_CREATE + default: + return oidc_pb.Prompt_PROMPT_UNSPECIFIED + } +} + +func (s *Server) CreateCallback(ctx context.Context, req *oidc_pb.CreateCallbackRequest) (*oidc_pb.CreateCallbackResponse, error) { + switch v := req.GetCallbackKind().(type) { + case *oidc_pb.CreateCallbackRequest_Error: + return s.failAuthRequest(ctx, req.GetAuthRequestId(), v.Error) + case *oidc_pb.CreateCallbackRequest_Session: + return s.linkSessionToAuthRequest(ctx, req.GetAuthRequestId(), v.Session) + default: + return nil, errors.ThrowUnimplementedf(nil, "OIDCv2-zee7A", "verification oneOf %T in method CreateCallback not implemented", v) + } +} + +func (s *Server) failAuthRequest(ctx context.Context, authRequestID string, ae *oidc_pb.AuthorizationError) (*oidc_pb.CreateCallbackResponse, error) { + details, aar, err := s.command.FailAuthRequest(ctx, authRequestID, errorReasonToDomain(ae.GetError())) + if err != nil { + return nil, err + } + authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} + callback, err := oidc.CreateErrorCallbackURL(authReq, errorReasonToOIDC(ae.GetError()), ae.GetErrorDescription(), ae.GetErrorUri(), s.op) + if err != nil { + return nil, err + } + return &oidc_pb.CreateCallbackResponse{ + Details: object.DomainToDetailsPb(details), + CallbackUrl: callback, + }, nil +} + +func (s *Server) linkSessionToAuthRequest(ctx context.Context, authRequestID string, session *oidc_pb.Session) (*oidc_pb.CreateCallbackResponse, error) { + details, aar, err := s.command.LinkSessionToAuthRequest(ctx, authRequestID, session.GetSessionId(), session.GetSessionToken(), true) + if err != nil { + return nil, err + } + authReq := &oidc.AuthRequestV2{CurrentAuthRequest: aar} + ctx = op.ContextWithIssuer(ctx, http.BuildOrigin(authz.GetInstance(ctx).RequestedHost(), s.externalSecure)) + var callback string + if aar.ResponseType == domain.OIDCResponseTypeCode { + callback, err = oidc.CreateCodeCallbackURL(ctx, authReq, s.op) + } else { + callback, err = oidc.CreateTokenCallbackURL(ctx, authReq, s.op) + } + if err != nil { + return nil, err + } + return &oidc_pb.CreateCallbackResponse{ + Details: object.DomainToDetailsPb(details), + CallbackUrl: callback, + }, nil +} + +func errorReasonToDomain(errorReason oidc_pb.ErrorReason) domain.OIDCErrorReason { + switch errorReason { + case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED: + return domain.OIDCErrorReasonUnspecified + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST: + return domain.OIDCErrorReasonInvalidRequest + case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT: + return domain.OIDCErrorReasonUnauthorizedClient + case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED: + return domain.OIDCErrorReasonAccessDenied + case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE: + return domain.OIDCErrorReasonUnsupportedResponseType + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE: + return domain.OIDCErrorReasonInvalidScope + case oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR: + return domain.OIDCErrorReasonServerError + case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE: + return domain.OIDCErrorReasonTemporaryUnavailable + case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED: + return domain.OIDCErrorReasonInteractionRequired + case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED: + return domain.OIDCErrorReasonLoginRequired + case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED: + return domain.OIDCErrorReasonAccountSelectionRequired + case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED: + return domain.OIDCErrorReasonConsentRequired + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI: + return domain.OIDCErrorReasonInvalidRequestURI + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT: + return domain.OIDCErrorReasonInvalidRequestObject + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED: + return domain.OIDCErrorReasonRequestNotSupported + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED: + return domain.OIDCErrorReasonRequestURINotSupported + case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED: + return domain.OIDCErrorReasonRegistrationNotSupported + default: + return domain.OIDCErrorReasonUnspecified + } +} + +func errorReasonToOIDC(reason oidc_pb.ErrorReason) string { + switch reason { + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST: + return "invalid_request" + case oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT: + return "unauthorized_client" + case oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED: + return "access_denied" + case oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE: + return "unsupported_response_type" + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE: + return "invalid_scope" + case oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE: + return "temporarily_unavailable" + case oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED: + return "interaction_required" + case oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED: + return "login_required" + case oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED: + return "account_selection_required" + case oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED: + return "consent_required" + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI: + return "invalid_request_uri" + case oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT: + return "invalid_request_object" + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED: + return "request_not_supported" + case oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED: + return "request_uri_not_supported" + case oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED: + return "registration_not_supported" + case oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR: + fallthrough + default: + return "server_error" + } +} diff --git a/internal/api/grpc/oidc/v2/oidc_integration_test.go b/internal/api/grpc/oidc/v2/oidc_integration_test.go new file mode 100644 index 0000000000..78d04f64be --- /dev/null +++ b/internal/api/grpc/oidc/v2/oidc_integration_test.go @@ -0,0 +1,249 @@ +//go:build integration + +package oidc_test + +import ( + "context" + "net/url" + "os" + "regexp" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/integration" + object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" + session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha" + user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha" +) + +var ( + CTX context.Context + Tester *integration.Tester + Client oidc_pb.OIDCServiceClient + User *user.AddHumanUserResponse +) + +const ( + redirectURI = "oidcIntegrationTest://callback" + redirectURIImplicit = "http://localhost:9999/callback" +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, _, cancel := integration.Contexts(5 * time.Minute) + defer cancel() + + Tester = integration.NewTester(ctx) + defer Tester.Done() + Client = Tester.Client.OIDCv2 + + CTX = Tester.WithAuthorization(ctx, integration.OrgOwner) + User = Tester.CreateHumanUser(CTX) + return m.Run() + }()) +} + +func TestServer_GetAuthRequest(t *testing.T) { + client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI) + require.NoError(t, err) + authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI) + require.NoError(t, err) + now := time.Now() + + tests := []struct { + name string + AuthRequestID string + want *oidc_pb.GetAuthRequestResponse + wantErr bool + }{ + { + name: "Not found", + AuthRequestID: "123", + wantErr: true, + }, + { + name: "success", + AuthRequestID: authRequestID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Client.GetAuthRequest(CTX, &oidc_pb.GetAuthRequestRequest{ + AuthRequestId: tt.AuthRequestID, + }) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + authRequest := got.GetAuthRequest() + assert.NotNil(t, authRequest) + assert.Equal(t, authRequestID, authRequest.GetId()) + assert.WithinRange(t, authRequest.GetCreationDate().AsTime(), now.Add(-time.Second), now.Add(time.Second)) + assert.Contains(t, authRequest.GetScope(), "openid") + }) + } +} + +func TestServer_CreateCallback(t *testing.T) { + client, err := Tester.CreateOIDCNativeClient(CTX, redirectURI) + require.NoError(t, err) + sessionResp, err := Tester.Client.SessionV2.CreateSession(CTX, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{ + UserId: Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, + }, + }, + }, + }) + require.NoError(t, err) + + tests := []struct { + name string + req *oidc_pb.CreateCallbackRequest + AuthError string + want *oidc_pb.CreateCallbackResponse + wantURL *url.URL + wantErr bool + }{ + { + name: "Not found", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: "123", + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + }, + }, + wantErr: true, + }, + { + name: "session not found", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: "foo", + SessionToken: "bar", + }, + }, + }, + wantErr: true, + }, + { + name: "session token invalid", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: "bar", + }, + }, + }, + wantErr: true, + }, + { + name: "fail callback", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Error{ + Error: &oidc_pb.AuthorizationError{ + Error: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED, + ErrorDescription: gu.Ptr("nope"), + ErrorUri: gu.Ptr("https://example.com/docs"), + }, + }, + }, + want: &oidc_pb.CreateCallbackResponse{ + CallbackUrl: regexp.QuoteMeta(`oidcintegrationtest://callback?error=access_denied&error_description=nope&error_uri=https%3A%2F%2Fexample.com%2Fdocs&state=state`), + Details: &object.Details{ + ResourceOwner: Tester.Instance.InstanceID(), + }, + }, + wantErr: false, + }, + { + name: "code callback", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + authRequestID, err := Tester.CreateOIDCAuthRequest(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURI) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + }, + }, + want: &oidc_pb.CreateCallbackResponse{ + CallbackUrl: `oidcintegrationtest:\/\/callback\?code=(.*)&state=state`, + Details: &object.Details{ + ResourceOwner: Tester.Instance.InstanceID(), + }, + }, + wantErr: false, + }, + { + name: "implicit", + req: &oidc_pb.CreateCallbackRequest{ + AuthRequestId: func() string { + client, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit) + require.NoError(t, err) + authRequestID, err := Tester.CreateOIDCAuthRequestImplicit(client.GetClientId(), Tester.Users[integration.FirstInstanceUsersKey][integration.OrgOwner].ID, redirectURIImplicit) + require.NoError(t, err) + return authRequestID + }(), + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionResp.GetSessionId(), + SessionToken: sessionResp.GetSessionToken(), + }, + }, + }, + want: &oidc_pb.CreateCallbackResponse{ + CallbackUrl: `http:\/\/localhost:9999\/callback#access_token=(.*)&expires_in=(.*)&id_token=(.*)&state=state&token_type=Bearer`, + Details: &object.Details{ + ResourceOwner: Tester.Instance.InstanceID(), + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := Client.CreateCallback(CTX, tt.req) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + integration.AssertDetails(t, tt.want, got) + if tt.want != nil { + assert.Regexp(t, regexp.MustCompile(tt.want.CallbackUrl), got.GetCallbackUrl()) + } + }) + } +} diff --git a/internal/api/grpc/oidc/v2/oidc_test.go b/internal/api/grpc/oidc/v2/oidc_test.go new file mode 100644 index 0000000000..91d795a00a --- /dev/null +++ b/internal/api/grpc/oidc/v2/oidc_test.go @@ -0,0 +1,150 @@ +package oidc + +import ( + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" +) + +func Test_authRequestToPb(t *testing.T) { + now := time.Now() + arg := &query.AuthRequest{ + ID: "authID", + CreationDate: now, + ClientID: "clientID", + Scope: []string{"a", "b", "c"}, + RedirectURI: "callbackURI", + Prompt: []domain.Prompt{ + domain.PromptUnspecified, + domain.PromptNone, + domain.PromptLogin, + domain.PromptConsent, + domain.PromptSelectAccount, + domain.PromptCreate, + 999, + }, + UiLocales: []string{"en", "fi"}, + LoginHint: gu.Ptr("foo@bar.com"), + MaxAge: gu.Ptr(time.Minute), + HintUserID: gu.Ptr("userID"), + } + want := &oidc_pb.AuthRequest{ + Id: "authID", + CreationDate: timestamppb.New(now), + ClientId: "clientID", + RedirectUri: "callbackURI", + Prompt: []oidc_pb.Prompt{ + oidc_pb.Prompt_PROMPT_UNSPECIFIED, + oidc_pb.Prompt_PROMPT_NONE, + oidc_pb.Prompt_PROMPT_LOGIN, + oidc_pb.Prompt_PROMPT_CONSENT, + oidc_pb.Prompt_PROMPT_SELECT_ACCOUNT, + oidc_pb.Prompt_PROMPT_CREATE, + oidc_pb.Prompt_PROMPT_UNSPECIFIED, + }, + UiLocales: []string{"en", "fi"}, + Scope: []string{"a", "b", "c"}, + LoginHint: gu.Ptr("foo@bar.com"), + MaxAge: durationpb.New(time.Minute), + HintUserId: gu.Ptr("userID"), + } + got := authRequestToPb(arg) + if !proto.Equal(want, got) { + t.Errorf("authRequestToPb() =\n%v\nwant\n%v\n", got, want) + } +} + +func Test_errorReasonToOIDC(t *testing.T) { + tests := []struct { + reason oidc_pb.ErrorReason + want string + }{ + { + reason: oidc_pb.ErrorReason_ERROR_REASON_UNSPECIFIED, + want: "server_error", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST, + want: "invalid_request", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_UNAUTHORIZED_CLIENT, + want: "unauthorized_client", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_ACCESS_DENIED, + want: "access_denied", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE, + want: "unsupported_response_type", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_SCOPE, + want: "invalid_scope", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_SERVER_ERROR, + want: "server_error", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_TEMPORARY_UNAVAILABLE, + want: "temporarily_unavailable", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_INTERACTION_REQUIRED, + want: "interaction_required", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_LOGIN_REQUIRED, + want: "login_required", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_ACCOUNT_SELECTION_REQUIRED, + want: "account_selection_required", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_CONSENT_REQUIRED, + want: "consent_required", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_URI, + want: "invalid_request_uri", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_INVALID_REQUEST_OBJECT, + want: "invalid_request_object", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_NOT_SUPPORTED, + want: "request_not_supported", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_REQUEST_URI_NOT_SUPPORTED, + want: "request_uri_not_supported", + }, + { + reason: oidc_pb.ErrorReason_ERROR_REASON_REGISTRATION_NOT_SUPPORTED, + want: "registration_not_supported", + }, + { + reason: 99999, + want: "server_error", + }, + } + for _, tt := range tests { + t.Run(tt.reason.String(), func(t *testing.T) { + got := errorReasonToOIDC(tt.reason) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/api/grpc/oidc/v2/server.go b/internal/api/grpc/oidc/v2/server.go new file mode 100644 index 0000000000..f15401d1bc --- /dev/null +++ b/internal/api/grpc/oidc/v2/server.go @@ -0,0 +1,59 @@ +package oidc + +import ( + "github.com/zitadel/oidc/v2/pkg/op" + "google.golang.org/grpc" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/grpc/server" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/query" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" +) + +var _ oidc_pb.OIDCServiceServer = (*Server)(nil) + +type Server struct { + oidc_pb.UnimplementedOIDCServiceServer + command *command.Commands + query *query.Queries + + op op.OpenIDProvider + externalSecure bool +} + +type Config struct{} + +func CreateServer( + command *command.Commands, + query *query.Queries, + op op.OpenIDProvider, + externalSecure bool, +) *Server { + return &Server{ + command: command, + query: query, + op: op, + externalSecure: externalSecure, + } +} + +func (s *Server) RegisterServer(grpcServer *grpc.Server) { + oidc_pb.RegisterOIDCServiceServer(grpcServer, s) +} + +func (s *Server) AppName() string { + return oidc_pb.OIDCService_ServiceDesc.ServiceName +} + +func (s *Server) MethodPrefix() string { + return oidc_pb.OIDCService_ServiceDesc.ServiceName +} + +func (s *Server) AuthMethods() authz.MethodMapping { + return oidc_pb.OIDCService_AuthMethods +} + +func (s *Server) RegisterGateway() server.RegisterGatewayFunc { + return oidc_pb.RegisterOIDCServiceHandler +} diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index fa29183536..6d63981726 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -18,11 +18,10 @@ import ( ) var ( - CTX context.Context - Tester *integration.Tester - Client session.SessionServiceClient - User *user.AddHumanUserResponse - GenericOAuthIDPID string + CTX context.Context + Tester *integration.Tester + Client session.SessionServiceClient + User *user.AddHumanUserResponse ) func TestMain(m *testing.M) { diff --git a/internal/api/grpc/user/v2/user_integration_test.go b/internal/api/grpc/user/v2/user_integration_test.go index e43b6a2c93..299027d4ed 100644 --- a/internal/api/grpc/user/v2/user_integration_test.go +++ b/internal/api/grpc/user/v2/user_integration_test.go @@ -540,7 +540,7 @@ func TestServer_StartIdentityProviderFlow(t *testing.T) { ResourceOwner: Tester.Organisation.ID, }, NextStep: &user.StartIdentityProviderFlowResponse_AuthUrl{ - AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=https%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=", + AuthUrl: "https://example.com/oauth/v2/authorize?client_id=clientID&prompt=select_account&redirect_uri=http%3A%2F%2Flocalhost%3A8080%2Fidps%2Fcallback&response_type=code&scope=openid+profile+email&state=", }, }, wantErr: false, diff --git a/internal/api/oidc/amr/amr.go b/internal/api/oidc/amr/amr.go new file mode 100644 index 0000000000..1791f767c8 --- /dev/null +++ b/internal/api/oidc/amr/amr.go @@ -0,0 +1,43 @@ +// Package amr maps zitadel session factors to Authentication Method Reference Values +// as defined in [RFC 8176, section 2]. +// +// [RFC 8176, section 2]: https://datatracker.ietf.org/doc/html/rfc8176#section-2 +package amr + +const ( + // Password states that the users password has been verified + // Deprecated: use `PWD` instead + Password = "password" + // PWD states that the users password has been verified + PWD = "pwd" + // MFA states that multiple factors have been verified (e.g. pwd and otp or passkey) + MFA = "mfa" + // OTP states that a one time password has been verified (e.g. TOTP) + OTP = "otp" + // UserPresence states that the end users presence has been verified (e.g. passkey and u2f) + UserPresence = "user" +) + +type AuthenticationMethodReference interface { + IsPasswordChecked() bool + IsPasskeyChecked() bool + IsU2FChecked() bool + IsOTPChecked() bool +} + +func List(model AuthenticationMethodReference) []string { + amr := make([]string, 0) + if model.IsPasswordChecked() { + amr = append(amr, PWD) + } + if model.IsPasskeyChecked() || model.IsU2FChecked() { + amr = append(amr, UserPresence) + } + if model.IsOTPChecked() { + amr = append(amr, OTP) + } + if model.IsPasskeyChecked() || len(amr) >= 2 { + amr = append(amr, MFA) + } + return amr +} diff --git a/internal/api/oidc/amr/amr_test.go b/internal/api/oidc/amr/amr_test.go new file mode 100644 index 0000000000..f2c5189bcf --- /dev/null +++ b/internal/api/oidc/amr/amr_test.go @@ -0,0 +1,93 @@ +package amr + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAMR(t *testing.T) { + type args struct { + model AuthenticationMethodReference + } + tests := []struct { + name string + args args + want []string + }{ + { + "no checks, empty", + args{ + new(test), + }, + []string{}, + }, + { + "pw checked", + args{ + &test{pwChecked: true}, + }, + []string{PWD}, + }, + { + "passkey checked", + args{ + &test{passkeyChecked: true}, + }, + []string{UserPresence, MFA}, + }, + { + "u2f checked", + args{ + &test{u2fChecked: true}, + }, + []string{UserPresence}, + }, + { + "otp checked", + args{ + &test{otpChecked: true}, + }, + []string{OTP}, + }, + { + "multiple checked", + args{ + &test{ + pwChecked: true, + u2fChecked: true, + }, + }, + []string{PWD, UserPresence, MFA}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := List(tt.args.model) + assert.Equal(t, tt.want, got) + }) + } +} + +type test struct { + pwChecked bool + passkeyChecked bool + u2fChecked bool + otpChecked bool +} + +func (t test) IsPasswordChecked() bool { + return t.pwChecked +} + +func (t test) IsPasskeyChecked() bool { + return t.passkeyChecked +} + +func (t test) IsU2FChecked() bool { + return t.u2fChecked +} + +func (t test) IsOTPChecked() bool { + return t.otpChecked +} diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 0c943bfd6f..e2ccc8f86d 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -2,6 +2,7 @@ package oidc import ( "context" + "encoding/base64" "strings" "time" @@ -10,16 +11,75 @@ import ( "github.com/zitadel/oidc/v2/pkg/op" "github.com/zitadel/zitadel/internal/api/authz" + http_utils "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/api/http/middleware" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/user/model" ) +const ( + LoginClientHeader = "x-zitadel-login-client" +) + func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + + headers, _ := http_utils.HeadersFromCtx(ctx) + if loginClient := headers.Get(LoginClientHeader); loginClient != "" { + return o.createAuthRequestLoginClient(ctx, req, userID, loginClient) + } + + return o.createAuthRequest(ctx, req, userID) +} + +func (o *OPStorage) createAuthRequestLoginClient(ctx context.Context, req *oidc.AuthRequest, hintUserID, loginClient string) (op.AuthRequest, error) { + project, err := o.query.ProjectByClientID(ctx, req.ClientID, false) + if err != nil { + return nil, err + } + scope, err := o.assertProjectRoleScopesByProject(ctx, project, req.Scopes) + if err != nil { + return nil, err + } + audience, err := o.audienceFromProjectID(ctx, project.ID) + if err != nil { + return nil, err + } + audience = domain.AddAudScopeToAudience(ctx, audience, scope) + authRequest := &command.AuthRequest{ + LoginClient: loginClient, + ClientID: req.ClientID, + RedirectURI: req.RedirectURI, + State: req.State, + Nonce: req.Nonce, + Scope: scope, + Audience: audience, + ResponseType: ResponseTypeToBusiness(req.ResponseType), + CodeChallenge: CodeChallengeToBusiness(req.CodeChallenge, req.CodeChallengeMethod), + Prompt: PromptToBusiness(req.Prompt), + UILocales: UILocalesToBusiness(req.UILocales), + MaxAge: MaxAgeToBusiness(req.MaxAge), + } + if req.LoginHint != "" { + authRequest.LoginHint = &req.LoginHint + } + if hintUserID != "" { + authRequest.HintUserID = &hintUserID + } + + aar, err := o.command.AddAuthRequest(ctx, authRequest) + if err != nil { + return nil, err + } + return &AuthRequestV2{aar}, nil +} + +func (o *OPStorage) createAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) { userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) if !ok { return nil, errors.ThrowPreconditionFailed(nil, "OIDC-sd436", "no user agent id") @@ -36,9 +96,31 @@ func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest return AuthRequestFromBusiness(resp) } +func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) ([]string, error) { + projectIDQuery, err := query.NewAppProjectIDSearchQuery(projectID) + if err != nil { + return nil, err + } + appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false) + if err != nil { + return nil, err + } + + return append(appIDs, projectID), nil +} + func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + + if strings.HasPrefix(id, command.IDPrefixV2) { + req, err := o.command.GetCurrentAuthRequest(ctx, id) + if err != nil { + return nil, err + } + return &AuthRequestV2{req}, nil + } + userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) if !ok { return nil, errors.ThrowPreconditionFailed(nil, "OIDC-D3g21", "no user agent id") @@ -54,6 +136,17 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + plainCode, err := o.decryptGrant(code) + if err != nil { + return nil, err + } + if strings.HasPrefix(plainCode, command.IDPrefixV2) { + authReq, err := o.command.ExchangeAuthCode(ctx, plainCode) + if err != nil { + return nil, err + } + return &AuthRequestV2{authReq}, nil + } resp, err := o.repo.AuthRequestByCode(ctx, code) if err != nil { return nil, err @@ -61,9 +154,23 @@ func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.Au return AuthRequestFromBusiness(resp) } +// decryptGrant decrypts a code or refresh_token +func (o *OPStorage) decryptGrant(grant string) (string, error) { + decodedGrant, err := base64.RawURLEncoding.DecodeString(grant) + if err != nil { + return "", err + } + return o.encAlg.DecryptString(decodedGrant, o.encAlg.EncryptionKeyID()) +} + func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + + if strings.HasPrefix(id, command.IDPrefixV2) { + return o.command.AddAuthRequestCode(ctx, id, code) + } + userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) if !ok { return errors.ThrowPreconditionFailed(nil, "OIDC-Dgus2", "no user agent id") @@ -81,12 +188,15 @@ func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + var userAgentID, applicationID, userOrgID string - authReq, ok := req.(*AuthRequest) - if ok { + switch authReq := req.(type) { + case *AuthRequest: userAgentID = authReq.AgentID applicationID = authReq.ApplicationID userOrgID = authReq.UserOrgID + case *AuthRequestV2: + return o.command.AddOIDCSessionAccessToken(setContextUserSystem(ctx), authReq.GetID()) } accessTokenLifetime, _, _, _, err := o.getOIDCSettings(ctx) @@ -104,6 +214,15 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + + // handle V2 request directly + switch tokenReq := req.(type) { + case *AuthRequestV2: + return o.command.AddOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.GetID()) + case *RefreshTokenRequestV2: + return o.command.ExchangeOIDCSessionRefreshAndAccessToken(setContextUserSystem(ctx), tokenReq.OIDCSessionWriteModel.AggregateID, refreshToken, tokenReq.RequestedScopes) + } + userAgentID, applicationID, userOrgID, authTime, authMethodsReferences := getInfoFromRequest(req) scopes, err := o.assertProjectRoleScopes(ctx, applicationID, req.GetScopes()) if err != nil { @@ -142,7 +261,22 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time, return "", "", "", time.Time{}, nil } -func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { +func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + plainCode, err := o.decryptGrant(refreshToken) + if err != nil { + return nil, err + } + if strings.HasPrefix(plainCode, command.IDPrefixV2) { + oidcSession, err := o.command.OIDCSessionByRefreshToken(ctx, plainCode) + if err != nil { + return nil, err + } + return &RefreshTokenRequestV2{OIDCSessionWriteModel: oidcSession}, nil + } + tokenView, err := o.repo.RefreshTokenByToken(ctx, refreshToken) if err != nil { return nil, err @@ -245,6 +379,29 @@ func (o *OPStorage) assertProjectRoleScopes(ctx context.Context, clientID string return scopes, nil } +func (o *OPStorage) assertProjectRoleScopesByProject(ctx context.Context, project *query.Project, scopes []string) ([]string, error) { + for _, scope := range scopes { + if strings.HasPrefix(scope, ScopeProjectRolePrefix) { + return scopes, nil + } + } + if !project.ProjectRoleAssertion { + return scopes, nil + } + projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(project.ID) + if err != nil { + return nil, errors.ThrowInternal(err, "OIDC-Cyc78", "Errors.Internal") + } + roles, err := o.query.SearchProjectRoles(ctx, true, &query.ProjectRoleSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false) + if err != nil { + return nil, err + } + for _, role := range roles.ProjectRoles { + scopes = append(scopes, ScopeProjectRolePrefix+role.Key) + } + return scopes, nil +} + func (o *OPStorage) assertClientScopesForPAT(ctx context.Context, token *model.TokenView, clientID, projectID string) error { token.Audience = append(token.Audience, clientID) projectIDQuery, err := query.NewProjectRoleProjectIDSearchQuery(projectID) @@ -279,3 +436,58 @@ func (o *OPStorage) getOIDCSettings(ctx context.Context) (accessTokenLifetime, i } return o.defaultAccessTokenLifetime, o.defaultIdTokenLifetime, o.defaultRefreshTokenIdleExpiration, o.defaultRefreshTokenExpiration, nil } + +func CreateErrorCallbackURL(authReq op.AuthRequest, reason, description, uri string, authorizer op.Authorizer) (string, error) { + e := struct { + Error string `schema:"error"` + Description string `schema:"error_description,omitempty"` + URI string `schema:"error_uri,omitempty"` + State string `schema:"state,omitempty"` + }{ + Error: reason, + Description: description, + URI: uri, + State: authReq.GetState(), + } + callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), e, authorizer.Encoder()) + if err != nil { + return "", err + } + return callback, nil +} + +func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authorizer op.Authorizer) (string, error) { + code, err := op.CreateAuthRequestCode(ctx, authReq, authorizer.Storage(), authorizer.Crypto()) + if err != nil { + return "", err + } + codeResponse := struct { + code string + state string + }{ + code: code, + state: authReq.GetState(), + } + callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder()) + if err != nil { + return "", err + } + return callback, err +} + +func CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest, authorizer op.Authorizer) (string, error) { + client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID()) + if err != nil { + return "", err + } + createAccessToken := req.GetResponseType() != oidc.ResponseTypeIDTokenOnly + resp, err := op.CreateTokenResponse(ctx, req, client, authorizer, createAccessToken, "", "") + if err != nil { + return "", err + } + callback, err := op.AuthResponseURL(req.GetRedirectURI(), req.GetResponseType(), req.GetResponseMode(), resp, authorizer.Encoder()) + if err != nil { + return "", err + } + return callback, err +} diff --git a/internal/api/oidc/auth_request_converter.go b/internal/api/oidc/auth_request_converter.go index 6473460843..02fd3273ed 100644 --- a/internal/api/oidc/auth_request_converter.go +++ b/internal/api/oidc/auth_request_converter.go @@ -12,20 +12,12 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" http_utils "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/api/oidc/amr" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/user/model" ) -const ( - // DEPRECATED: use `amrPWD` instead - amrPassword = "password" - amrPWD = "pwd" - amrMFA = "mfa" - amrOTP = "otp" - amrUserPresence = "user" -) - type AuthRequest struct { *domain.AuthRequest } @@ -40,19 +32,19 @@ func (a *AuthRequest) GetACR() string { } func (a *AuthRequest) GetAMR() []string { - amr := make([]string, 0) + list := make([]string, 0) if a.PasswordVerified { - amr = append(amr, amrPassword, amrPWD) + list = append(list, amr.Password, amr.PWD) } if len(a.MFAsVerified) > 0 { - amr = append(amr, amrMFA) + list = append(list, amr.MFA) for _, mfa := range a.MFAsVerified { if amrMFA := AMRFromMFAType(mfa); amrMFA != "" { - amr = append(amr, amrMFA) + list = append(list, amrMFA) } } } - return amr + return list } func (a *AuthRequest) GetAudience() []string { @@ -271,10 +263,10 @@ func CodeChallengeToOIDC(challenge *domain.OIDCCodeChallenge) *oidc.CodeChalleng func AMRFromMFAType(mfaType domain.MFAType) string { switch mfaType { case domain.MFATypeOTP: - return amrOTP + return amr.OTP case domain.MFATypeU2F, domain.MFATypeU2FUserVerification: - return amrUserPresence + return amr.UserPresence default: return "" } diff --git a/internal/api/oidc/auth_request_converter_v2.go b/internal/api/oidc/auth_request_converter_v2.go new file mode 100644 index 0000000000..fd9c5f48ef --- /dev/null +++ b/internal/api/oidc/auth_request_converter_v2.go @@ -0,0 +1,106 @@ +package oidc + +import ( + "time" + + "github.com/zitadel/oidc/v2/pkg/oidc" + + "github.com/zitadel/zitadel/internal/command" +) + +type AuthRequestV2 struct { + *command.CurrentAuthRequest +} + +func (a *AuthRequestV2) GetID() string { + return a.ID +} + +func (a *AuthRequestV2) GetACR() string { + return "" //PLANNED: impl +} + +func (a *AuthRequestV2) GetAMR() []string { + return a.AMR +} + +func (a *AuthRequestV2) GetAudience() []string { + return a.Audience +} + +func (a *AuthRequestV2) GetAuthTime() time.Time { + return a.AuthTime +} + +func (a *AuthRequestV2) GetClientID() string { + return a.ClientID +} + +func (a *AuthRequestV2) GetCodeChallenge() *oidc.CodeChallenge { + return CodeChallengeToOIDC(a.CodeChallenge) +} + +func (a *AuthRequestV2) GetNonce() string { + return a.Nonce +} + +func (a *AuthRequestV2) GetRedirectURI() string { + return a.RedirectURI +} + +func (a *AuthRequestV2) GetResponseType() oidc.ResponseType { + return ResponseTypeToOIDC(a.ResponseType) +} + +func (a *AuthRequestV2) GetResponseMode() oidc.ResponseMode { + return "" +} + +func (a *AuthRequestV2) GetScopes() []string { + return a.Scope +} + +func (a *AuthRequestV2) GetState() string { + return a.State +} + +func (a *AuthRequestV2) GetSubject() string { + return a.UserID +} + +func (a *AuthRequestV2) Done() bool { + return a.UserID != "" && a.SessionID != "" +} + +type RefreshTokenRequestV2 struct { + *command.OIDCSessionWriteModel + RequestedScopes []string +} + +func (r *RefreshTokenRequestV2) GetAMR() []string { + return r.AuthMethodsReferences +} + +func (r *RefreshTokenRequestV2) GetAudience() []string { + return r.Audience +} + +func (r *RefreshTokenRequestV2) GetAuthTime() time.Time { + return r.AuthTime +} + +func (r *RefreshTokenRequestV2) GetClientID() string { + return r.ClientID +} + +func (r *RefreshTokenRequestV2) GetScopes() []string { + return r.Scope +} + +func (r *RefreshTokenRequestV2) GetSubject() string { + return r.UserID +} + +func (r *RefreshTokenRequestV2) SetCurrentScopes(scopes []string) { + r.RequestedScopes = scopes +} diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go new file mode 100644 index 0000000000..d58958256f --- /dev/null +++ b/internal/api/oidc/auth_request_integration_test.go @@ -0,0 +1,275 @@ +//go:build integration + +package oidc_test + +import ( + "context" + "net/url" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v2/pkg/client/rp" + "github.com/zitadel/oidc/v2/pkg/oidc" + "golang.org/x/oauth2" + + "github.com/zitadel/zitadel/internal/api/oidc/amr" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/internal/integration" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" + user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha" +) + +var ( + CTX context.Context + CTXLOGIN context.Context + Tester *integration.Tester + User *user.AddHumanUserResponse +) + +const ( + redirectURI = "oidcIntegrationTest://callback" + redirectURIImplicit = "http://localhost:9999/callback" +) + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, errCtx, cancel := integration.Contexts(5 * time.Minute) + defer cancel() + + Tester = integration.NewTester(ctx) + defer Tester.Done() + + CTX, _ = Tester.WithAuthorization(ctx, integration.OrgOwner), errCtx + User = Tester.CreateHumanUser(CTX) + Tester.RegisterUserPasskey(CTX, User.GetUserId()) + CTXLOGIN, _ = Tester.WithAuthorization(ctx, integration.Login), errCtx + return m.Run() + }()) +} + +func createClient(t testing.TB) string { + app, err := Tester.CreateOIDCNativeClient(CTX, redirectURI) + require.NoError(t, err) + return app.GetClientId() +} + +func createImplicitClient(t testing.TB) string { + app, err := Tester.CreateOIDCImplicitFlowClient(CTX, redirectURIImplicit) + require.NoError(t, err) + return app.GetClientId() +} + +func createAuthRequest(t testing.TB, clientID, redirectURI string, scope ...string) string { + redURL, err := Tester.CreateOIDCAuthRequest(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...) + require.NoError(t, err) + return redURL +} + +func createAuthRequestImplicit(t testing.TB, clientID, redirectURI string, scope ...string) string { + redURL, err := Tester.CreateOIDCAuthRequestImplicit(clientID, Tester.Users[integration.FirstInstanceUsersKey][integration.Login].ID, redirectURI, scope...) + require.NoError(t, err) + return redURL +} + +func TestOPStorage_CreateAuthRequest(t *testing.T) { + clientID := createClient(t) + + id := createAuthRequest(t, clientID, redirectURI) + require.Contains(t, id, command.IDPrefixV2) +} + +func TestOPStorage_CreateAccessToken_code(t *testing.T) { + clientID := createClient(t) + authRequestID := createAuthRequest(t, clientID, redirectURI) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // test code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, false) + assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime) + + // callback on a succeeded request must fail + linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.Error(t, err) + + // exchange with a used code must fail + _, err = exchangeTokens(t, clientID, code) + require.Error(t, err) +} + +func TestOPStorage_CreateAccessToken_implicit(t *testing.T) { + clientID := createImplicitClient(t) + authRequestID := createAuthRequestImplicit(t, clientID, redirectURIImplicit) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // test implicit callback + callback, err := url.Parse(linkResp.GetCallbackUrl()) + require.NoError(t, err) + values, err := url.ParseQuery(callback.Fragment) + require.NoError(t, err) + accessToken := values.Get("access_token") + idToken := values.Get("id_token") + refreshToken := values.Get("refresh_token") + assert.NotEmpty(t, accessToken) + assert.NotEmpty(t, idToken) + assert.Empty(t, refreshToken) + assert.NotEmpty(t, values.Get("expires_in")) + assert.Equal(t, oidc.BearerToken, values.Get("token_type")) + assert.Equal(t, "state", values.Get("state")) + + // check id_token / claims + provider, err := Tester.CreateRelyingParty(clientID, redirectURIImplicit) + require.NoError(t, err) + claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), accessToken, idToken, provider.IDTokenVerifier()) + require.NoError(t, err) + assertTokenClaims(t, claims, startTime, changeTime) + + // callback on a succeeded request must fail + linkResp, err = Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.Error(t, err) +} + +func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { + clientID := createClient(t) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // test code exchange (expect refresh token to be returned) + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime) +} + +func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { + clientID := createClient(t) + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, oidc.ScopeOfflineAccess) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasskeySession(t, CTXLOGIN, User.GetUserId()) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code) + require.NoError(t, err) + assertTokens(t, tokens, true) + assertTokenClaims(t, tokens.IDTokenClaims, startTime, changeTime) + + // test actual refresh grant + newTokens, err := refreshTokens(t, clientID, tokens.RefreshToken) + require.NoError(t, err) + idToken, _ := newTokens.Extra("id_token").(string) + assert.NotEmpty(t, idToken) + assert.NotEmpty(t, newTokens.AccessToken) + assert.NotEmpty(t, newTokens.RefreshToken) + claims, err := rp.VerifyTokens[*oidc.IDTokenClaims](context.Background(), newTokens.AccessToken, idToken, provider.IDTokenVerifier()) + require.NoError(t, err) + // auth time must still be the initial + assertTokenClaims(t, claims, startTime, changeTime) + + // refresh with an old refresh_token must fail + _, err = rp.RefreshAccessToken(provider, tokens.RefreshToken, "", "") + require.Error(t, err) +} + +func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) { + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + + codeVerifier := "codeVerifier" + return rp.CodeExchange[*oidc.IDTokenClaims](context.Background(), code, provider, rp.WithCodeVerifier(codeVerifier)) +} + +func refreshTokens(t testing.TB, clientID, refreshToken string) (*oauth2.Token, error) { + provider, err := Tester.CreateRelyingParty(clientID, redirectURI) + require.NoError(t, err) + + return rp.RefreshAccessToken(provider, refreshToken, "", "") +} + +func assertCodeResponse(t *testing.T, callback string) string { + callbackURL, err := url.Parse(callback) + require.NoError(t, err) + code := callbackURL.Query().Get("code") + require.NotEmpty(t, code) + assert.Equal(t, "state", callbackURL.Query().Get("state")) + return code +} + +func assertTokens(t *testing.T, tokens *oidc.Tokens[*oidc.IDTokenClaims], requireRefreshToken bool) { + assert.NotEmpty(t, tokens.AccessToken) + assert.NotEmpty(t, tokens.IDToken) + if requireRefreshToken { + assert.NotEmpty(t, tokens.RefreshToken) + } else { + assert.Empty(t, tokens.RefreshToken) + } +} + +func assertTokenClaims(t *testing.T, claims *oidc.IDTokenClaims, sessionStart, sessionChange time.Time) { + assert.Equal(t, User.GetUserId(), claims.Subject) + assert.Equal(t, []string{amr.UserPresence, amr.MFA}, claims.AuthenticationMethodsReferences) + assert.WithinRange(t, claims.AuthTime.AsTime().UTC(), sessionStart.Add(-1*time.Second), sessionChange.Add(1*time.Second)) +} diff --git a/internal/api/oidc/client.go b/internal/api/oidc/client.go index e1a26f866f..a8c8a3febf 100644 --- a/internal/api/oidc/client.go +++ b/internal/api/oidc/client.go @@ -66,7 +66,7 @@ func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Cl return nil, err } - return ClientFromBusiness(client, o.defaultLoginURL, accessTokenLifetime, idTokenLifetime, allowedScopes) + return ClientFromBusiness(client, o.defaultLoginURL, o.defaultLoginURLV2, accessTokenLifetime, idTokenLifetime, allowedScopes) } func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID string) (_ *jose.JSONWebKey, err error) { @@ -153,7 +153,7 @@ func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.Us return o.setUserinfo(ctx, userInfo, userID, applicationID, scopes, nil) } -func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error { +func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) { token, err := o.repo.TokenByIDs(ctx, subject, tokenID) if err != nil { return errors.ThrowPermissionDenied(nil, "OIDC-Dsfb2", "token is not valid or has expired") diff --git a/internal/api/oidc/client_converter.go b/internal/api/oidc/client_converter.go index 6b32f38927..5f3f25f759 100644 --- a/internal/api/oidc/client_converter.go +++ b/internal/api/oidc/client_converter.go @@ -7,6 +7,7 @@ import ( "github.com/zitadel/oidc/v2/pkg/oidc" "github.com/zitadel/oidc/v2/pkg/op" + "github.com/zitadel/zitadel/internal/command" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query" @@ -15,18 +16,20 @@ import ( type Client struct { app *query.App defaultLoginURL string + defaultLoginURLV2 string defaultAccessTokenLifetime time.Duration defaultIdTokenLifetime time.Duration allowedScopes []string } -func ClientFromBusiness(app *query.App, defaultLoginURL string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) { +func ClientFromBusiness(app *query.App, defaultLoginURL, defaultLoginURLV2 string, defaultAccessTokenLifetime, defaultIdTokenLifetime time.Duration, allowedScopes []string) (op.Client, error) { if app.OIDCConfig == nil { return nil, errors.ThrowInvalidArgument(nil, "OIDC-d5bhD", "client is not a proper oidc application") } return &Client{ app: app, defaultLoginURL: defaultLoginURL, + defaultLoginURLV2: defaultLoginURLV2, defaultAccessTokenLifetime: defaultAccessTokenLifetime, defaultIdTokenLifetime: defaultIdTokenLifetime, allowedScopes: allowedScopes}, @@ -46,6 +49,9 @@ func (c *Client) GetID() string { } func (c *Client) LoginURL(id string) string { + if strings.HasPrefix(id, command.IDPrefixV2) { + return c.defaultLoginURLV2 + id + } return c.defaultLoginURL + id } diff --git a/internal/api/oidc/op.go b/internal/api/oidc/op.go index 9574f561b6..038f741388 100644 --- a/internal/api/oidc/op.go +++ b/internal/api/oidc/op.go @@ -41,6 +41,7 @@ type Config struct { Cache *middleware.CacheConfig CustomEndpoints *EndpointConfig DeviceAuth *DeviceAuthorizationConfig + DefaultLoginURLV2 string } type EndpointConfig struct { @@ -65,6 +66,7 @@ type OPStorage struct { query *query.Queries eventstore *eventstore.Eventstore defaultLoginURL string + defaultLoginURLV2 string defaultAccessTokenLifetime time.Duration defaultIdTokenLifetime time.Duration signingKeyAlgorithm string @@ -181,6 +183,7 @@ func newStorage(config Config, command *command.Commands, query *query.Queries, query: query, eventstore: es, defaultLoginURL: fmt.Sprintf("%s%s?%s=", login.HandlerPrefix, login.EndpointLogin, login.QueryAuthRequestID), + defaultLoginURLV2: config.DefaultLoginURLV2, signingKeyAlgorithm: config.SigningKeyAlgorithm, defaultAccessTokenLifetime: config.DefaultAccessTokenLifetime, defaultIdTokenLifetime: config.DefaultIdTokenLifetime, diff --git a/internal/command/auth_request.go b/internal/command/auth_request.go new file mode 100644 index 0000000000..1ba0a55173 --- /dev/null +++ b/internal/command/auth_request.go @@ -0,0 +1,215 @@ +package command + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/oidc/amr" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +type AuthRequest struct { + ID string + LoginClient string + ClientID string + RedirectURI string + State string + Nonce string + Scope []string + Audience []string + ResponseType domain.OIDCResponseType + CodeChallenge *domain.OIDCCodeChallenge + Prompt []domain.Prompt + UILocales []string + MaxAge *time.Duration + LoginHint *string + HintUserID *string +} + +type CurrentAuthRequest struct { + *AuthRequest + SessionID string + UserID string + AMR []string + AuthTime time.Time +} + +const IDPrefixV2 = "V2_" + +func (c *Commands) AddAuthRequest(ctx context.Context, authRequest *AuthRequest) (_ *CurrentAuthRequest, err error) { + authRequestID, err := c.idGenerator.Next() + if err != nil { + return nil, err + } + authRequest.ID = IDPrefixV2 + authRequestID + writeModel, err := c.getAuthRequestWriteModel(ctx, authRequest.ID) + if err != nil { + return nil, err + } + if writeModel.AuthRequestState != domain.AuthRequestStateUnspecified { + return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting") + } + err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewAddedEvent( + ctx, + &authrequest.NewAggregate(authRequest.ID, authz.GetInstance(ctx).InstanceID()).Aggregate, + authRequest.LoginClient, + authRequest.ClientID, + authRequest.RedirectURI, + authRequest.State, + authRequest.Nonce, + authRequest.Scope, + authRequest.Audience, + authRequest.ResponseType, + authRequest.CodeChallenge, + authRequest.Prompt, + authRequest.UILocales, + authRequest.MaxAge, + authRequest.LoginHint, + authRequest.HintUserID, + )) + if err != nil { + return nil, err + } + return authRequestWriteModelToCurrentAuthRequest(writeModel), nil +} + +func (c *Commands) LinkSessionToAuthRequest(ctx context.Context, id, sessionID, sessionToken string, checkLoginClient bool) (*domain.ObjectDetails, *CurrentAuthRequest, error) { + writeModel, err := c.getAuthRequestWriteModel(ctx, id) + if err != nil { + return nil, nil, err + } + if writeModel.AuthRequestState == domain.AuthRequestStateUnspecified { + return nil, nil, errors.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting") + } + if writeModel.AuthRequestState != domain.AuthRequestStateAdded { + return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled") + } + if checkLoginClient && authz.GetCtxData(ctx).UserID != writeModel.LoginClient { + return nil, nil, errors.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient") + } + sessionWriteModel := NewSessionWriteModel(sessionID, authz.GetCtxData(ctx).OrgID) + err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) + if err != nil { + return nil, nil, err + } + if sessionWriteModel.State == domain.SessionStateUnspecified { + return nil, nil, errors.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting") + } + if err := c.sessionPermission(ctx, sessionWriteModel, sessionToken, domain.PermissionSessionWrite); err != nil { + return nil, nil, err + } + + if err := c.pushAppendAndReduce(ctx, writeModel, authrequest.NewSessionLinkedEvent( + ctx, &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate, + sessionID, + sessionWriteModel.UserID, + sessionWriteModel.AuthenticationTime(), + amr.List(sessionWriteModel), + )); err != nil { + return nil, nil, err + } + return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil +} + +func (c *Commands) FailAuthRequest(ctx context.Context, id string, reason domain.OIDCErrorReason) (*domain.ObjectDetails, *CurrentAuthRequest, error) { + writeModel, err := c.getAuthRequestWriteModel(ctx, id) + if err != nil { + return nil, nil, err + } + if writeModel.AuthRequestState != domain.AuthRequestStateAdded { + return nil, nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled") + } + err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewFailedEvent( + ctx, + &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate, + reason, + )) + if err != nil { + return nil, nil, err + } + return writeModelToObjectDetails(&writeModel.WriteModel), authRequestWriteModelToCurrentAuthRequest(writeModel), nil +} + +func (c *Commands) AddAuthRequestCode(ctx context.Context, authRequestID, code string) (err error) { + if code == "" { + return errors.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode") + } + writeModel, err := c.getAuthRequestWriteModel(ctx, authRequestID) + if err != nil { + return err + } + if writeModel.AuthRequestState != domain.AuthRequestStateAdded || writeModel.SessionID == "" { + return errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled") + } + return c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeAddedEvent(ctx, + &authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate)) +} + +func (c *Commands) ExchangeAuthCode(ctx context.Context, code string) (authRequest *CurrentAuthRequest, err error) { + if code == "" { + return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode") + } + writeModel, err := c.getAuthRequestWriteModel(ctx, code) + if err != nil { + return nil, err + } + if writeModel.AuthRequestState != domain.AuthRequestStateCodeAdded { + return nil, errors.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode") + } + err = c.pushAppendAndReduce(ctx, writeModel, authrequest.NewCodeExchangedEvent(ctx, + &authrequest.NewAggregate(writeModel.AggregateID, authz.GetInstance(ctx).InstanceID()).Aggregate)) + if err != nil { + return nil, err + } + return authRequestWriteModelToCurrentAuthRequest(writeModel), nil +} + +func authRequestWriteModelToCurrentAuthRequest(writeModel *AuthRequestWriteModel) (_ *CurrentAuthRequest) { + return &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: writeModel.AggregateID, + LoginClient: writeModel.LoginClient, + ClientID: writeModel.ClientID, + RedirectURI: writeModel.RedirectURI, + State: writeModel.State, + Nonce: writeModel.Nonce, + Scope: writeModel.Scope, + Audience: writeModel.Audience, + ResponseType: writeModel.ResponseType, + CodeChallenge: writeModel.CodeChallenge, + Prompt: writeModel.Prompt, + UILocales: writeModel.UILocales, + MaxAge: writeModel.MaxAge, + LoginHint: writeModel.LoginHint, + HintUserID: writeModel.HintUserID, + }, + SessionID: writeModel.SessionID, + UserID: writeModel.UserID, + AMR: writeModel.AMR, + AuthTime: writeModel.AuthTime, + } +} + +func (c *Commands) GetCurrentAuthRequest(ctx context.Context, id string) (_ *CurrentAuthRequest, err error) { + wm, err := c.getAuthRequestWriteModel(ctx, id) + if err != nil { + return nil, err + } + return authRequestWriteModelToCurrentAuthRequest(wm), nil +} + +func (c *Commands) getAuthRequestWriteModel(ctx context.Context, id string) (writeModel *AuthRequestWriteModel, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + writeModel = NewAuthRequestWriteModel(ctx, id) + err = c.eventstore.FilterToQueryReducer(ctx, writeModel) + if err != nil { + return nil, err + } + return writeModel, nil +} diff --git a/internal/command/auth_request_model.go b/internal/command/auth_request_model.go new file mode 100644 index 0000000000..91bbaf955a --- /dev/null +++ b/internal/command/auth_request_model.go @@ -0,0 +1,110 @@ +package command + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/authrequest" +) + +type AuthRequestWriteModel struct { + eventstore.WriteModel + aggregate *eventstore.Aggregate + + LoginClient string + ClientID string + RedirectURI string + State string + Nonce string + Scope []string + Audience []string + ResponseType domain.OIDCResponseType + CodeChallenge *domain.OIDCCodeChallenge + Prompt []domain.Prompt + UILocales []string + MaxAge *time.Duration + LoginHint *string + HintUserID *string + SessionID string + UserID string + AuthTime time.Time + AMR []string + AuthRequestState domain.AuthRequestState +} + +func NewAuthRequestWriteModel(ctx context.Context, id string) *AuthRequestWriteModel { + return &AuthRequestWriteModel{ + WriteModel: eventstore.WriteModel{ + AggregateID: id, + }, + aggregate: &authrequest.NewAggregate(id, authz.GetInstance(ctx).InstanceID()).Aggregate, + } +} + +func (m *AuthRequestWriteModel) Reduce() error { + for _, event := range m.Events { + switch e := event.(type) { + case *authrequest.AddedEvent: + m.LoginClient = e.LoginClient + m.ClientID = e.ClientID + m.RedirectURI = e.RedirectURI + m.State = e.State + m.Nonce = e.Nonce + m.Scope = e.Scope + m.Audience = e.Audience + m.ResponseType = e.ResponseType + m.CodeChallenge = e.CodeChallenge + m.Prompt = e.Prompt + m.UILocales = e.UILocales + m.MaxAge = e.MaxAge + m.LoginHint = e.LoginHint + m.HintUserID = e.HintUserID + m.AuthRequestState = domain.AuthRequestStateAdded + case *authrequest.SessionLinkedEvent: + m.SessionID = e.SessionID + m.UserID = e.UserID + m.AuthTime = e.AuthTime + m.AMR = e.AMR + case *authrequest.CodeAddedEvent: + m.AuthRequestState = domain.AuthRequestStateCodeAdded + case *authrequest.FailedEvent: + m.AuthRequestState = domain.AuthRequestStateFailed + case *authrequest.CodeExchangedEvent: + m.AuthRequestState = domain.AuthRequestStateCodeExchanged + case *authrequest.SucceededEvent: + m.AuthRequestState = domain.AuthRequestStateSucceeded + } + } + + return m.WriteModel.Reduce() +} + +func (m *AuthRequestWriteModel) Query() *eventstore.SearchQueryBuilder { + return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(authrequest.AggregateType). + AggregateIDs(m.AggregateID). + Builder() +} + +// CheckAuthenticated checks that the auth request exists, a session must have been linked +// and in case of a Code Flow the code must have been exchanged +func (m *AuthRequestWriteModel) CheckAuthenticated() error { + if m.SessionID == "" { + return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated") + } + // in case of OIDC Code Flow, the code must have been exchanged + if m.ResponseType == domain.OIDCResponseTypeCode && m.AuthRequestState == domain.AuthRequestStateCodeExchanged { + return nil + } + // in case of OIDC Implicit Flow, check that the requests exists, but has not succeeded yet + if (m.ResponseType == domain.OIDCResponseTypeIDToken || m.ResponseType == domain.OIDCResponseTypeIDTokenToken) && + m.AuthRequestState == domain.AuthRequestStateAdded { + return nil + } + return caos_errs.ThrowPreconditionFailed(nil, "AUTHR-sajk3", "Errors.AuthRequest.NotAuthenticated") +} diff --git a/internal/command/auth_request_test.go b/internal/command/auth_request_test.go new file mode 100644 index 0000000000..ce12a119c1 --- /dev/null +++ b/internal/command/auth_request_test.go @@ -0,0 +1,998 @@ +package command + +import ( + "context" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/oidc/amr" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/id" + "github.com/zitadel/zitadel/internal/id/mock" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/session" +) + +func TestCommands_AddAuthRequest(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + } + type args struct { + ctx context.Context + request *AuthRequest + } + tests := []struct { + name string + fields fields + args args + want *CurrentAuthRequest + wantErr error + }{ + { + "already exists error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"), + }, + args{ + ctx: mockCtx, + request: &AuthRequest{}, + }, + nil, + caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3gt", "Errors.AuthRequest.AlreadyExisting"), + }, + { + "added", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + }), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "id"), + }, + args{ + ctx: mockCtx, + request: &AuthRequest{ + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + CodeChallenge: &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + Prompt: []domain.Prompt{domain.PromptNone}, + UILocales: []string{"en", "de"}, + MaxAge: gu.Ptr(time.Duration(0)), + LoginHint: gu.Ptr("loginHint"), + HintUserID: gu.Ptr("hintUserID"), + }, + }, + &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: "V2_id", + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + CodeChallenge: &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + Prompt: []domain.Prompt{domain.PromptNone}, + UILocales: []string{"en", "de"}, + MaxAge: gu.Ptr(time.Duration(0)), + LoginHint: gu.Ptr("loginHint"), + HintUserID: gu.Ptr("hintUserID"), + }, + }, + nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + } + got, err := c.AddAuthRequest(tt.args.ctx, tt.args.request) + require.ErrorIs(t, tt.wantErr, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestCommands_LinkSessionToAuthRequest(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore *eventstore.Eventstore + tokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) + checkPermission domain.PermissionCheck + } + type args struct { + ctx context.Context + id string + sessionID string + sessionToken string + checkLoginClient bool + } + type res struct { + details *domain.ObjectDetails + authReq *CurrentAuthRequest + wantErr error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "authRequest not found", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckNotAllowed(), + }, + args{ + ctx: mockCtx, + id: "id", + sessionID: "sessionID", + }, + res{ + wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-jae5P", "Errors.AuthRequest.NotExisting"), + }, + }, + { + "authRequest not existing", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + eventFromEventPusher( + authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate, + domain.OIDCErrorReasonUnspecified), + ), + ), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckAllowed(), + }, + args{ + ctx: mockCtx, + id: "id", + sessionID: "sessionID", + }, + res{ + wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx208nt", "Errors.AuthRequest.AlreadyHandled"), + }, + }, + { + "wrong login client", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckAllowed(), + }, + args{ + ctx: authz.NewMockContext("instanceID", "orgID", "wrongLoginClient"), + id: "id", + sessionID: "sessionID", + sessionToken: "token", + checkLoginClient: true, + }, + res{ + wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-rai9Y", "Errors.AuthRequest.WrongLoginClient"), + }, + }, + { + "session not existing", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + expectFilter(), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckNotAllowed(), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + }, + res{ + wantErr: caos_errs.ThrowNotFound(nil, "COMMAND-x0099887", "Errors.Session.NotExisting"), + }, + }, + { + "missing permission", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), + ), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckNotAllowed(), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + }, + res{ + wantErr: caos_errs.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"), + }, + }, + { + "invalid session token", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld")), + ), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid") + }, + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + sessionToken: "invalid", + }, + res{ + wantErr: caos_errs.ThrowPermissionDenied(nil, "COMMAND-sGr42", "Errors.Session.Token.Invalid"), + }, + }, + { + "linked", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"), + ), + eventFromEventPusher( + session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, + "userID", testNow), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, + testNow), + ), + ), + expectPush( + []*repository.Event{eventFromEventPusherWithInstanceID( + "instanceID", + authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + )}), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckAllowed(), + }, + args{ + ctx: mockCtx, + id: "V2_id", + sessionID: "sessionID", + sessionToken: "token", + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: "V2_id", + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + }, + SessionID: "sessionID", + UserID: "userID", + AMR: []string{amr.PWD}, + }, + }, + }, + { + "linked with login client check", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, "domain.tld"), + ), + eventFromEventPusher( + session.NewUserCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, + "userID", testNow), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(mockCtx, &session.NewAggregate("sessionID", "org1").Aggregate, + testNow), + ), + ), + expectPush( + []*repository.Event{eventFromEventPusherWithInstanceID( + "instanceID", + authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + )}), + ), + tokenVerifier: func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) { + return nil + }, + checkPermission: newMockPermissionCheckAllowed(), + }, + args{ + ctx: authz.NewMockContext("instanceID", "orgID", "loginClient"), + id: "V2_id", + sessionID: "sessionID", + sessionToken: "token", + checkLoginClient: true, + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: "V2_id", + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + }, + SessionID: "sessionID", + UserID: "userID", + AMR: []string{amr.PWD}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + sessionTokenVerifier: tt.fields.tokenVerifier, + checkPermission: tt.fields.checkPermission, + } + details, got, err := c.LinkSessionToAuthRequest(tt.args.ctx, tt.args.id, tt.args.sessionID, tt.args.sessionToken, tt.args.checkLoginClient) + require.ErrorIs(t, err, tt.res.wantErr) + assert.Equal(t, tt.res.details, details) + if err == nil { + assert.WithinRange(t, got.AuthTime, testNow, testNow) + got.AuthTime = time.Time{} + } + assert.Equal(t, tt.res.authReq, got) + }) + } +} + +func TestCommands_FailAuthRequest(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + id string + reason domain.OIDCErrorReason + } + type res struct { + details *domain.ObjectDetails + authReq *CurrentAuthRequest + wantErr error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "authRequest not existing", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + ), + }, + args{ + ctx: mockCtx, + id: "foo", + reason: domain.OIDCErrorReasonLoginRequired, + }, + res{ + wantErr: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sx202nt", "Errors.AuthRequest.AlreadyHandled"), + }, + }, + { + "failed", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + nil, + nil, + nil, + nil, + nil, + nil, + ), + ), + ), + expectPush( + []*repository.Event{eventFromEventPusherWithInstanceID( + "instanceID", + authrequest.NewFailedEvent(mockCtx, &authrequest.NewAggregate("V2_id", "instanceID").Aggregate, + domain.OIDCErrorReasonLoginRequired), + )}), + ), + }, + args{ + ctx: mockCtx, + id: "V2_id", + reason: domain.OIDCErrorReasonLoginRequired, + }, + res{ + details: &domain.ObjectDetails{ResourceOwner: "instanceID"}, + authReq: &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: "V2_id", + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + } + details, got, err := c.FailAuthRequest(tt.args.ctx, tt.args.id, tt.args.reason) + require.ErrorIs(t, err, tt.res.wantErr) + assert.Equal(t, tt.res.details, details) + assert.Equal(t, tt.res.authReq, got) + }) + } +} + +func TestCommands_AddAuthRequestCode(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + id string + code string + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + { + "empty code error", + fields{ + eventstore: eventstoreExpect(t), + }, + args{ + ctx: mockCtx, + id: "V2_authRequestID", + code: "", + }, + caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Ht52d", "Errors.AuthRequest.InvalidCode"), + }, + { + "no session linked error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + ), + ), + }, + args{ + ctx: mockCtx, + id: "V2_authRequestID", + code: "V2_authRequestID", + }, + caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.AlreadyHandled"), + }, + { + "success", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + ), + ), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + }, + ), + ), + }, + args{ + ctx: mockCtx, + id: "V2_authRequestID", + code: "V2_authRequestID", + }, + nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + } + err := c.AddAuthRequestCode(tt.args.ctx, tt.args.id, tt.args.code) + assert.ErrorIs(t, tt.wantErr, err) + }) + } +} + +func TestCommands_ExchangeAuthCode(t *testing.T) { + mockCtx := authz.NewMockContext("instanceID", "orgID", "loginClient") + type fields struct { + eventstore *eventstore.Eventstore + } + type args struct { + ctx context.Context + code string + } + type res struct { + authRequest *CurrentAuthRequest + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "empty code error", + fields{ + eventstore: eventstoreExpect(t), + }, + args{ + ctx: mockCtx, + code: "", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-Sf3g2", "Errors.AuthRequest.InvalidCode"), + }, + }, + { + "no code added error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + ), + ), + }, + args{ + ctx: mockCtx, + code: "V2_authRequestID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "COMMAND-SFwd2", "Errors.AuthRequest.NoCode"), + }, + }, + { + "code exchanged", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + ), + eventFromEventPusher( + authrequest.NewCodeAddedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + ), + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + authrequest.NewCodeExchangedEvent(mockCtx, &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + }, + ), + ), + }, + args{ + ctx: mockCtx, + code: "V2_authRequestID", + }, + res{ + authRequest: &CurrentAuthRequest{ + AuthRequest: &AuthRequest{ + ID: "V2_authRequestID", + LoginClient: "loginClient", + ClientID: "clientID", + RedirectURI: "redirectURI", + State: "state", + Nonce: "nonce", + Scope: []string{"openid"}, + Audience: []string{"audience"}, + ResponseType: domain.OIDCResponseTypeCode, + CodeChallenge: &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + Prompt: []domain.Prompt{domain.PromptNone}, + UILocales: []string{"en", "de"}, + MaxAge: gu.Ptr(time.Duration(0)), + LoginHint: gu.Ptr("loginHint"), + HintUserID: gu.Ptr("hintUserID"), + }, + SessionID: "sessionID", + UserID: "userID", + AMR: []string{"pwd"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + } + got, err := c.ExchangeAuthCode(tt.args.ctx, tt.args.code) + assert.ErrorIs(t, tt.res.err, err) + + if err == nil { + // equal on time won't work -> test separately and clear it before comparing the rest + assert.WithinRange(t, got.AuthTime, testNow, testNow) + got.AuthTime = time.Time{} + } + assert.Equal(t, tt.res.authRequest, got) + }) + } +} diff --git a/internal/command/command.go b/internal/command/command.go index 3080bd31f0..2c89b78a2d 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -15,10 +15,12 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/id" "github.com/zitadel/zitadel/internal/repository/action" + "github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/idpintent" instance_repo "github.com/zitadel/zitadel/internal/repository/instance" "github.com/zitadel/zitadel/internal/repository/keypair" "github.com/zitadel/zitadel/internal/repository/milestone" + "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/org" proj_repo "github.com/zitadel/zitadel/internal/repository/project" "github.com/zitadel/zitadel/internal/repository/quota" @@ -43,18 +45,21 @@ type Commands struct { externalSecure bool externalPort uint16 - idpConfigEncryption crypto.EncryptionAlgorithm - smtpEncryption crypto.EncryptionAlgorithm - smsEncryption crypto.EncryptionAlgorithm - userEncryption crypto.EncryptionAlgorithm - userPasswordAlg crypto.HashAlgorithm - machineKeySize int - applicationKeySize int - domainVerificationAlg crypto.EncryptionAlgorithm - domainVerificationGenerator crypto.Generator - domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error - sessionTokenCreator func(sessionID string) (id string, token string, err error) - sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) + idpConfigEncryption crypto.EncryptionAlgorithm + smtpEncryption crypto.EncryptionAlgorithm + smsEncryption crypto.EncryptionAlgorithm + userEncryption crypto.EncryptionAlgorithm + userPasswordAlg crypto.HashAlgorithm + machineKeySize int + applicationKeySize int + domainVerificationAlg crypto.EncryptionAlgorithm + domainVerificationGenerator crypto.Generator + domainVerificationValidator func(domain, token, verifier string, checkType api_http.CheckType) error + sessionTokenCreator func(sessionID string) (id string, token string, err error) + sessionTokenVerifier func(ctx context.Context, sessionToken, sessionID, tokenID string) (err error) + defaultAccessTokenLifetime time.Duration + defaultRefreshTokenLifetime time.Duration + defaultRefreshTokenIdleLifetime time.Duration multifactors domain.MultifactorConfigs webauthnConfig *webauthn_helper.Config @@ -80,6 +85,9 @@ func StartCommands( httpClient *http.Client, permissionCheck domain.PermissionCheck, sessionTokenVerifier func(ctx context.Context, sessionToken string, sessionID string, tokenID string) (err error), + defaultAccessTokenLifetime, + defaultRefreshTokenLifetime, + defaultRefreshTokenIdleLifetime time.Duration, ) (repo *Commands, err error) { if externalDomain == "" { return nil, errors.ThrowInvalidArgument(nil, "COMMAND-Df21s", "no external domain specified") @@ -88,31 +96,34 @@ func StartCommands( // reuse the oidcEncryption to be able to handle both tokens in the interceptor later on sessionAlg := oidcEncryption repo = &Commands{ - eventstore: es, - static: staticStore, - idGenerator: idGenerator, - zitadelRoles: zitadelRoles, - externalDomain: externalDomain, - externalSecure: externalSecure, - externalPort: externalPort, - keySize: defaults.KeyConfig.Size, - certKeySize: defaults.KeyConfig.CertificateSize, - privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime, - publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime, - certificateLifetime: defaults.KeyConfig.CertificateLifetime, - idpConfigEncryption: idpConfigEncryption, - smtpEncryption: smtpEncryption, - smsEncryption: smsEncryption, - userEncryption: userEncryption, - domainVerificationAlg: domainVerificationEncryption, - keyAlgorithm: oidcEncryption, - certificateAlgorithm: samlEncryption, - webauthnConfig: webAuthN, - httpClient: httpClient, - checkPermission: permissionCheck, - newCode: newCryptoCode, - sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg), - sessionTokenVerifier: sessionTokenVerifier, + eventstore: es, + static: staticStore, + idGenerator: idGenerator, + zitadelRoles: zitadelRoles, + externalDomain: externalDomain, + externalSecure: externalSecure, + externalPort: externalPort, + keySize: defaults.KeyConfig.Size, + certKeySize: defaults.KeyConfig.CertificateSize, + privateKeyLifetime: defaults.KeyConfig.PrivateKeyLifetime, + publicKeyLifetime: defaults.KeyConfig.PublicKeyLifetime, + certificateLifetime: defaults.KeyConfig.CertificateLifetime, + idpConfigEncryption: idpConfigEncryption, + smtpEncryption: smtpEncryption, + smsEncryption: smsEncryption, + userEncryption: userEncryption, + domainVerificationAlg: domainVerificationEncryption, + keyAlgorithm: oidcEncryption, + certificateAlgorithm: samlEncryption, + webauthnConfig: webAuthN, + httpClient: httpClient, + checkPermission: permissionCheck, + newCode: newCryptoCode, + sessionTokenCreator: sessionTokenCreator(idGenerator, sessionAlg), + sessionTokenVerifier: sessionTokenVerifier, + defaultAccessTokenLifetime: defaultAccessTokenLifetime, + defaultRefreshTokenLifetime: defaultRefreshTokenLifetime, + defaultRefreshTokenIdleLifetime: defaultRefreshTokenIdleLifetime, } instance_repo.RegisterEventMappers(repo.eventstore) @@ -125,6 +136,8 @@ func StartCommands( quota.RegisterEventMappers(repo.eventstore) session.RegisterEventMappers(repo.eventstore) idpintent.RegisterEventMappers(repo.eventstore) + authrequest.RegisterEventMappers(repo.eventstore) + oidcsession.RegisterEventMappers(repo.eventstore) milestone.RegisterEventMappers(repo.eventstore) repo.userPasswordAlg = crypto.NewBCrypt(defaults.SecretGenerators.PasswordSaltCost) diff --git a/internal/command/main_test.go b/internal/command/main_test.go index 184b0a9b04..9f778af7bf 100644 --- a/internal/command/main_test.go +++ b/internal/command/main_test.go @@ -16,9 +16,11 @@ import ( "github.com/zitadel/zitadel/internal/eventstore/repository" "github.com/zitadel/zitadel/internal/eventstore/repository/mock" action_repo "github.com/zitadel/zitadel/internal/repository/action" + "github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/idpintent" iam_repo "github.com/zitadel/zitadel/internal/repository/instance" key_repo "github.com/zitadel/zitadel/internal/repository/keypair" + "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/org" proj_repo "github.com/zitadel/zitadel/internal/repository/project" "github.com/zitadel/zitadel/internal/repository/session" @@ -43,6 +45,8 @@ func eventstoreExpect(t *testing.T, expects ...expect) *eventstore.Eventstore { action_repo.RegisterEventMappers(es) session.RegisterEventMappers(es) idpintent.RegisterEventMappers(es) + authrequest.RegisterEventMappers(es) + oidcsession.RegisterEventMappers(es) return es } diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go new file mode 100644 index 0000000000..6af09d91b0 --- /dev/null +++ b/internal/command/oidc_session.go @@ -0,0 +1,281 @@ +package command + +import ( + "context" + "encoding/base64" + "strings" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/oidc/amr" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/id" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/oidcsession" +) + +// AddOIDCSessionAccessToken creates a new OIDC Session, creates an access token and returns its id and expiration. +// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. +func (c *Commands) AddOIDCSessionAccessToken(ctx context.Context, authRequestID string) (string, time.Time, error) { + cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID) + if err != nil { + return "", time.Time{}, err + } + cmd.AddSession(ctx) + if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil { + return "", time.Time{}, err + } + cmd.SetAuthRequestSuccessful(ctx) + accessTokenID, _, accessTokenExpiration, err := cmd.PushEvents(ctx) + return accessTokenID, accessTokenExpiration, err +} + +// AddOIDCSessionRefreshAndAccessToken creates a new OIDC Session, creates an access token and refresh token. +// It returns the access token id, expiration and the refresh token. +// If the underlying [AuthRequest] is a OIDC Auth Code Flow, it will set the code as exchanged. +func (c *Commands) AddOIDCSessionRefreshAndAccessToken(ctx context.Context, authRequestID string) (tokenID, refreshToken string, tokenExpiration time.Time, err error) { + cmd, err := c.newOIDCSessionAddEvents(ctx, authRequestID) + if err != nil { + return "", "", time.Time{}, err + } + cmd.AddSession(ctx) + if err = cmd.AddAccessToken(ctx, cmd.authRequestWriteModel.Scope); err != nil { + return "", "", time.Time{}, err + } + if err = cmd.AddRefreshToken(ctx); err != nil { + return "", "", time.Time{}, err + } + cmd.SetAuthRequestSuccessful(ctx) + return cmd.PushEvents(ctx) +} + +// ExchangeOIDCSessionRefreshAndAccessToken updates an existing OIDC Session, creates a new access and refresh token. +// It returns the access token id and expiration and the new refresh token. +func (c *Commands) ExchangeOIDCSessionRefreshAndAccessToken(ctx context.Context, oidcSessionID, refreshToken string, scope []string) (tokenID, newRefreshToken string, tokenExpiration time.Time, err error) { + cmd, err := c.newOIDCSessionUpdateEvents(ctx, oidcSessionID, refreshToken) + if err != nil { + return "", "", time.Time{}, err + } + if err = cmd.AddAccessToken(ctx, scope); err != nil { + return "", "", time.Time{}, err + } + if err = cmd.RenewRefreshToken(ctx); err != nil { + return "", "", time.Time{}, err + } + return cmd.PushEvents(ctx) +} + +// OIDCSessionByRefreshToken computes the current state of an existing OIDCSession by a refresh_token (to start a Refresh Token Grant). +// If either the session is not active, the token is invalid or expired (incl. idle expiration) an invalid refresh token error will be returned. +func (c *Commands) OIDCSessionByRefreshToken(ctx context.Context, refreshToken string) (*OIDCSessionWriteModel, error) { + split := strings.Split(refreshToken, ":") + if len(split) != 2 { + return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid") + } + writeModel := NewOIDCSessionWriteModel(split[0], "") + err := c.eventstore.FilterToQueryReducer(ctx, writeModel) + if err != nil { + return nil, caos_errs.ThrowPreconditionFailed(err, "OIDCS-SAF31", "Errors.OIDCSession.RefreshTokenInvalid") + } + if err = writeModel.CheckRefreshToken(split[1]); err != nil { + return nil, err + } + return writeModel, nil +} + +func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, authRequestID string) (*OIDCSessionEvents, error) { + authRequestWriteModel, err := c.getAuthRequestWriteModel(ctx, authRequestID) + if err != nil { + return nil, err + } + if err = authRequestWriteModel.CheckAuthenticated(); err != nil { + return nil, err + } + sessionWriteModel := NewSessionWriteModel(authRequestWriteModel.SessionID, authz.GetCtxData(ctx).OrgID) + err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel) + if err != nil { + return nil, err + } + if sessionWriteModel.State != domain.SessionStateActive { + return nil, caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated") + } + accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx) + if err != nil { + return nil, err + } + sessionID, err := c.idGenerator.Next() + if err != nil { + return nil, err + } + sessionID = IDPrefixV2 + sessionID + return &OIDCSessionEvents{ + eventstore: c.eventstore, + idGenerator: c.idGenerator, + encryptionAlg: c.keyAlgorithm, + oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, authz.GetInstance(ctx).InstanceID()), + sessionWriteModel: sessionWriteModel, + authRequestWriteModel: authRequestWriteModel, + accessTokenLifetime: accessTokenLifetime, + refreshTokenLifeTime: refreshTokenLifeTime, + refreshTokenIdleLifetime: refreshTokenIdleLifetime, + }, nil +} + +func (c *Commands) decryptRefreshToken(refreshToken string) (refreshTokenID string, err error) { + decoded, err := base64.RawURLEncoding.DecodeString(refreshToken) + if err != nil { + return "", err + } + decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID()) + if err != nil { + return "", err + } + split := strings.Split(decrypted, ":") + if len(split) != 2 { + return "", caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid") + } + return split[1], nil +} + +func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, oidcSessionID, refreshToken string) (*OIDCSessionEvents, error) { + refreshTokenID, err := c.decryptRefreshToken(refreshToken) + if err != nil { + return nil, err + } + sessionWriteModel := NewOIDCSessionWriteModel(oidcSessionID, authz.GetInstance(ctx).InstanceID()) + if err = c.eventstore.FilterToQueryReducer(ctx, sessionWriteModel); err != nil { + return nil, err + } + if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil { + return nil, err + } + accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx) + if err != nil { + return nil, err + } + return &OIDCSessionEvents{ + eventstore: c.eventstore, + idGenerator: c.idGenerator, + encryptionAlg: c.keyAlgorithm, + oidcSessionWriteModel: sessionWriteModel, + accessTokenLifetime: accessTokenLifetime, + refreshTokenLifeTime: refreshTokenLifeTime, + refreshTokenIdleLifetime: refreshTokenIdleLifetime, + }, nil +} + +type OIDCSessionEvents struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + encryptionAlg crypto.EncryptionAlgorithm + events []eventstore.Command + oidcSessionWriteModel *OIDCSessionWriteModel + sessionWriteModel *SessionWriteModel + authRequestWriteModel *AuthRequestWriteModel + accessTokenLifetime time.Duration + refreshTokenLifeTime time.Duration + refreshTokenIdleLifetime time.Duration + + // accessTokenID is set by the command + accessTokenID string + + // refreshToken is set by the command + refreshToken string +} + +func (c *OIDCSessionEvents) AddSession(ctx context.Context) { + c.events = append(c.events, oidcsession.NewAddedEvent( + ctx, + c.oidcSessionWriteModel.aggregate, + c.sessionWriteModel.UserID, + c.sessionWriteModel.AggregateID, + c.authRequestWriteModel.ClientID, + c.authRequestWriteModel.Audience, + c.authRequestWriteModel.Scope, + amr.List(c.sessionWriteModel), + c.sessionWriteModel.AuthenticationTime(), + )) +} + +func (c *OIDCSessionEvents) SetAuthRequestSuccessful(ctx context.Context) { + c.events = append(c.events, authrequest.NewSucceededEvent(ctx, c.authRequestWriteModel.aggregate)) +} + +func (c *OIDCSessionEvents) AddAccessToken(ctx context.Context, scope []string) (err error) { + c.accessTokenID, err = c.idGenerator.Next() + if err != nil { + return err + } + c.events = append(c.events, oidcsession.NewAccessTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, c.accessTokenID, scope, c.accessTokenLifetime)) + return nil +} + +func (c *OIDCSessionEvents) AddRefreshToken(ctx context.Context) (err error) { + var refreshTokenID string + refreshTokenID, c.refreshToken, err = c.generateRefreshToken() + if err != nil { + return err + } + c.events = append(c.events, oidcsession.NewRefreshTokenAddedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenLifeTime, c.refreshTokenIdleLifetime)) + return nil +} + +func (c *OIDCSessionEvents) RenewRefreshToken(ctx context.Context) (err error) { + var refreshTokenID string + refreshTokenID, c.refreshToken, err = c.generateRefreshToken() + if err != nil { + return err + } + c.events = append(c.events, oidcsession.NewRefreshTokenRenewedEvent(ctx, c.oidcSessionWriteModel.aggregate, refreshTokenID, c.refreshTokenIdleLifetime)) + return nil +} + +func (c *OIDCSessionEvents) generateRefreshToken() (refreshTokenID, refreshToken string, err error) { + refreshTokenID, err = c.idGenerator.Next() + if err != nil { + return "", "", err + } + token, err := c.encryptionAlg.Encrypt([]byte(c.oidcSessionWriteModel.AggregateID + ":" + refreshTokenID)) + if err != nil { + return "", "", err + } + return refreshTokenID, base64.RawURLEncoding.EncodeToString(token), nil +} + +func (c *OIDCSessionEvents) PushEvents(ctx context.Context) (accessTokenID string, refreshToken string, accessTokenExpiration time.Time, err error) { + pushedEvents, err := c.eventstore.Push(ctx, c.events...) + if err != nil { + return "", "", time.Time{}, err + } + err = AppendAndReduce(c.oidcSessionWriteModel, pushedEvents...) + if err != nil { + return "", "", time.Time{}, err + } + // prefix the returned id with the oidcSessionID so that we can retrieve it later on + // we need to use `-` as a delimiter because the OIDC library uses `:` and will check for a length of 2 parts + return c.oidcSessionWriteModel.AggregateID + "-" + c.accessTokenID, c.refreshToken, c.oidcSessionWriteModel.AccessTokenExpiration, nil +} + +func (c *Commands) tokenTokenLifetimes(ctx context.Context) (accessTokenLifetime time.Duration, refreshTokenLifetime time.Duration, refreshTokenIdleLifetime time.Duration, err error) { + oidcSettings := NewInstanceOIDCSettingsWriteModel(ctx) + err = c.eventstore.FilterToQueryReducer(ctx, oidcSettings) + if err != nil { + return 0, 0, 0, err + } + accessTokenLifetime = c.defaultAccessTokenLifetime + refreshTokenLifetime = c.defaultRefreshTokenLifetime + refreshTokenIdleLifetime = c.defaultRefreshTokenIdleLifetime + if oidcSettings.AccessTokenLifetime > 0 { + accessTokenLifetime = oidcSettings.AccessTokenLifetime + } + if oidcSettings.RefreshTokenExpiration > 0 { + refreshTokenLifetime = oidcSettings.RefreshTokenExpiration + } + if oidcSettings.RefreshTokenIdleExpiration > 0 { + refreshTokenIdleLifetime = oidcSettings.RefreshTokenIdleExpiration + } + return accessTokenLifetime, refreshTokenLifetime, refreshTokenIdleLifetime, nil +} diff --git a/internal/command/oidc_session_model.go b/internal/command/oidc_session_model.go new file mode 100644 index 0000000000..f1c117f2b2 --- /dev/null +++ b/internal/command/oidc_session_model.go @@ -0,0 +1,114 @@ +package command + +import ( + "time" + + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/repository/oidcsession" +) + +type OIDCSessionWriteModel struct { + eventstore.WriteModel + + UserID string + SessionID string + ClientID string + Audience []string + Scope []string + AuthMethodsReferences []string + AuthTime time.Time + State domain.OIDCSessionState + AccessTokenExpiration time.Time + RefreshTokenID string + RefreshTokenExpiration time.Time + RefreshTokenIdleExpiration time.Time + + aggregate *eventstore.Aggregate +} + +func NewOIDCSessionWriteModel(id string, resourceOwner string) *OIDCSessionWriteModel { + return &OIDCSessionWriteModel{ + WriteModel: eventstore.WriteModel{ + AggregateID: id, + ResourceOwner: resourceOwner, + }, + aggregate: &oidcsession.NewAggregate(id, resourceOwner).Aggregate, + } +} + +func (wm *OIDCSessionWriteModel) Reduce() error { + for _, event := range wm.Events { + switch e := event.(type) { + case *oidcsession.AddedEvent: + wm.reduceAdded(e) + case *oidcsession.AccessTokenAddedEvent: + wm.reduceAccessTokenAdded(e) + case *oidcsession.RefreshTokenAddedEvent: + wm.reduceRefreshTokenAdded(e) + case *oidcsession.RefreshTokenRenewedEvent: + wm.reduceRefreshTokenRenewed(e) + } + } + return wm.WriteModel.Reduce() +} + +func (wm *OIDCSessionWriteModel) Query() *eventstore.SearchQueryBuilder { + query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). + AddQuery(). + AggregateTypes(oidcsession.AggregateType). + AggregateIDs(wm.AggregateID). + EventTypes( + oidcsession.AddedType, + oidcsession.AccessTokenAddedType, + oidcsession.RefreshTokenAddedType, + oidcsession.RefreshTokenRenewedType, + ). + Builder() + + if wm.ResourceOwner != "" { + query.ResourceOwner(wm.ResourceOwner) + } + return query +} + +func (wm *OIDCSessionWriteModel) reduceAdded(e *oidcsession.AddedEvent) { + wm.UserID = e.UserID + wm.SessionID = e.SessionID + wm.ClientID = e.ClientID + wm.Audience = e.Audience + wm.Scope = e.Scope + wm.AuthMethodsReferences = e.AuthMethodsReferences + wm.AuthTime = e.AuthTime + wm.State = domain.OIDCSessionStateActive +} + +func (wm *OIDCSessionWriteModel) reduceAccessTokenAdded(e *oidcsession.AccessTokenAddedEvent) { + wm.AccessTokenExpiration = e.CreationDate().Add(e.Lifetime) +} + +func (wm *OIDCSessionWriteModel) reduceRefreshTokenAdded(e *oidcsession.RefreshTokenAddedEvent) { + wm.RefreshTokenID = e.ID + wm.RefreshTokenExpiration = e.CreationDate().Add(e.Lifetime) + wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime) +} + +func (wm *OIDCSessionWriteModel) reduceRefreshTokenRenewed(e *oidcsession.RefreshTokenRenewedEvent) { + wm.RefreshTokenID = e.ID + wm.RefreshTokenIdleExpiration = e.CreationDate().Add(e.IdleLifetime) +} + +func (wm *OIDCSessionWriteModel) CheckRefreshToken(refreshTokenID string) error { + if wm.State != domain.OIDCSessionStateActive { + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid") + } + if wm.RefreshTokenID != refreshTokenID { + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid") + } + now := time.Now() + if wm.RefreshTokenExpiration.Before(now) || wm.RefreshTokenIdleExpiration.Before(now) { + return caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid") + } + return nil +} diff --git a/internal/command/oidc_session_test.go b/internal/command/oidc_session_test.go new file mode 100644 index 0000000000..2a1dfe08af --- /dev/null +++ b/internal/command/oidc_session_test.go @@ -0,0 +1,795 @@ +package command + +import ( + "context" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/oidc/amr" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" + caos_errs "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/id" + "github.com/zitadel/zitadel/internal/id/mock" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/oidcsession" + "github.com/zitadel/zitadel/internal/repository/session" +) + +var ( + testNow = time.Now() + tokenCreationNow = time.Time{} +) + +func TestCommands_AddOIDCSessionAccessToken(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + defaultAccessTokenLifetime time.Duration + defaultRefreshTokenLifetime time.Duration + defaultRefreshTokenIdleLifetime time.Duration + keyAlgorithm crypto.EncryptionAlgorithm + } + type args struct { + ctx context.Context + authRequestID string + } + type res struct { + id string + expiration time.Time + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "unauthenticated error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + ), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + authRequestID: "V2_authRequestID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"), + }, + }, + { + "inactive session error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + ), + eventFromEventPusher( + authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + eventFromEventPusher( + authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + ), + expectFilter(), + ), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + authRequestID: "V2_authRequestID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"), + }, + }, + { + "add successful", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + ), + eventFromEventPusher( + authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + eventFromEventPusher( + authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"), + ), + eventFromEventPusher( + session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + "userID", testNow), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + testNow), + ), + ), + expectFilter(), // token lifetime + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid"}, time.Hour), + ), + eventFromEventPusherWithInstanceID("instanceID", + authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + }, + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID"), + defaultAccessTokenLifetime: time.Hour, + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + authRequestID: "V2_authRequestID", + }, + res{ + id: "V2_oidcSessionID-accessTokenID", + expiration: tokenCreationNow.Add(time.Hour), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime, + defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime, + defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime, + keyAlgorithm: tt.fields.keyAlgorithm, + } + gotID, gotExpiration, err := c.AddOIDCSessionAccessToken(tt.args.ctx, tt.args.authRequestID) + assert.Equal(t, tt.res.id, gotID) + assert.Equal(t, tt.res.expiration, gotExpiration) + assert.ErrorIs(t, err, tt.res.err) + }) + } +} + +func TestCommands_AddOIDCSessionRefreshAndAccessToken(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + defaultAccessTokenLifetime time.Duration + defaultRefreshTokenLifetime time.Duration + defaultRefreshTokenIdleLifetime time.Duration + keyAlgorithm crypto.EncryptionAlgorithm + } + type args struct { + ctx context.Context + authRequestID string + } + type res struct { + id string + refreshToken string + expiration time.Time + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "unauthenticated error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + ), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + authRequestID: "V2_authRequestID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "AUTHR-SF2r2", "Errors.AuthRequest.NotAuthenticated"), + }, + }, + { + "inactive session error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid", "offline_access"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + ), + eventFromEventPusher( + authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + eventFromEventPusher( + authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + ), + expectFilter(), + ), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + authRequestID: "V2_authRequestID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-sjkl3", "Errors.Session.Terminated"), + }, + }, + { + "add successful", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "loginClient", + "clientID", + "redirectURI", + "state", + "nonce", + []string{"openid", "offline_access"}, + []string{"audience"}, + domain.OIDCResponseTypeCode, + &domain.OIDCCodeChallenge{ + Challenge: "challenge", + Method: domain.CodeChallengeMethodS256, + }, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + ), + ), + eventFromEventPusher( + authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate, + "sessionID", + "userID", + testNow, + []string{amr.PWD}, + ), + ), + eventFromEventPusher( + authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + eventFromEventPusher( + authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + ), + expectFilter( + eventFromEventPusher( + session.NewAddedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, "domain.tld"), + ), + eventFromEventPusher( + session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + "userID", testNow), + ), + eventFromEventPusher( + session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate, + testNow), + ), + ), + expectFilter(), // token lifetime + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "offline_access"}, time.Hour), + ), + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + eventFromEventPusherWithInstanceID("instanceID", + authrequest.NewSucceededEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), + ), + }, + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID", "accessTokenID", "refreshTokenID"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + authRequestID: "V2_authRequestID", + }, + res{ + id: "V2_oidcSessionID-accessTokenID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", //V2_oidcSessionID:refreshTokenID + expiration: tokenCreationNow.Add(time.Hour), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime, + defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime, + defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime, + keyAlgorithm: tt.fields.keyAlgorithm, + } + gotID, gotRefreshToken, gotExpiration, err := c.AddOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.authRequestID) + assert.Equal(t, tt.res.id, gotID) + assert.Equal(t, tt.res.refreshToken, gotRefreshToken) + assert.Equal(t, tt.res.expiration, gotExpiration) + assert.ErrorIs(t, err, tt.res.err) + }) + } +} + +func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + defaultAccessTokenLifetime time.Duration + defaultRefreshTokenLifetime time.Duration + defaultRefreshTokenIdleLifetime time.Duration + keyAlgorithm crypto.EncryptionAlgorithm + } + type args struct { + ctx context.Context + oidcSessionID string + refreshToken string + scope []string + } + type res struct { + id string + refreshToken string + expiration time.Time + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "invalid refresh token format error", + fields{ + eventstore: eventstoreExpect(t), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + oidcSessionID: "V2_oidcSessionID", + refreshToken: "aW52YWxpZA", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-Sj3lk", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "inactive session error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + oidcSessionID: "V2_oidcSessionID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "invalid refresh token error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusher( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + oidcSessionID: "V2_oidcSessionID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "expired refresh token error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusher( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + eventFromEventPusher( + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + oidcSessionID: "V2_oidcSessionID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "refresh successful", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusherWithCreationDateNow( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + ), + expectFilter(), // token lifetime + expectPush( + []*repository.Event{ + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "offline_access"}, time.Hour), + ), + eventFromEventPusherWithInstanceID("instanceID", + oidcsession.NewRefreshTokenRenewedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "refreshTokenID2", 24*time.Hour), + ), + }, + ), + ), + idGenerator: mock.NewIDGeneratorExpectIDs(t, "accessTokenID", "refreshTokenID2"), + defaultAccessTokenLifetime: time.Hour, + defaultRefreshTokenLifetime: 7 * 24 * time.Hour, + defaultRefreshTokenIdleLifetime: 24 * time.Hour, + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + oidcSessionID: "V2_oidcSessionID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRA", + scope: []string{"openid", "offline_access"}, + }, + res{ + id: "V2_oidcSessionID-accessTokenID", + refreshToken: "VjJfb2lkY1Nlc3Npb25JRDpyZWZyZXNoVG9rZW5JRDI", + expiration: time.Time{}.Add(time.Hour), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime, + defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime, + defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime, + keyAlgorithm: tt.fields.keyAlgorithm, + } + gotID, gotRefreshToken, gotExpiration, err := c.ExchangeOIDCSessionRefreshAndAccessToken(tt.args.ctx, tt.args.oidcSessionID, tt.args.refreshToken, tt.args.scope) + assert.Equal(t, tt.res.id, gotID) + assert.Equal(t, tt.res.refreshToken, gotRefreshToken) + assert.Equal(t, tt.res.expiration, gotExpiration) + assert.ErrorIs(t, err, tt.res.err) + }) + } +} + +func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { + type fields struct { + eventstore *eventstore.Eventstore + idGenerator id.Generator + defaultAccessTokenLifetime time.Duration + defaultRefreshTokenLifetime time.Duration + defaultRefreshTokenIdleLifetime time.Duration + keyAlgorithm crypto.EncryptionAlgorithm + } + type args struct { + ctx context.Context + refreshToken string + } + type res struct { + model *OIDCSessionWriteModel + err error + } + tests := []struct { + name string + fields fields + args args + res res + }{ + { + "invalid refresh token format error", + fields{ + eventstore: eventstoreExpect(t), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + refreshToken: "invalid", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-JOI23", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "inactive session error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter(), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + refreshToken: "V2_oidcSessionID:refreshTokenID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-s3hjk", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "invalid refresh token error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusher( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + refreshToken: "V2_oidcSessionID:refreshTokenID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-28ubl", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "expired refresh token error", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusher( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusher( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + eventFromEventPusher( + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + refreshToken: "V2_oidcSessionID:refreshTokenID", + }, + res{ + err: caos_errs.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), + }, + }, + { + "get successful", + fields{ + eventstore: eventstoreExpect(t, + expectFilter( + eventFromEventPusherWithCreationDateNow( + oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "userID", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"}, []string{amr.PWD}, testNow), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour), + ), + eventFromEventPusherWithCreationDateNow( + oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "instanceID").Aggregate, + "refreshTokenID", 7*24*time.Hour, 24*time.Hour), + ), + ), + ), + keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), + }, + args{ + ctx: authz.WithInstanceID(context.Background(), "instanceID"), + refreshToken: "V2_oidcSessionID:refreshTokenID", + }, + res{ + model: &OIDCSessionWriteModel{ + WriteModel: eventstore.WriteModel{ + AggregateID: "V2_oidcSessionID", + ChangeDate: testNow, + }, + UserID: "userID", + SessionID: "sessionID", + ClientID: "clientID", + Audience: []string{"audience"}, + Scope: []string{"openid", "profile", "offline_access"}, + AuthMethodsReferences: []string{amr.PWD}, + AuthTime: testNow, + State: domain.OIDCSessionStateActive, + RefreshTokenID: "refreshTokenID", + RefreshTokenExpiration: testNow.Add(7 * 24 * time.Hour), + RefreshTokenIdleExpiration: testNow.Add(24 * time.Hour), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Commands{ + eventstore: tt.fields.eventstore, + idGenerator: tt.fields.idGenerator, + defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime, + defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime, + defaultRefreshTokenIdleLifetime: tt.fields.defaultRefreshTokenIdleLifetime, + keyAlgorithm: tt.fields.keyAlgorithm, + } + got, err := c.OIDCSessionByRefreshToken(tt.args.ctx, tt.args.refreshToken) + require.ErrorIs(t, err, tt.res.err) + if tt.res.err == nil { + assert.WithinRange(t, got.ChangeDate, tt.res.model.ChangeDate.Add(-2*time.Second), tt.res.model.ChangeDate.Add(2*time.Second)) + assert.Equal(t, tt.res.model.AggregateID, got.AggregateID) + assert.Equal(t, tt.res.model.UserID, got.UserID) + assert.Equal(t, tt.res.model.SessionID, got.SessionID) + assert.Equal(t, tt.res.model.ClientID, got.ClientID) + assert.Equal(t, tt.res.model.Audience, got.Audience) + assert.Equal(t, tt.res.model.Scope, got.Scope) + assert.Equal(t, tt.res.model.AuthMethodsReferences, got.AuthMethodsReferences) + assert.WithinRange(t, got.AuthTime, tt.res.model.AuthTime.Add(-2*time.Second), tt.res.model.AuthTime.Add(2*time.Second)) + assert.Equal(t, tt.res.model.State, got.State) + assert.Equal(t, tt.res.model.RefreshTokenID, got.RefreshTokenID) + assert.WithinRange(t, got.RefreshTokenExpiration, tt.res.model.RefreshTokenExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenExpiration.Add(2*time.Second)) + assert.WithinRange(t, got.RefreshTokenIdleExpiration, tt.res.model.RefreshTokenIdleExpiration.Add(-2*time.Second), tt.res.model.RefreshTokenIdleExpiration.Add(2*time.Second)) + } + }) + } +} diff --git a/internal/command/session_model.go b/internal/command/session_model.go index aed5426976..2779c7dc37 100644 --- a/internal/command/session_model.go +++ b/internal/command/session_model.go @@ -52,6 +52,24 @@ type SessionWriteModel struct { aggregate *eventstore.Aggregate } +func (wm *SessionWriteModel) IsPasswordChecked() bool { + return !wm.PasswordCheckedAt.IsZero() +} + +func (wm *SessionWriteModel) IsPasskeyChecked() bool { + return !wm.PasskeyCheckedAt.IsZero() +} + +func (wm *SessionWriteModel) IsU2FChecked() bool { + // TODO: implement with https://github.com/zitadel/zitadel/issues/5477 + return false +} + +func (wm *SessionWriteModel) IsOTPChecked() bool { + // TODO: implement with https://github.com/zitadel/zitadel/issues/5477 + return false +} + func NewSessionWriteModel(sessionID string, resourceOwner string) *SessionWriteModel { return &SessionWriteModel{ WriteModel: eventstore.WriteModel{ @@ -210,3 +228,19 @@ func (wm *SessionWriteModel) ChangeMetadata(ctx context.Context, metadata map[st wm.commands = append(wm.commands, session.NewMetadataSetEvent(ctx, wm.aggregate, wm.Metadata)) } } + +// AuthenticationTime returns the time the user authenticated using the latest time of all checks +func (wm *SessionWriteModel) AuthenticationTime() time.Time { + var authTime time.Time + for _, check := range []time.Time{ + wm.PasswordCheckedAt, + wm.PasskeyCheckedAt, + wm.IntentCheckedAt, + // TODO: add U2F and OTP check https://github.com/zitadel/zitadel/issues/5477 + } { + if check.After(authTime) { + authTime = check + } + } + return authTime +} diff --git a/internal/command/user_human_otp_test.go b/internal/command/user_human_otp_test.go index e11f771e6f..ae44489aee 100644 --- a/internal/command/user_human_otp_test.go +++ b/internal/command/user_human_otp_test.go @@ -528,7 +528,7 @@ func TestCommands_createHumanOTP(t *testing.T) { } func TestCommands_HumanCheckMFAOTPSetup(t *testing.T) { - ctx := authz.NewMockContext("inst1", "org1", "user1") + ctx := authz.NewMockContext("", "org1", "user1") cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t)) key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg) diff --git a/internal/command/user_v2_totp_test.go b/internal/command/user_v2_totp_test.go index d7b9f82463..90596428ae 100644 --- a/internal/command/user_v2_totp_test.go +++ b/internal/command/user_v2_totp_test.go @@ -188,7 +188,7 @@ func TestCommands_AddUserTOTP(t *testing.T) { } func TestCommands_CheckUserTOTP(t *testing.T) { - ctx := authz.NewMockContext("inst1", "org1", "user1") + ctx := authz.NewMockContext("", "org1", "user1") cryptoAlg := crypto.CreateMockEncryptionAlg(gomock.NewController(t)) key, secret, err := domain.NewOTPKey("example.com", "user1", cryptoAlg) diff --git a/internal/domain/auth_request.go b/internal/domain/auth_request.go index c0b9e1a40a..765fead8a9 100644 --- a/internal/domain/auth_request.go +++ b/internal/domain/auth_request.go @@ -116,6 +116,17 @@ const ( MFALevelMultiFactorCertified ) +type AuthRequestState int + +const ( + AuthRequestStateUnspecified AuthRequestState = iota + AuthRequestStateAdded + AuthRequestStateCodeAdded + AuthRequestStateCodeExchanged + AuthRequestStateFailed + AuthRequestStateSucceeded +) + func NewAuthRequestFromType(requestType AuthRequestType) (*AuthRequest, error) { switch requestType { case AuthRequestTypeOIDC: diff --git a/internal/domain/oidc_error_reason.go b/internal/domain/oidc_error_reason.go new file mode 100644 index 0000000000..5a4c8d2c7a --- /dev/null +++ b/internal/domain/oidc_error_reason.go @@ -0,0 +1,23 @@ +package domain + +type OIDCErrorReason int32 + +const ( + OIDCErrorReasonUnspecified OIDCErrorReason = iota + OIDCErrorReasonInvalidRequest + OIDCErrorReasonUnauthorizedClient + OIDCErrorReasonAccessDenied + OIDCErrorReasonUnsupportedResponseType + OIDCErrorReasonInvalidScope + OIDCErrorReasonServerError + OIDCErrorReasonTemporaryUnavailable + OIDCErrorReasonInteractionRequired + OIDCErrorReasonLoginRequired + OIDCErrorReasonAccountSelectionRequired + OIDCErrorReasonConsentRequired + OIDCErrorReasonInvalidRequestURI + OIDCErrorReasonInvalidRequestObject + OIDCErrorReasonRequestNotSupported + OIDCErrorReasonRequestURINotSupported + OIDCErrorReasonRegistrationNotSupported +) diff --git a/internal/domain/oidc_session.go b/internal/domain/oidc_session.go new file mode 100644 index 0000000000..ad09f10737 --- /dev/null +++ b/internal/domain/oidc_session.go @@ -0,0 +1,9 @@ +package domain + +type OIDCSessionState int32 + +const ( + OIDCSessionStateUnspecified OIDCSessionState = iota + OIDCSessionStateActive + OIDCSessionStateTerminated +) diff --git a/internal/integration/client.go b/internal/integration/client.go index fecfecac44..d4849a0fad 100644 --- a/internal/integration/client.go +++ b/internal/integration/client.go @@ -19,6 +19,7 @@ import ( "github.com/zitadel/zitadel/pkg/grpc/admin" mgmt "github.com/zitadel/zitadel/pkg/grpc/management" object "github.com/zitadel/zitadel/pkg/grpc/object/v2alpha" + oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha" session "github.com/zitadel/zitadel/pkg/grpc/session/v2alpha" "github.com/zitadel/zitadel/pkg/grpc/system" user "github.com/zitadel/zitadel/pkg/grpc/user/v2alpha" @@ -30,6 +31,7 @@ type Client struct { Mgmt mgmt.ManagementServiceClient UserV2 user.UserServiceClient SessionV2 session.SessionServiceClient + OIDCv2 oidc_pb.OIDCServiceClient System system.SystemServiceClient } @@ -40,6 +42,7 @@ func newClient(cc *grpc.ClientConn) Client { Mgmt: mgmt.NewManagementServiceClient(cc), UserV2: user.NewUserServiceClient(cc), SessionV2: session.NewSessionServiceClient(cc), + OIDCv2: oidc_pb.NewOIDCServiceClient(cc), System: system.NewSystemServiceClient(cc), } } @@ -62,11 +65,9 @@ func (t *Tester) UseIsolatedInstance(iamOwnerCtx, systemCtx context.Context) (pr } t.createClientConn(iamOwnerCtx, grpc.WithAuthority(primaryDomain)) instanceId = instance.GetInstanceId() - t.Users[instanceId] = map[UserType]User{ - IAMOwner: { - Token: instance.GetPat(), - }, - } + t.Users.Set(instanceId, IAMOwner, &User{ + Token: instance.GetPat(), + }) return primaryDomain, instanceId, t.WithInstanceAuthorization(iamOwnerCtx, IAMOwner, instanceId) } @@ -187,3 +188,34 @@ func (s *Tester) CreateSuccessfulIntent(t *testing.T, idpID, userID, idpUserID s require.NoError(t, err) return intentID, token, writeModel.ChangeDate, writeModel.ProcessedSequence } + +func (s *Tester) CreatePasskeySession(t *testing.T, ctx context.Context, userID string) (id, token string, start, change time.Time) { + createResp, err := s.Client.SessionV2.CreateSession(ctx, &session.CreateSessionRequest{ + Checks: &session.Checks{ + User: &session.CheckUser{ + Search: &session.CheckUser_UserId{UserId: userID}, + }, + }, + Challenges: []session.ChallengeKind{ + session.ChallengeKind_CHALLENGE_KIND_PASSKEY, + }, + Domain: s.Config.ExternalDomain, + }) + require.NoError(t, err) + + assertion, err := s.WebAuthN.CreateAssertionResponse(createResp.GetChallenges().GetPasskey().GetPublicKeyCredentialRequestOptions()) + require.NoError(t, err) + + updateResp, err := s.Client.SessionV2.SetSession(ctx, &session.SetSessionRequest{ + SessionId: createResp.GetSessionId(), + SessionToken: createResp.GetSessionToken(), + Checks: &session.Checks{ + Passkey: &session.CheckPasskey{ + CredentialAssertionData: assertion, + }, + }, + }) + require.NoError(t, err) + return createResp.GetSessionId(), updateResp.GetSessionToken(), + createResp.GetDetails().GetChangeDate().AsTime(), updateResp.GetDetails().GetChangeDate().AsTime() +} diff --git a/internal/integration/config/zitadel.yaml b/internal/integration/config/zitadel.yaml index ba52818bce..14222e263d 100644 --- a/internal/integration/config/zitadel.yaml +++ b/internal/integration/config/zitadel.yaml @@ -1,6 +1,8 @@ Log: Level: debug +ExternalSecure: false + TLS: Enabled: false diff --git a/internal/integration/integration.go b/internal/integration/integration.go index 4fec3ca523..fe5e081efe 100644 --- a/internal/integration/integration.go +++ b/internal/integration/integration.go @@ -57,6 +57,7 @@ type UserType int const ( Unspecified UserType = iota OrgOwner + Login IAMOwner SystemUser // SystemUser is a user with access to the system service. ) @@ -71,13 +72,29 @@ type User struct { Token string } +type InstanceUserMap map[string]map[UserType]*User + +func (m InstanceUserMap) Set(instanceID string, typ UserType, user *User) { + if m[instanceID] == nil { + m[instanceID] = make(map[UserType]*User) + } + m[instanceID][typ] = user +} + +func (m InstanceUserMap) Get(instanceID string, typ UserType) *User { + if users, ok := m[instanceID]; ok { + return users[typ] + } + return nil +} + // Tester is a Zitadel server and client with all resources available for testing. type Tester struct { *start.Server Instance authz.Instance Organisation *query.Org - Users map[string]map[UserType]User + Users InstanceUserMap Client Client WebAuthN *webauthn.Client @@ -135,6 +152,7 @@ func (s *Tester) pollHealth(ctx context.Context) (err error) { } const ( + LoginUser = "loginClient" MachineUser = "integration" ) @@ -148,10 +166,9 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) { s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID()) logging.OnError(err).Fatal("query organisation") - query, err := query.NewUserUsernameSearchQuery(MachineUser, query.TextEquals) + usernameQuery, err := query.NewUserUsernameSearchQuery(MachineUser, query.TextEquals) logging.OnError(err).Fatal("user query") - user, err := s.Queries.GetUser(ctx, true, true, query) - + user, err := s.Queries.GetUser(ctx, true, true, usernameQuery) if errors.Is(err, sql.ErrNoRows) { _, err = s.Commands.AddMachine(ctx, &command.Machine{ ObjectRoot: models.ObjectRoot{ @@ -162,11 +179,10 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) { Description: "who cares?", AccessTokenType: domain.OIDCTokenTypeJWT, }) - logging.OnError(err).Fatal("add machine user") - user, err = s.Queries.GetUser(ctx, true, true, query) - + logging.WithFields("username", SystemUser).OnError(err).Fatal("add machine user") + user, err = s.Queries.GetUser(ctx, true, true, usernameQuery) } - logging.OnError(err).Fatal("get user") + logging.WithFields("username", SystemUser).OnError(err).Fatal("get user") _, err = s.Commands.AddOrgMember(ctx, s.Organisation.ID, user.ID, "ORG_OWNER") target := new(caos_errs.AlreadyExistsError) @@ -177,18 +193,50 @@ func (s *Tester) createMachineUser(ctx context.Context, instanceId string) { scopes := []string{oidc.ScopeOpenID, oidc.ScopeProfile, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner} pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine) _, err = s.Commands.AddPersonalAccessToken(ctx, pat) - logging.OnError(err).Fatal("add pat") - - if s.Users == nil { - s.Users = make(map[string]map[UserType]User) - } - if s.Users[instanceId] == nil { - s.Users[instanceId] = make(map[UserType]User) - } - s.Users[instanceId][OrgOwner] = User{ + logging.WithFields("username", SystemUser).OnError(err).Fatal("add pat") + s.Users.Set(instanceId, OrgOwner, &User{ User: user, Token: pat.Token, + }) +} + +func (s *Tester) createLoginClient(ctx context.Context) { + var err error + + s.Instance, err = s.Queries.InstanceByHost(ctx, s.Host()) + logging.OnError(err).Fatal("query instance") + ctx = authz.WithInstance(ctx, s.Instance) + + s.Organisation, err = s.Queries.OrgByID(ctx, true, s.Instance.DefaultOrganisationID()) + logging.OnError(err).Fatal("query organisation") + + usernameQuery, err := query.NewUserUsernameSearchQuery(LoginUser, query.TextEquals) + logging.WithFields("username", LoginUser).OnError(err).Fatal("user query") + user, err := s.Queries.GetUser(ctx, true, true, usernameQuery) + if errors.Is(err, sql.ErrNoRows) { + _, err = s.Commands.AddMachine(ctx, &command.Machine{ + ObjectRoot: models.ObjectRoot{ + ResourceOwner: s.Organisation.ID, + }, + Username: LoginUser, + Name: LoginUser, + Description: "who cares?", + AccessTokenType: domain.OIDCTokenTypeJWT, + }) + logging.WithFields("username", LoginUser).OnError(err).Fatal("add machine user") + user, err = s.Queries.GetUser(ctx, true, true, usernameQuery) } + logging.WithFields("username", LoginUser).OnError(err).Fatal("get user") + + scopes := []string{oidc.ScopeOpenID, z_oidc.ScopeUserMetaData, z_oidc.ScopeResourceOwner} + pat := command.NewPersonalAccessToken(user.ResourceOwner, user.ID, time.Now().Add(time.Hour), scopes, domain.UserTypeMachine) + _, err = s.Commands.AddPersonalAccessToken(ctx, pat) + logging.OnError(err).Fatal("add pat") + + s.Users.Set(FirstInstanceUsersKey, Login, &User{ + User: user, + Token: pat.Token, + }) } func (s *Tester) WithAuthorization(ctx context.Context, u UserType) context.Context { @@ -199,15 +247,12 @@ func (s *Tester) WithInstanceAuthorization(ctx context.Context, u UserType, inst if u == SystemUser { s.ensureSystemUser() } - return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users[instanceID][u].Token)) + return metadata.AppendToOutgoingContext(ctx, "Authorization", fmt.Sprintf("Bearer %s", s.Users.Get(instanceID, u).Token)) } func (s *Tester) ensureSystemUser() { const ISSUER = "tester" - if s.Users[FirstInstanceUsersKey] == nil { - s.Users[FirstInstanceUsersKey] = make(map[UserType]User) - } - if _, ok := s.Users[FirstInstanceUsersKey][SystemUser]; ok { + if s.Users.Get(FirstInstanceUsersKey, SystemUser) != nil { return } audience := http_util.BuildOrigin(s.Host(), s.Server.Config.ExternalSecure) @@ -215,7 +260,11 @@ func (s *Tester) ensureSystemUser() { logging.OnError(err).Fatal("system key signer") jwt, err := client.SignedJWTProfileAssertion(ISSUER, []string{audience}, time.Hour, signer) logging.OnError(err).Fatal("system key jwt") - s.Users[FirstInstanceUsersKey][SystemUser] = User{Token: jwt} + s.Users.Set(FirstInstanceUsersKey, SystemUser, &User{Token: jwt}) +} + +func (s *Tester) WithSystemAuthorizationHTTP(u UserType) map[string]string { + return map[string]string{"Authorization": fmt.Sprintf("Bearer %s", s.Users.Get(FirstInstanceUsersKey, u).Token)} } // Done send an interrupt signal to cleanly shutdown the server. @@ -263,9 +312,7 @@ func NewTester(ctx context.Context) *Tester { logging.OnError(err).Fatal() tester := Tester{ - Users: map[string]map[UserType]User{ - FirstInstanceUsersKey: make(map[UserType]User), - }, + Users: make(InstanceUserMap), } tester.wg.Add(1) go func(wg *sync.WaitGroup) { @@ -279,6 +326,8 @@ func NewTester(ctx context.Context) *Tester { logging.OnError(ctx.Err()).Fatal("waiting for integration tester server") } tester.createClientConn(ctx) + tester.createLoginClient(ctx) + tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, http_util.BuildOrigin(tester.Host(), tester.Config.ExternalSecure)) tester.createMachineUser(ctx, FirstInstanceUsersKey) tester.WebAuthN = webauthn.NewClient(tester.Config.WebAuthNName, tester.Config.ExternalDomain, "https://"+tester.Host()) diff --git a/internal/integration/oidc.go b/internal/integration/oidc.go new file mode 100644 index 0000000000..42f2ab174c --- /dev/null +++ b/internal/integration/oidc.go @@ -0,0 +1,163 @@ +package integration + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/zitadel/oidc/v2/pkg/client/rp" + "github.com/zitadel/oidc/v2/pkg/oidc" + + http_util "github.com/zitadel/zitadel/internal/api/http" + oidc_internal "github.com/zitadel/zitadel/internal/api/oidc" + "github.com/zitadel/zitadel/internal/command" + "github.com/zitadel/zitadel/pkg/grpc/app" + "github.com/zitadel/zitadel/pkg/grpc/management" +) + +func (s *Tester) CreateOIDCNativeClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) { + project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{ + Name: fmt.Sprintf("project-%d", time.Now().UnixNano()), + }) + if err != nil { + return nil, err + } + return s.Client.Mgmt.AddOIDCApp(ctx, &management.AddOIDCAppRequest{ + ProjectId: project.GetId(), + Name: fmt.Sprintf("app-%d", time.Now().UnixNano()), + RedirectUris: []string{redirectURI}, + ResponseTypes: []app.OIDCResponseType{app.OIDCResponseType_OIDC_RESPONSE_TYPE_CODE}, + GrantTypes: []app.OIDCGrantType{app.OIDCGrantType_OIDC_GRANT_TYPE_AUTHORIZATION_CODE, app.OIDCGrantType_OIDC_GRANT_TYPE_REFRESH_TOKEN}, + AppType: app.OIDCAppType_OIDC_APP_TYPE_NATIVE, + AuthMethodType: app.OIDCAuthMethodType_OIDC_AUTH_METHOD_TYPE_NONE, + PostLogoutRedirectUris: nil, + Version: app.OIDCVersion_OIDC_VERSION_1_0, + DevMode: false, + AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT, + AccessTokenRoleAssertion: false, + IdTokenRoleAssertion: false, + IdTokenUserinfoAssertion: false, + ClockSkew: nil, + AdditionalOrigins: nil, + SkipNativeAppSuccessPage: false, + }) +} + +func (s *Tester) CreateOIDCImplicitFlowClient(ctx context.Context, redirectURI string) (*management.AddOIDCAppResponse, error) { + project, err := s.Client.Mgmt.AddProject(ctx, &management.AddProjectRequest{ + Name: fmt.Sprintf("project-%d", time.Now().UnixNano()), + }) + if err != nil { + return nil, err + } + return s.Client.Mgmt.AddOIDCApp(ctx, &management.AddOIDCAppRequest{ + ProjectId: project.GetId(), + Name: fmt.Sprintf("app-%d", time.Now().UnixNano()), + RedirectUris: []string{redirectURI}, + ResponseTypes: []app.OIDCResponseType{app.OIDCResponseType_OIDC_RESPONSE_TYPE_ID_TOKEN_TOKEN}, + GrantTypes: []app.OIDCGrantType{app.OIDCGrantType_OIDC_GRANT_TYPE_IMPLICIT}, + AppType: app.OIDCAppType_OIDC_APP_TYPE_USER_AGENT, + AuthMethodType: app.OIDCAuthMethodType_OIDC_AUTH_METHOD_TYPE_NONE, + PostLogoutRedirectUris: nil, + Version: app.OIDCVersion_OIDC_VERSION_1_0, + DevMode: true, + AccessTokenType: app.OIDCTokenType_OIDC_TOKEN_TYPE_JWT, + AccessTokenRoleAssertion: false, + IdTokenRoleAssertion: false, + IdTokenUserinfoAssertion: false, + ClockSkew: nil, + AdditionalOrigins: nil, + SkipNativeAppSuccessPage: false, + }) +} + +func (s *Tester) CreateOIDCAuthRequest(clientID, loginClient, redirectURI string, scope ...string) (authRequestID string, err error) { + provider, err := s.CreateRelyingParty(clientID, redirectURI, scope...) + if err != nil { + return "", err + } + + codeVerifier := "codeVerifier" + codeChallenge := oidc.NewSHACodeChallenge(codeVerifier) + authURL := rp.AuthURL("state", provider, rp.WithCodeChallenge(codeChallenge)) + + loc, err := CheckRedirect(authURL, map[string]string{oidc_internal.LoginClientHeader: loginClient}) + if err != nil { + return "", err + } + + prefixWithHost := provider.Issuer() + s.Config.OIDC.DefaultLoginURLV2 + if !strings.HasPrefix(loc.String(), prefixWithHost) { + return "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String()) + } + return strings.TrimPrefix(loc.String(), prefixWithHost), nil +} + +func (s *Tester) CreateOIDCAuthRequestImplicit(clientID, loginClient, redirectURI string, scope ...string) (authRequestID string, err error) { + provider, err := s.CreateRelyingParty(clientID, redirectURI, scope...) + if err != nil { + return "", err + } + + authURL := rp.AuthURL("state", provider) + + // implicit is not natively supported so let's just overwrite the response type + parsed, _ := url.Parse(authURL) + queries := parsed.Query() + queries.Set("response_type", string(oidc.ResponseTypeIDToken)) + parsed.RawQuery = queries.Encode() + authURL = parsed.String() + + loc, err := CheckRedirect(authURL, map[string]string{oidc_internal.LoginClientHeader: loginClient}) + if err != nil { + return "", err + } + + prefixWithHost := provider.Issuer() + s.Config.OIDC.DefaultLoginURLV2 + if !strings.HasPrefix(loc.String(), prefixWithHost) { + return "", fmt.Errorf("login location has not prefix %s, but is %s", prefixWithHost, loc.String()) + } + return strings.TrimPrefix(loc.String(), prefixWithHost), nil +} + +func (s *Tester) CreateRelyingParty(clientID, redirectURI string, scope ...string) (rp.RelyingParty, error) { + issuer := http_util.BuildHTTP(s.Config.ExternalDomain, s.Config.Port, s.Config.ExternalSecure) + if len(scope) == 0 { + scope = []string{oidc.ScopeOpenID} + } + return rp.NewRelyingPartyOIDC(issuer, clientID, "", redirectURI, scope) +} + +func CheckRedirect(url string, headers map[string]string) (*url.URL, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } + for key, value := range headers { + req.Header.Set(key, value) + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return resp.Location() +} + +func (s *Tester) CreateSession(ctx context.Context, userID string) (string, string, error) { + session, err := s.Commands.CreateSession(ctx, []command.SessionCommand{command.CheckUser(userID)}, "domain.tld", nil) + if err != nil { + return "", "", err + } + return session.ID, session.NewToken, nil +} diff --git a/internal/query/auth_request.go b/internal/query/auth_request.go new file mode 100644 index 0000000000..ee91bdebea --- /dev/null +++ b/internal/query/auth_request.go @@ -0,0 +1,88 @@ +package query + +import ( + "context" + "database/sql" + _ "embed" + errs "errors" + "fmt" + "time" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +type AuthRequest struct { + ID string + CreationDate time.Time + LoginClient string + ClientID string + Scope []string + RedirectURI string + Prompt []domain.Prompt + UiLocales []string + LoginHint *string + MaxAge *time.Duration + HintUserID *string +} + +func (a *AuthRequest) checkLoginClient(ctx context.Context) error { + if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient { + return errors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient") + } + return nil +} + +//go:embed embed/auth_request_by_id.sql +var authRequestByIDQuery string + +func (q *Queries) authRequestByIDQuery(ctx context.Context) string { + return fmt.Sprintf(authRequestByIDQuery, q.client.Timetravel(call.Took(ctx))) +} + +func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *AuthRequest, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + if shouldTriggerBulk { + ctx = projection.AuthRequestProjection.Trigger(ctx) + } + + var ( + scope database.StringArray + prompt database.EnumArray[domain.Prompt] + locales database.StringArray + ) + + dst := new(AuthRequest) + err = q.client.DB.QueryRowContext( + ctx, q.authRequestByIDQuery(ctx), + id, authz.GetInstance(ctx).InstanceID(), + ).Scan( + &dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI, + &prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID, + ) + if errs.Is(err, sql.ErrNoRows) { + return nil, errors.ThrowNotFound(err, "QUERY-Thee9", "Errors.AuthRequest.NotExisting") + } + if err != nil { + return nil, errors.ThrowInternal(err, "QUERY-Ou8ue", "Errors.Internal") + } + + dst.Scope = scope + dst.Prompt = prompt + dst.UiLocales = locales + + if checkLoginClient { + if err = dst.checkLoginClient(ctx); err != nil { + return nil, err + } + } + + return dst, nil +} diff --git a/internal/query/auth_request_test.go b/internal/query/auth_request_test.go new file mode 100644 index 0000000000..7348855e75 --- /dev/null +++ b/internal/query/auth_request_test.go @@ -0,0 +1,180 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + _ "embed" + "fmt" + "regexp" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/query/projection" +) + +func TestQueries_AuthRequestByID(t *testing.T) { + expQuery := regexp.QuoteMeta(fmt.Sprintf( + authRequestByIDQuery, + asOfSystemTime, + )) + + cols := []string{ + projection.AuthRequestColumnID, + projection.AuthRequestColumnCreationDate, + projection.AuthRequestColumnLoginClient, + projection.AuthRequestColumnClientID, + projection.AuthRequestColumnScope, + projection.AuthRequestColumnRedirectURI, + projection.AuthRequestColumnPrompt, + projection.AuthRequestColumnUILocales, + projection.AuthRequestColumnLoginHint, + projection.AuthRequestColumnMaxAge, + projection.AuthRequestColumnHintUserID, + } + type args struct { + shouldTriggerBulk bool + id string + checkLoginClient bool + } + tests := []struct { + name string + args args + expect sqlExpectation + want *AuthRequest + wantErr error + }{ + { + name: "success, all values", + args: args{ + shouldTriggerBulk: false, + id: "123", + checkLoginClient: true, + }, + expect: mockQuery(expQuery, cols, []driver.Value{ + "id", + testNow, + "loginClient", + "clientID", + database.StringArray{"a", "b", "c"}, + "example.com", + database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent}, + database.StringArray{"en", "fi"}, + "me@example.com", + int64(time.Minute), + "userID", + }, "123", "instanceID"), + want: &AuthRequest{ + ID: "id", + CreationDate: testNow, + LoginClient: "loginClient", + ClientID: "clientID", + Scope: []string{"a", "b", "c"}, + RedirectURI: "example.com", + Prompt: []domain.Prompt{domain.PromptLogin, domain.PromptConsent}, + UiLocales: []string{"en", "fi"}, + LoginHint: gu.Ptr("me@example.com"), + MaxAge: gu.Ptr(time.Minute), + HintUserID: gu.Ptr("userID"), + }, + }, + { + name: "success, null values", + args: args{ + shouldTriggerBulk: false, + id: "123", + checkLoginClient: true, + }, + expect: mockQuery(expQuery, cols, []driver.Value{ + "id", + testNow, + "loginClient", + "clientID", + database.StringArray{"a", "b", "c"}, + "example.com", + database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent}, + database.StringArray{"en", "fi"}, + sql.NullString{}, + sql.NullInt64{}, + sql.NullString{}, + }, "123", "instanceID"), + want: &AuthRequest{ + ID: "id", + CreationDate: testNow, + LoginClient: "loginClient", + ClientID: "clientID", + Scope: []string{"a", "b", "c"}, + RedirectURI: "example.com", + Prompt: []domain.Prompt{domain.PromptLogin, domain.PromptConsent}, + UiLocales: []string{"en", "fi"}, + LoginHint: nil, + MaxAge: nil, + HintUserID: nil, + }, + }, + { + name: "no rows", + args: args{ + shouldTriggerBulk: false, + id: "123", + }, + expect: mockQuery(expQuery, cols, nil, "123", "instanceID"), + wantErr: errors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.AuthRequest.NotExisting"), + }, + { + name: "query error", + args: args{ + shouldTriggerBulk: false, + id: "123", + }, + expect: mockQueryErr(expQuery, sql.ErrConnDone, "123", "instanceID"), + wantErr: errors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"), + }, + { + name: "wrong login client", + args: args{ + shouldTriggerBulk: false, + id: "123", + checkLoginClient: true, + }, + expect: mockQuery(expQuery, cols, []driver.Value{ + "id", + testNow, + "wrongLoginClient", + "clientID", + database.StringArray{"a", "b", "c"}, + "example.com", + database.EnumArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent}, + database.StringArray{"en", "fi"}, + sql.NullString{}, + sql.NullInt64{}, + sql.NullString{}, + }, "123", "instanceID"), + wantErr: errors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + execMock(t, tt.expect, func(db *sql.DB) { + q := &Queries{ + client: &database.DB{ + DB: db, + Database: &prepareDB{}, + }, + } + ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") + + got, err := q.AuthRequestByID(ctx, tt.args.shouldTriggerBulk, tt.args.id, tt.args.checkLoginClient) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + }) + }) + } +} diff --git a/internal/query/embed/auth_request_by_id.sql b/internal/query/embed/auth_request_by_id.sql new file mode 100644 index 0000000000..ffc18fccd6 --- /dev/null +++ b/internal/query/embed/auth_request_by_id.sql @@ -0,0 +1,15 @@ +select + id, + creation_date, + login_client, + client_id, + scope, + redirect_uri, + prompt, + ui_locales, + login_hint, + max_age, + hint_user_id +from projections.auth_requests %s +where id = $1 and instance_id = $2 +limit 1; diff --git a/internal/query/prepare_test.go b/internal/query/prepare_test.go index db309805ca..cb69fb21e5 100644 --- a/internal/query/prepare_test.go +++ b/internal/query/prepare_test.go @@ -14,6 +14,7 @@ import ( "github.com/DATA-DOG/go-sqlmock" sq "github.com/Masterminds/squirrel" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -74,9 +75,9 @@ type checkErr func(error) (err error, ok bool) type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock -func mockQuery(stmt string, cols []string, row []driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { +func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { - q := m.ExpectQuery(stmt) + q := m.ExpectQuery(stmt).WithArgs(args...) result := sqlmock.NewRows(cols) if len(row) > 0 { result.AddRow(row...) @@ -111,6 +112,15 @@ func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.S } } +func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + mock = exp(mock) + run(db) + assert.NoError(t, mock.ExpectationsWereMet()) +} + var ( rowType = reflect.TypeOf(&sql.Row{}) rowsType = reflect.TypeOf(&sql.Rows{}) @@ -317,7 +327,9 @@ func TestValidatePrepare(t *testing.T) { type prepareDB struct{} -func (_ *prepareDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " } +const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' " + +func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime } var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))} diff --git a/internal/query/projection/auth_request.go b/internal/query/projection/auth_request.go new file mode 100644 index 0000000000..b01ce175ae --- /dev/null +++ b/internal/query/projection/auth_request.go @@ -0,0 +1,142 @@ +package projection + +import ( + "context" + + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/handler/crdb" + "github.com/zitadel/zitadel/internal/repository/authrequest" + "github.com/zitadel/zitadel/internal/repository/instance" +) + +const ( + AuthRequestsProjectionTable = "projections.auth_requests" + + AuthRequestColumnID = "id" + AuthRequestColumnCreationDate = "creation_date" + AuthRequestColumnChangeDate = "change_date" + AuthRequestColumnSequence = "sequence" + AuthRequestColumnResourceOwner = "resource_owner" + AuthRequestColumnInstanceID = "instance_id" + AuthRequestColumnLoginClient = "login_client" + AuthRequestColumnClientID = "client_id" + AuthRequestColumnRedirectURI = "redirect_uri" + AuthRequestColumnScope = "scope" + AuthRequestColumnPrompt = "prompt" + AuthRequestColumnUILocales = "ui_locales" + AuthRequestColumnMaxAge = "max_age" + AuthRequestColumnLoginHint = "login_hint" + AuthRequestColumnHintUserID = "hint_user_id" +) + +type authRequestProjection struct { + crdb.StatementHandler +} + +func newAuthRequestProjection(ctx context.Context, config crdb.StatementHandlerConfig) *authRequestProjection { + p := new(authRequestProjection) + config.ProjectionName = AuthRequestsProjectionTable + config.Reducers = p.reducers() + config.InitCheck = crdb.NewMultiTableCheck( + crdb.NewTable([]*crdb.Column{ + crdb.NewColumn(AuthRequestColumnID, crdb.ColumnTypeText), + crdb.NewColumn(AuthRequestColumnCreationDate, crdb.ColumnTypeTimestamp), + crdb.NewColumn(AuthRequestColumnChangeDate, crdb.ColumnTypeTimestamp), + crdb.NewColumn(AuthRequestColumnSequence, crdb.ColumnTypeInt64), + crdb.NewColumn(AuthRequestColumnResourceOwner, crdb.ColumnTypeText), + crdb.NewColumn(AuthRequestColumnInstanceID, crdb.ColumnTypeText), + crdb.NewColumn(AuthRequestColumnLoginClient, crdb.ColumnTypeText), + crdb.NewColumn(AuthRequestColumnClientID, crdb.ColumnTypeText), + crdb.NewColumn(AuthRequestColumnRedirectURI, crdb.ColumnTypeText), + crdb.NewColumn(AuthRequestColumnScope, crdb.ColumnTypeTextArray), + crdb.NewColumn(AuthRequestColumnPrompt, crdb.ColumnTypeEnumArray, crdb.Nullable()), + crdb.NewColumn(AuthRequestColumnUILocales, crdb.ColumnTypeTextArray, crdb.Nullable()), + crdb.NewColumn(AuthRequestColumnMaxAge, crdb.ColumnTypeInt64, crdb.Nullable()), + crdb.NewColumn(AuthRequestColumnLoginHint, crdb.ColumnTypeText, crdb.Nullable()), + crdb.NewColumn(AuthRequestColumnHintUserID, crdb.ColumnTypeText, crdb.Nullable()), + }, + crdb.NewPrimaryKey(AuthRequestColumnInstanceID, AuthRequestColumnID), + ), + ) + p.StatementHandler = crdb.NewStatementHandler(ctx, config) + return p +} + +func (p *authRequestProjection) reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: authrequest.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: authrequest.AddedType, + Reduce: p.reduceAuthRequestAdded, + }, + { + Event: authrequest.SucceededType, + Reduce: p.reduceAuthRequestEnded, + }, + { + Event: authrequest.FailedType, + Reduce: p.reduceAuthRequestEnded, + }, + }, + }, + { + Aggregate: instance.AggregateType, + EventRedusers: []handler.EventReducer{ + { + Event: instance.InstanceRemovedEventType, + Reduce: reduceInstanceRemovedHelper(AuthRequestColumnInstanceID), + }, + }, + }, + } +} + +func (p *authRequestProjection) reduceAuthRequestAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*authrequest.AddedEvent) + if !ok { + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-Sfwfa", "reduce.wrong.event.type %s", authrequest.AddedType) + } + + return crdb.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(AuthRequestColumnID, e.Aggregate().ID), + handler.NewCol(AuthRequestColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(AuthRequestColumnCreationDate, e.CreationDate()), + handler.NewCol(AuthRequestColumnChangeDate, e.CreationDate()), + handler.NewCol(AuthRequestColumnResourceOwner, e.Aggregate().ResourceOwner), + handler.NewCol(AuthRequestColumnSequence, e.Sequence()), + handler.NewCol(AuthRequestColumnLoginClient, e.LoginClient), + handler.NewCol(AuthRequestColumnClientID, e.ClientID), + handler.NewCol(AuthRequestColumnRedirectURI, e.RedirectURI), + handler.NewCol(AuthRequestColumnScope, e.Scope), + handler.NewCol(AuthRequestColumnPrompt, e.Prompt), + handler.NewCol(AuthRequestColumnUILocales, e.UILocales), + handler.NewCol(AuthRequestColumnMaxAge, e.MaxAge), + handler.NewCol(AuthRequestColumnLoginHint, e.LoginHint), + handler.NewCol(AuthRequestColumnHintUserID, e.HintUserID), + }, + ), nil +} + +func (p *authRequestProjection) reduceAuthRequestEnded(event eventstore.Event) (*handler.Statement, error) { + switch event.(type) { + case *authrequest.SucceededEvent, + *authrequest.FailedEvent: + break + default: + return nil, errors.ThrowInvalidArgumentf(nil, "HANDL-ASF3h", "reduce.wrong.event.type %s", []eventstore.EventType{authrequest.SucceededType, authrequest.FailedType}) + } + + return crdb.NewDeleteStatement( + event, + []handler.Condition{ + handler.NewCond(AuthRequestColumnID, event.Aggregate().ID), + handler.NewCond(AuthRequestColumnInstanceID, event.Aggregate().InstanceID), + }, + ), nil +} diff --git a/internal/query/projection/auth_request_test.go b/internal/query/projection/auth_request_test.go new file mode 100644 index 0000000000..f0ce0651d4 --- /dev/null +++ b/internal/query/projection/auth_request_test.go @@ -0,0 +1,134 @@ +package projection + +import ( + "testing" + "time" + + "github.com/muhlemmer/gu" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/repository/authrequest" +) + +func TestAuthRequestProjection_reduces(t *testing.T) { + type args struct { + event func(t *testing.T) eventstore.Event + } + tests := []struct { + name string + args args + reduce func(event eventstore.Event) (*handler.Statement, error) + want wantReduce + }{ + { + name: "reduceAuthRequestAdded", + args: args{ + event: getEvent(testEvent( + authrequest.AddedType, + authrequest.AggregateType, + []byte(`{"login_client": "loginClient", "client_id":"clientId","redirect_uri": "redirectURI", "scope": ["openid"], "prompt": [1], "ui_locales": ["en","de"], "max_age": 0, "login_hint": "loginHint", "hint_user_id": "hintUserID"}`), + ), authrequest.AddedEventMapper), + }, + reduce: (&authRequestProjection{}).reduceAuthRequestAdded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("auth_request"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "INSERT INTO projections.auth_requests (id, instance_id, creation_date, change_date, resource_owner, sequence, login_client, client_id, redirect_uri, scope, prompt, ui_locales, max_age, login_hint, hint_user_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + anyArg{}, + anyArg{}, + "ro-id", + uint64(15), + "loginClient", + "clientId", + "redirectURI", + []string{"openid"}, + []domain.Prompt{domain.PromptNone}, + []string{"en", "de"}, + gu.Ptr(time.Duration(0)), + gu.Ptr("loginHint"), + gu.Ptr("hintUserID"), + }, + }, + }, + }, + }, + }, + { + name: "reduceAuthRequestFailed", + args: args{ + event: getEvent(testEvent( + authrequest.FailedType, + authrequest.AggregateType, + []byte(`{"reason": 0}`), + ), authrequest.FailedEventMapper), + }, + reduce: (&authRequestProjection{}).reduceAuthRequestEnded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("auth_request"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.auth_requests WHERE (id = $1) AND (instance_id = $2)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + }, + }, + }, + }, + }, + }, + { + name: "reduceAuthRequestSucceeded", + args: args{ + event: getEvent(testEvent( + authrequest.SucceededType, + authrequest.AggregateType, + nil, + ), authrequest.SucceededEventMapper), + }, + reduce: (&authRequestProjection{}).reduceAuthRequestEnded, + want: wantReduce{ + aggregateType: eventstore.AggregateType("auth_request"), + sequence: 15, + previousSequence: 10, + executer: &testExecuter{ + executions: []execution{ + { + expectedStmt: "DELETE FROM projections.auth_requests WHERE (id = $1) AND (instance_id = $2)", + expectedArgs: []interface{}{ + "agg-id", + "instance-id", + }, + }, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := baseEvent(t) + got, err := tt.reduce(event) + if !errors.IsErrorInvalidArgument(err) { + t.Errorf("no wrong event mapping: %v, got: %v", err, got) + } + + event = tt.args.event(t) + got, err = tt.reduce(event) + assertReduce(t, got, err, AuthRequestsProjectionTable, tt.want) + }) + } +} diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index 6da0638347..e8f2bd1563 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -67,6 +67,7 @@ var ( TelemetryPusherProjection interface{} DeviceAuthProjection *deviceAuthProjection SessionProjection *sessionProjection + AuthRequestProjection *authRequestProjection MilestoneProjection *milestoneProjection ) @@ -145,6 +146,7 @@ func Create(ctx context.Context, sqlClient *database.DB, es *eventstore.Eventsto NotificationPolicyProjection = newNotificationPolicyProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["notification_policies"])) DeviceAuthProjection = newDeviceAuthProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["device_auth"])) SessionProjection = newSessionProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["sessions"])) + AuthRequestProjection = newAuthRequestProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["auth_requests"])) MilestoneProjection = newMilestoneProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["milestones"])) newProjectionsList() return nil @@ -243,6 +245,7 @@ func newProjectionsList() { NotificationPolicyProjection, DeviceAuthProjection, SessionProjection, + AuthRequestProjection, MilestoneProjection, } } diff --git a/internal/query/query.go b/internal/query/query.go index 464cf59aa4..9355265e1c 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "regexp" "sync" "time" @@ -18,9 +19,11 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/action" + "github.com/zitadel/zitadel/internal/repository/authrequest" "github.com/zitadel/zitadel/internal/repository/idpintent" iam_repo "github.com/zitadel/zitadel/internal/repository/instance" "github.com/zitadel/zitadel/internal/repository/keypair" + "github.com/zitadel/zitadel/internal/repository/oidcsession" "github.com/zitadel/zitadel/internal/repository/org" "github.com/zitadel/zitadel/internal/repository/project" "github.com/zitadel/zitadel/internal/repository/session" @@ -88,6 +91,8 @@ func StartQueries( usergrant.RegisterEventMappers(repo.eventstore) session.RegisterEventMappers(repo.eventstore) idpintent.RegisterEventMappers(repo.eventstore) + authrequest.RegisterEventMappers(repo.eventstore) + oidcsession.RegisterEventMappers(repo.eventstore) repo.idpConfigEncryption = idpConfigEncryption repo.multifactors = domain.MultifactorConfigs{ @@ -115,3 +120,19 @@ func (q *Queries) Health(ctx context.Context) error { type prepareDatabase interface { Timetravel(d time.Duration) string } + +// cleanStaticQueries removes whitespaces, +// such as ` `, \t, \n, from queries to improve +// readability in logs and errors. +func cleanStaticQueries(qs ...*string) { + regex := regexp.MustCompile(`\s+`) + for _, q := range qs { + *q = regex.ReplaceAllString(*q, " ") + } +} + +func init() { + cleanStaticQueries( + &authRequestByIDQuery, + ) +} diff --git a/internal/query/query_test.go b/internal/query/query_test.go new file mode 100644 index 0000000000..fc16ee9fad --- /dev/null +++ b/internal/query/query_test.go @@ -0,0 +1,17 @@ +package query + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_cleanStaticQueries(t *testing.T) { + query := `select + foo, + bar +from table;` + want := "select foo, bar from table;" + cleanStaticQueries(&query) + assert.Equal(t, want, query) +} diff --git a/internal/repository/authrequest/aggregate.go b/internal/repository/authrequest/aggregate.go new file mode 100644 index 0000000000..ce6d8adc21 --- /dev/null +++ b/internal/repository/authrequest/aggregate.go @@ -0,0 +1,26 @@ +package authrequest + +import ( + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + AggregateType = "auth_request" + AggregateVersion = "v1" +) + +type Aggregate struct { + eventstore.Aggregate +} + +func NewAggregate(id, instanceID string) *Aggregate { + return &Aggregate{ + Aggregate: eventstore.Aggregate{ + Type: AggregateType, + Version: AggregateVersion, + ID: id, + ResourceOwner: instanceID, + InstanceID: instanceID, + }, + } +} diff --git a/internal/repository/authrequest/auth_request.go b/internal/repository/authrequest/auth_request.go new file mode 100644 index 0000000000..06ef5fea9e --- /dev/null +++ b/internal/repository/authrequest/auth_request.go @@ -0,0 +1,287 @@ +package authrequest + +import ( + "context" + "encoding/json" + "time" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" +) + +const ( + authRequestEventPrefix = "auth_request." + AddedType = authRequestEventPrefix + "added" + FailedType = authRequestEventPrefix + "failed" + CodeAddedType = authRequestEventPrefix + "code.added" + SessionLinkedType = authRequestEventPrefix + "session.linked" + CodeExchangedType = authRequestEventPrefix + "code.exchanged" + SucceededType = authRequestEventPrefix + "succeeded" +) + +type AddedEvent struct { + eventstore.BaseEvent `json:"-"` + + LoginClient string `json:"login_client"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + State string `json:"state,omitempty"` + Nonce string `json:"nonce,omitempty"` + Scope []string `json:"scope,omitempty"` + Audience []string `json:"audience,omitempty"` + ResponseType domain.OIDCResponseType `json:"response_type,omitempty"` + CodeChallenge *domain.OIDCCodeChallenge `json:"code_challenge,omitempty"` + Prompt []domain.Prompt `json:"prompt,omitempty"` + UILocales []string `json:"ui_locales,omitempty"` + MaxAge *time.Duration `json:"max_age,omitempty"` + LoginHint *string `json:"login_hint,omitempty"` + HintUserID *string `json:"hint_user_id,omitempty"` +} + +func (e *AddedEvent) Data() interface{} { + return e +} + +func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewAddedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, + loginClient, + clientID, + redirectURI, + state, + nonce string, + scope, + audience []string, + responseType domain.OIDCResponseType, + codeChallenge *domain.OIDCCodeChallenge, + prompt []domain.Prompt, + uiLocales []string, + maxAge *time.Duration, + loginHint, + hintUserID *string, +) *AddedEvent { + return &AddedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + AddedType, + ), + LoginClient: loginClient, + ClientID: clientID, + RedirectURI: redirectURI, + State: state, + Nonce: nonce, + Scope: scope, + Audience: audience, + ResponseType: responseType, + CodeChallenge: codeChallenge, + Prompt: prompt, + UILocales: uiLocales, + MaxAge: maxAge, + LoginHint: loginHint, + HintUserID: hintUserID, + } +} + +func AddedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &AddedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "AUTHR-DG4gn", "unable to unmarshal auth request added") + } + + return added, nil +} + +type SessionLinkedEvent struct { + eventstore.BaseEvent `json:"-"` + + SessionID string `json:"session_id"` + UserID string `json:"user_id"` + AuthTime time.Time `json:"auth_time"` + AMR []string `json:"amr"` +} + +func (e *SessionLinkedEvent) Data() interface{} { + return e +} + +func (e *SessionLinkedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewSessionLinkedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, + sessionID, + userID string, + authTime time.Time, + amr []string, +) *SessionLinkedEvent { + return &SessionLinkedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + SessionLinkedType, + ), + SessionID: sessionID, + UserID: userID, + AuthTime: authTime, + AMR: amr, + } +} + +func SessionLinkedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &SessionLinkedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request session linked") + } + + return added, nil +} + +type FailedEvent struct { + eventstore.BaseEvent `json:"-"` + + Reason domain.OIDCErrorReason `json:"reason,omitempty"` +} + +func (e *FailedEvent) Data() interface{} { + return e +} + +func (e *FailedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewFailedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + reason domain.OIDCErrorReason, +) *FailedEvent { + return &FailedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + FailedType, + ), + Reason: reason, + } +} + +func FailedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &FailedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request session linked") + } + + return added, nil +} + +type CodeAddedEvent struct { + eventstore.BaseEvent `json:"-"` +} + +func (e *CodeAddedEvent) Data() interface{} { + return e +} + +func (e *CodeAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewCodeAddedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, +) *CodeAddedEvent { + return &CodeAddedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + CodeAddedType, + ), + } +} + +func CodeAddedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &CodeAddedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "AUTHR-Sfe3w", "unable to unmarshal auth request code added") + } + + return added, nil +} + +type CodeExchangedEvent struct { + eventstore.BaseEvent `json:"-"` +} + +func (e *CodeExchangedEvent) Data() interface{} { + return nil +} + +func (e *CodeExchangedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewCodeExchangedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, +) *CodeExchangedEvent { + return &CodeExchangedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + CodeExchangedType, + ), + } +} + +func CodeExchangedEventMapper(event *repository.Event) (eventstore.Event, error) { + return &CodeExchangedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + }, nil +} + +type SucceededEvent struct { + eventstore.BaseEvent `json:"-"` +} + +func (e *SucceededEvent) Data() interface{} { + return nil +} + +func (e *SucceededEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewSucceededEvent(ctx context.Context, + aggregate *eventstore.Aggregate, +) *SucceededEvent { + return &SucceededEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + SucceededType, + ), + } +} + +func SucceededEventMapper(event *repository.Event) (eventstore.Event, error) { + return &SucceededEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + }, nil +} diff --git a/internal/repository/authrequest/eventstore.go b/internal/repository/authrequest/eventstore.go new file mode 100644 index 0000000000..c1bfed7dc4 --- /dev/null +++ b/internal/repository/authrequest/eventstore.go @@ -0,0 +1,12 @@ +package authrequest + +import "github.com/zitadel/zitadel/internal/eventstore" + +func RegisterEventMappers(es *eventstore.Eventstore) { + es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper). + RegisterFilterEventMapper(AggregateType, SessionLinkedType, SessionLinkedEventMapper). + RegisterFilterEventMapper(AggregateType, CodeAddedType, CodeAddedEventMapper). + RegisterFilterEventMapper(AggregateType, CodeExchangedType, CodeExchangedEventMapper). + RegisterFilterEventMapper(AggregateType, FailedType, FailedEventMapper). + RegisterFilterEventMapper(AggregateType, SucceededType, SucceededEventMapper) +} diff --git a/internal/repository/oidcsession/aggregate.go b/internal/repository/oidcsession/aggregate.go new file mode 100644 index 0000000000..920f6c9fce --- /dev/null +++ b/internal/repository/oidcsession/aggregate.go @@ -0,0 +1,25 @@ +package oidcsession + +import ( + "github.com/zitadel/zitadel/internal/eventstore" +) + +const ( + AggregateType = "oidc_session" + AggregateVersion = "v1" +) + +type Aggregate struct { + eventstore.Aggregate +} + +func NewAggregate(id, resourceOwner string) *Aggregate { + return &Aggregate{ + Aggregate: eventstore.Aggregate{ + Type: AggregateType, + Version: AggregateVersion, + ID: id, + ResourceOwner: resourceOwner, + }, + } +} diff --git a/internal/repository/oidcsession/eventstore.go b/internal/repository/oidcsession/eventstore.go new file mode 100644 index 0000000000..88c78f4593 --- /dev/null +++ b/internal/repository/oidcsession/eventstore.go @@ -0,0 +1,11 @@ +package oidcsession + +import "github.com/zitadel/zitadel/internal/eventstore" + +func RegisterEventMappers(es *eventstore.Eventstore) { + es.RegisterFilterEventMapper(AggregateType, AddedType, AddedEventMapper). + RegisterFilterEventMapper(AggregateType, AccessTokenAddedType, AccessTokenAddedEventMapper). + RegisterFilterEventMapper(AggregateType, RefreshTokenAddedType, RefreshTokenAddedEventMapper). + RegisterFilterEventMapper(AggregateType, RefreshTokenRenewedType, RefreshTokenRenewedEventMapper) + +} diff --git a/internal/repository/oidcsession/oidc_session.go b/internal/repository/oidcsession/oidc_session.go new file mode 100644 index 0000000000..842013b34f --- /dev/null +++ b/internal/repository/oidcsession/oidc_session.go @@ -0,0 +1,215 @@ +package oidcsession + +import ( + "context" + "encoding/json" + "time" + + "github.com/zitadel/zitadel/internal/errors" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" +) + +const ( + oidcSessionEventPrefix = "oidc_session." + AddedType = oidcSessionEventPrefix + "added" + AccessTokenAddedType = oidcSessionEventPrefix + "access_token.added" + RefreshTokenAddedType = oidcSessionEventPrefix + "refresh_token.added" + RefreshTokenRenewedType = oidcSessionEventPrefix + "refresh_token.renewed" +) + +type AddedEvent struct { + eventstore.BaseEvent `json:"-"` + + UserID string `json:"userID"` + SessionID string `json:"sessionID"` + ClientID string `json:"clientID"` + Audience []string `json:"audience"` + Scope []string `json:"scope"` + AuthMethodsReferences []string `json:"authMethodsReferences"` + AuthTime time.Time `json:"authTime"` +} + +func (e *AddedEvent) Data() interface{} { + return e +} + +func (e *AddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewAddedEvent(ctx context.Context, + aggregate *eventstore.Aggregate, + userID, + sessionID, + clientID string, + audience, + scope []string, + authMethodsReferences []string, + authTime time.Time, +) *AddedEvent { + return &AddedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + AddedType, + ), + UserID: userID, + SessionID: sessionID, + ClientID: clientID, + Audience: audience, + Scope: scope, + AuthMethodsReferences: authMethodsReferences, + AuthTime: authTime, + } +} + +func AddedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &AddedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "OIDCS-DG4gn", "unable to unmarshal oidc session added") + } + + return added, nil +} + +type AccessTokenAddedEvent struct { + eventstore.BaseEvent `json:"-"` + + ID string `json:"id"` + Scope []string `json:"scope"` + Lifetime time.Duration `json:"lifetime"` +} + +func (e *AccessTokenAddedEvent) Data() interface{} { + return e +} + +func (e *AccessTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewAccessTokenAddedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + id string, + scope []string, + lifetime time.Duration, +) *AccessTokenAddedEvent { + return &AccessTokenAddedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + AccessTokenAddedType, + ), + ID: id, + Scope: scope, + Lifetime: lifetime, + } +} + +func AccessTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &AccessTokenAddedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "OIDCS-DSGn5", "unable to unmarshal access token added") + } + + return added, nil +} + +type RefreshTokenAddedEvent struct { + eventstore.BaseEvent `json:"-"` + + ID string `json:"id"` + Lifetime time.Duration `json:"lifetime"` + IdleLifetime time.Duration `json:"idleLifetime"` +} + +func (e *RefreshTokenAddedEvent) Data() interface{} { + return e +} + +func (e *RefreshTokenAddedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewRefreshTokenAddedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + id string, + lifetime, + idleLifetime time.Duration, +) *RefreshTokenAddedEvent { + return &RefreshTokenAddedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + RefreshTokenAddedType, + ), + ID: id, + Lifetime: lifetime, + IdleLifetime: idleLifetime, + } +} + +func RefreshTokenAddedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &RefreshTokenAddedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "OIDCS-aW3gqq", "unable to unmarshal refresh token added") + } + + return added, nil +} + +type RefreshTokenRenewedEvent struct { + eventstore.BaseEvent `json:"-"` + + ID string `json:"id"` + IdleLifetime time.Duration `json:"idleLifetime"` +} + +func (e *RefreshTokenRenewedEvent) Data() interface{} { + return e +} + +func (e *RefreshTokenRenewedEvent) UniqueConstraints() []*eventstore.EventUniqueConstraint { + return nil +} + +func NewRefreshTokenRenewedEvent( + ctx context.Context, + aggregate *eventstore.Aggregate, + id string, + idleLifetime time.Duration, +) *RefreshTokenRenewedEvent { + return &RefreshTokenRenewedEvent{ + BaseEvent: *eventstore.NewBaseEventForPush( + ctx, + aggregate, + RefreshTokenRenewedType, + ), + ID: id, + IdleLifetime: idleLifetime, + } +} + +func RefreshTokenRenewedEventMapper(event *repository.Event) (eventstore.Event, error) { + added := &RefreshTokenRenewedEvent{ + BaseEvent: *eventstore.BaseEventFromRepo(event), + } + err := json.Unmarshal(event.Data, added) + if err != nil { + return nil, errors.ThrowInternal(err, "OIDCS-SF3fc", "unable to unmarshal refresh token renewed") + } + + return added, nil +} diff --git a/internal/static/i18n/bg.yaml b/internal/static/i18n/bg.yaml index 6950078880..289300493d 100644 --- a/internal/static/i18n/bg.yaml +++ b/internal/static/i18n/bg.yaml @@ -506,6 +506,12 @@ Errors: TokenCreationFailed: Неуспешно създаване на токен InvalidToken: Знакът за намерение е невалиден OtherUser: Намерение, предназначено за друг потребител + AuthRequest: + AlreadyExists: Auth Request вече съществува + NotExisting: Auth Request не съществува + WrongLoginClient: Auth Request, създаден от друг клиент за влизане + OIDCSession: + RefreshTokenInvalid: Токенът за опресняване е невалиден AggregateTypes: action: Действие diff --git a/internal/static/i18n/de.yaml b/internal/static/i18n/de.yaml index 846e09dff0..1e941ffbf1 100644 --- a/internal/static/i18n/de.yaml +++ b/internal/static/i18n/de.yaml @@ -488,7 +488,12 @@ Errors: TokenCreationFailed: Tokenerstellung schlug fehl InvalidToken: Intent Token ist ungültig OtherUser: Intent ist für anderen Benutzer gedacht - + AuthRequest: + AlreadyExists: Auth Request existiert bereits + NotExisting: Auth Request existiert nicht + WrongLoginClient: Auth Request wurde von einem anderen Login-Client erstellt + OIDCSession: + RefreshTokenInvalid: Refresh Token ist ungültig AggregateTypes: action: Action instance: Instanz diff --git a/internal/static/i18n/en.yaml b/internal/static/i18n/en.yaml index f74322cb39..688b75af2a 100644 --- a/internal/static/i18n/en.yaml +++ b/internal/static/i18n/en.yaml @@ -488,6 +488,12 @@ Errors: TokenCreationFailed: Token creation failed InvalidToken: Intent Token is invalid OtherUser: Intent meant for another user + AuthRequest: + AlreadyExists: Auth Request already exists + NotExisting: Auth Request does not exist + WrongLoginClient: Auth Request created by other login client + OIDCSession: + RefreshTokenInvalid: Refresh Token is invalid AggregateTypes: action: Action diff --git a/internal/static/i18n/es.yaml b/internal/static/i18n/es.yaml index 1adb107632..748a90aabb 100644 --- a/internal/static/i18n/es.yaml +++ b/internal/static/i18n/es.yaml @@ -488,6 +488,12 @@ Errors: TokenCreationFailed: Fallo en la creación del token InvalidToken: El token de la intención no es válido OtherUser: Destinado a otro usuario + AuthRequest: + AlreadyExists: Auth Request ya existe + NotExisting: Auth Request no existe + WrongLoginClient: Auth Request creado por otro cliente de inicio de sesión + OIDCSession: + RefreshTokenInvalid: El token de refresco no es válido AggregateTypes: action: Acción diff --git a/internal/static/i18n/fr.yaml b/internal/static/i18n/fr.yaml index 5789a815c6..feb89a8302 100644 --- a/internal/static/i18n/fr.yaml +++ b/internal/static/i18n/fr.yaml @@ -488,6 +488,12 @@ Errors: TokenCreationFailed: La création du token a échoué InvalidToken: Le jeton d'intention n'est pas valide OtherUser: Intention destinée à un autre utilisateur + AuthRequest: + AlreadyExists: Auth Request existe déjà + NotExisting: Auth Request n'existe pas + WrongLoginClient: Auth Request créé par un autre client de connexion + OIDCSession: + RefreshTokenInvalid: Le jeton de rafraîchissement n'est pas valide AggregateTypes: action: Action diff --git a/internal/static/i18n/it.yaml b/internal/static/i18n/it.yaml index 4e51b0fc3e..afe26d6bcb 100644 --- a/internal/static/i18n/it.yaml +++ b/internal/static/i18n/it.yaml @@ -488,6 +488,12 @@ Errors: TokenCreationFailed: creazione del token fallita InvalidToken: Il token dell'intento non è valido OtherUser: Intento destinato a un altro utente + AuthRequest: + AlreadyExists: Auth Request esiste già + NotExisting: Auth Request non esiste + WrongLoginClient: Auth Request creato da un altro client di accesso + OIDCSession: + RefreshTokenInvalid: Refresh Token non è valido AggregateTypes: action: Azione diff --git a/internal/static/i18n/ja.yaml b/internal/static/i18n/ja.yaml index ed6d070d1e..4dfe8c42e6 100644 --- a/internal/static/i18n/ja.yaml +++ b/internal/static/i18n/ja.yaml @@ -477,6 +477,12 @@ Errors: TokenCreationFailed: トークンの作成に失敗しました InvalidToken: インテントのトークンが無効である OtherUser: 他のユーザーを意図している + AuthRequest: + AlreadyExists: AuthRequestはすでに存在する + NotExisting: AuthRequest が存在しません + WrongLoginClient: 他のログインクライアントによって作成された AuthRequest + OIDCSession: + RefreshTokenInvalid: 無効なリフレッシュトークンです AggregateTypes: action: アクション diff --git a/internal/static/i18n/pl.yaml b/internal/static/i18n/pl.yaml index 9e907dcbea..0c9e39f7fb 100644 --- a/internal/static/i18n/pl.yaml +++ b/internal/static/i18n/pl.yaml @@ -488,6 +488,12 @@ Errors: TokenCreationFailed: Tworzenie tokena nie powiodło się InvalidToken: Token intencji jest nieprawidłowy OtherUser: Intencja przeznaczona dla innego użytkownika + AuthRequest: + AlreadyExists: Auth Request już istnieje + NotExisting: Auth Request nie istnieje + WrongLoginClient: Auth Request utworzony przez innego klienta logowania + OIDCSession: + RefreshTokenInvalid: Refresh Token jest nieprawidłowy AggregateTypes: action: Działanie diff --git a/internal/static/i18n/zh.yaml b/internal/static/i18n/zh.yaml index 2a9f2400f9..8e76e9fbfc 100644 --- a/internal/static/i18n/zh.yaml +++ b/internal/static/i18n/zh.yaml @@ -488,6 +488,12 @@ Errors: TokenCreationFailed: 令牌创建失败 InvalidToken: 意图令牌是无效的 OtherUser: 意图是为另一个用户准备的 + AuthRequest: + AlreadyExists: AuthRequest已经存在 + NotExisting: AuthRequest不存在 + WrongLoginClient: 其他登录客户端创建的AuthRequest + OIDCSession: + RefreshTokenInvalid: Refresh Token 无效 AggregateTypes: action: 动作 diff --git a/proto/zitadel/oidc/v2alpha/authorization.proto b/proto/zitadel/oidc/v2alpha/authorization.proto new file mode 100644 index 0000000000..69a14a3d07 --- /dev/null +++ b/proto/zitadel/oidc/v2alpha/authorization.proto @@ -0,0 +1,117 @@ +syntax = "proto3"; + +package zitadel.oidc.v2alpha; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "protoc-gen-openapiv2/options/annotations.proto"; + +option go_package = "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha;oidc"; + +message AuthRequest{ + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_schema) = { + external_docs: { + url: "https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest"; + description: "Find out more about OIDC Auth Request parameters"; + } + }; + + string id = 1 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "ID of the authorization request"; + } + ]; + + google.protobuf.Timestamp creation_date = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Time when the auth request was created"; + } + ]; + + string client_id = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "OIDC client ID of the application that created the auth request"; + } + ]; + + repeated string scope = 4 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Requested scopes by the application, which the user must consent to."; + } + ]; + + string redirect_uri = 5 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Base URI that points back to the application"; + } + ]; + + repeated Prompt prompt = 6 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Prompts that must be displayed to the user"; + } + ]; + + repeated string ui_locales = 7 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "End-User's preferred languages and scripts for the user interface, represented as a list of BCP47 [RFC5646] language tag values, ordered by preference. For instance, the value [fr-CA, fr, en] represents a preference for French as spoken in Canada, then French (without a region designation), followed by English (without a region designation). An error SHOULD NOT result if some or all of the requested locales are not supported."; + } + ]; + + optional string login_hint = 8 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Login hint can be set by the application with a user identifier such as an email or phone number."; + } + ]; + + optional google.protobuf.Duration max_age = 9 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Specifies the allowable elapsed time in seconds since the last time the End-User was actively authenticated. If the elapsed time is greater than this value, or the field is present with 0 duration, the user must be re-authenticated."; + } + ]; + + optional string hint_user_id = 10 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "User ID taken from a ID Token Hint if it was present and valid."; + } + ]; +} + +enum Prompt { + PROMPT_UNSPECIFIED = 0; + PROMPT_NONE = 1; + PROMPT_LOGIN = 2; + PROMPT_CONSENT = 3; + PROMPT_SELECT_ACCOUNT = 4; + PROMPT_CREATE = 5; +} + +message AuthorizationError { + ErrorReason error = 1; + optional string error_description = 2; + optional string error_uri = 3; +} + +enum ErrorReason { + ERROR_REASON_UNSPECIFIED = 0; + + // Error states from https://datatracker.ietf.org/doc/html/rfc6749#section-4.2.2.1 + ERROR_REASON_INVALID_REQUEST = 1; + ERROR_REASON_UNAUTHORIZED_CLIENT = 2; + ERROR_REASON_ACCESS_DENIED = 3; + ERROR_REASON_UNSUPPORTED_RESPONSE_TYPE = 4; + ERROR_REASON_INVALID_SCOPE = 5; + ERROR_REASON_SERVER_ERROR = 6; + ERROR_REASON_TEMPORARY_UNAVAILABLE = 7; + + // Error states from https://openid.net/specs/openid-connect-core-1_0.html#AuthError + ERROR_REASON_INTERACTION_REQUIRED = 8; + ERROR_REASON_LOGIN_REQUIRED = 9; + ERROR_REASON_ACCOUNT_SELECTION_REQUIRED = 10; + ERROR_REASON_CONSENT_REQUIRED = 11; + ERROR_REASON_INVALID_REQUEST_URI = 12; + ERROR_REASON_INVALID_REQUEST_OBJECT = 13; + ERROR_REASON_REQUEST_NOT_SUPPORTED = 14; + ERROR_REASON_REQUEST_URI_NOT_SUPPORTED = 15; + ERROR_REASON_REGISTRATION_NOT_SUPPORTED = 16; +} \ No newline at end of file diff --git a/proto/zitadel/oidc/v2alpha/oidc_service.proto b/proto/zitadel/oidc/v2alpha/oidc_service.proto new file mode 100644 index 0000000000..94ba36413b --- /dev/null +++ b/proto/zitadel/oidc/v2alpha/oidc_service.proto @@ -0,0 +1,190 @@ +syntax = "proto3"; + +package zitadel.oidc.v2alpha; + +import "zitadel/object/v2alpha/object.proto"; +import "zitadel/protoc_gen_zitadel/v2/options.proto"; +import "zitadel/oidc/v2alpha/authorization.proto"; +import "google/api/annotations.proto"; +import "google/api/field_behavior.proto"; +import "protoc-gen-openapiv2/options/annotations.proto"; +import "validate/validate.proto"; + +option go_package = "github.com/zitadel/zitadel/pkg/grpc/oidc/v2alpha;oidc"; + +option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_swagger) = { + info: { + title: "OIDC Service"; + version: "2.0-alpha"; + description: "Get OIDC Auth Request details and create callback URLs. This project is in alpha state. It can AND will continue breaking until the services provide the same functionality as the current login."; + contact:{ + name: "ZITADEL" + url: "https://zitadel.com" + email: "hi@zitadel.com" + } + license: { + name: "Apache 2.0", + url: "https://github.com/zitadel/zitadel/blob/main/LICENSE"; + }; + }; + schemes: HTTPS; + schemes: HTTP; + + consumes: "application/json"; + consumes: "application/grpc"; + + produces: "application/json"; + produces: "application/grpc"; + + consumes: "application/grpc-web+proto"; + produces: "application/grpc-web+proto"; + + host: "$ZITADEL_DOMAIN"; + base_path: "/"; + + external_docs: { + description: "Detailed information about ZITADEL", + url: "https://zitadel.com/docs" + } + + responses: { + key: "403"; + value: { + description: "Returned when the user does not have permission to access the resource."; + schema: { + json_schema: { + ref: "#/definitions/rpcStatus"; + } + } + } + } + responses: { + key: "404"; + value: { + description: "Returned when the resource does not exist."; + schema: { + json_schema: { + ref: "#/definitions/rpcStatus"; + } + } + } + } +}; + +service OIDCService { + rpc GetAuthRequest (GetAuthRequestRequest) returns (GetAuthRequestResponse) { + option (google.api.http) = { + get: "/v2alpha/oidc/auth_requests/{auth_request_id}" + }; + + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "authenticated" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + summary: "Get OIDC Auth Request details"; + description: "Get OIDC Auth Request details by ID, obtained from the redirect URL. Returns details that are parsed from the application's Auth Request." + responses: { + key: "200" + value: { + description: "OK"; + } + }; + }; + } + + rpc CreateCallback (CreateCallbackRequest) returns (CreateCallbackResponse) { + option (google.api.http) = { + post: "/v2alpha/oidc/auth_requests/{auth_request_id}" + body: "*" + }; + + option (zitadel.protoc_gen_zitadel.v2.options) = { + auth_option: { + permission: "authenticated" + } + }; + + option (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_operation) = { + summary: "Finalize an Auth Request and get the callback URL."; + description: "Finalize an Auth Request and get the callback URL for success or failure. The user must be redirected to the URL in order to inform the application about the success or failure. On success, the URL contains details for the application to obtain the tokens. This method can only be called once for an Auth request." + responses: { + key: "200" + value: { + description: "OK"; + } + }; + }; + } +} + +message GetAuthRequestRequest { + string auth_request_id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + description: "ID of the Auth Request, as obtained from the redirect URL."; + example: "\"163840776835432705\""; + } + ]; +} + +message GetAuthRequestResponse { + AuthRequest auth_request = 1; +} + +message CreateCallbackRequest { + string auth_request_id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set."; + ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError"; + } + ]; + + oneof callback_kind { + option (validate.required) = true; + Session session = 2; + AuthorizationError error = 3 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Set this field when the authorization flow failed. It creates a callback URL to the application, with the error details set."; + ref: "https://openid.net/specs/openid-connect-core-1_0.html#AuthError"; + } + ]; + } +} + +message Session { + string session_id = 1 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + description: "ID of the session, used to login the user. Connects the session to the Auth Request."; + example: "\"163840776835432705\""; + } + ]; + + string session_token = 2 [ + (validate.rules).string = {min_len: 1, max_len: 200}, + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + min_length: 1; + max_length: 200; + description: "Token to verify the session is valid"; + } + ]; +} + +message CreateCallbackResponse { + zitadel.object.v2alpha.Details details = 1; + string callback_url = 2 [ + (grpc.gateway.protoc_gen_openapiv2.options.openapiv2_field) = { + description: "Callback URL where the user should be redirected, using a \"302 FOUND\" status. Contains details for the application to obtain the tokens on success, or error details on failure. Note that this field must be treated as credentials, as the contained code can be used to obtain tokens on behalve of the user."; + example: "\"https://client.example.org/cb?code=SplxlOBeZQQYbYS6WxSbIA&state=af0ifjsldkj\"" + } + ]; +} +