From 391284bb3322f391100f6d68ab4b2f06395c2873 Mon Sep 17 00:00:00 2001 From: Ryan McKinley Date: Fri, 2 Aug 2024 10:27:10 +0300 Subject: [PATCH] Storage: Improve (some) error handling (#91373) --- pkg/storage/unified/resource/errors.go | 1 - pkg/storage/unified/resource/server.go | 66 +++++++++++-------- pkg/storage/unified/resource/validation.go | 9 ++- .../unified/resource/validation_test.go | 28 ++++---- 4 files changed, 57 insertions(+), 47 deletions(-) diff --git a/pkg/storage/unified/resource/errors.go b/pkg/storage/unified/resource/errors.go index 088e6683630..7f1d3a3ca56 100644 --- a/pkg/storage/unified/resource/errors.go +++ b/pkg/storage/unified/resource/errors.go @@ -12,7 +12,6 @@ import ( // Package-level errors. var ( ErrOptimisticLockingFailed = errors.New("optimistic locking failed") - ErrUserNotFoundInContext = errors.New("user not found in context") ErrNotImplementedYet = errors.New("not implemented yet") ) diff --git a/pkg/storage/unified/resource/server.go b/pkg/storage/unified/resource/server.go index c301e37288d..a1190e061fb 100644 --- a/pkg/storage/unified/resource/server.go +++ b/pkg/storage/unified/resource/server.go @@ -211,19 +211,15 @@ func (s *server) Stop(ctx context.Context) error { } // Old value indicates an update -- otherwise a create -func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue []byte) (*WriteEvent, error) { - user, err := identity.GetRequester(ctx) - if err != nil { - return nil, ErrUserNotFoundInContext - } +func (s *server) newEvent(ctx context.Context, user identity.Requester, key *ResourceKey, value, oldValue []byte) (*WriteEvent, *ErrorResult) { tmp := &unstructured.Unstructured{} - err = tmp.UnmarshalJSON(value) + err := tmp.UnmarshalJSON(value) if err != nil { - return nil, err + return nil, AsErrorResult(err) } obj, err := utils.MetaAccessor(tmp) if err != nil { - return nil, err + return nil, AsErrorResult(err) } event := &WriteEvent{ @@ -239,27 +235,27 @@ func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue temp := &unstructured.Unstructured{} err = temp.UnmarshalJSON(oldValue) if err != nil { - return nil, err + return nil, AsErrorResult(err) } event.ObjectOld, err = utils.MetaAccessor(temp) if err != nil { - return nil, err + return nil, AsErrorResult(err) } } if key.Namespace != obj.GetNamespace() { - return nil, apierrors.NewBadRequest("key/namespace do not match") + return nil, NewBadRequestError("key/namespace do not match") } gvk := obj.GetGroupVersionKind() if gvk.Kind == "" { - return nil, apierrors.NewBadRequest("expecting resources with a kind in the body") + return nil, NewBadRequestError("expecting resources with a kind in the body") } if gvk.Version == "" { - return nil, apierrors.NewBadRequest("expecting resources with an apiVersion") + return nil, NewBadRequestError("expecting resources with an apiVersion") } if gvk.Group != "" && gvk.Group != key.Group { - return nil, apierrors.NewBadRequest( + return nil, NewBadRequestError( fmt.Sprintf("group in key does not match group in the body (%s != %s)", key.Group, gvk.Group), ) } @@ -267,15 +263,14 @@ func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue // This needs to be a create function if key.Name == "" { if obj.GetName() == "" { - return nil, apierrors.NewBadRequest("missing name") + return nil, NewBadRequestError("missing name") } key.Name = obj.GetName() } else if key.Name != obj.GetName() { - return nil, apierrors.NewBadRequest( + return nil, NewBadRequestError( fmt.Sprintf("key/name do not match (key: %s, name: %s)", key.Name, obj.GetName())) } - err = validateName(obj.GetName()) - if err != nil { + if err := validateName(obj.GetName()); err != nil { return nil, err } @@ -283,17 +278,17 @@ func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue if folder != "" { err = s.access.CanWriteFolder(ctx, user, folder) if err != nil { - return nil, err + return nil, AsErrorResult(err) } } origin, err := obj.GetOriginInfo() if err != nil { - return nil, apierrors.NewBadRequest("invalid origin info") + return nil, NewBadRequestError("invalid origin info") } if origin != nil { err = s.access.CanWriteOrigin(ctx, user, origin.Name) if err != nil { - return nil, err + return nil, AsErrorResult(err) } } return event, nil @@ -308,6 +303,15 @@ func (s *server) Create(ctx context.Context, req *CreateRequest) (*CreateRespons } rsp := &CreateResponse{} + user, err := identity.GetRequester(ctx) + if err != nil || user == nil { + rsp.Error = &ErrorResult{ + Message: "no user found in context", + Code: http.StatusUnauthorized, + } + return rsp, nil + } + found := s.backend.ReadResource(ctx, &ReadRequest{Key: req.Key}) if found != nil && len(found.Value) > 0 { rsp.Error = &ErrorResult{ @@ -317,9 +321,9 @@ func (s *server) Create(ctx context.Context, req *CreateRequest) (*CreateRespons return rsp, nil } - event, err := s.newEvent(ctx, req.Key, req.Value, nil) - if err != nil { - rsp.Error = AsErrorResult(err) + event, e := s.newEvent(ctx, user, req.Key, req.Value, nil) + if e != nil { + rsp.Error = e return rsp, nil } @@ -339,6 +343,14 @@ func (s *server) Update(ctx context.Context, req *UpdateRequest) (*UpdateRespons } rsp := &UpdateResponse{} + user, err := identity.GetRequester(ctx) + if err != nil || user == nil { + rsp.Error = &ErrorResult{ + Message: "no user found in context", + Code: http.StatusUnauthorized, + } + return rsp, nil + } if req.ResourceVersion < 0 { rsp.Error = AsErrorResult(apierrors.NewBadRequest("update must include the previous version")) return rsp, nil @@ -359,9 +371,9 @@ func (s *server) Update(ctx context.Context, req *UpdateRequest) (*UpdateRespons return nil, ErrOptimisticLockingFailed } - event, err := s.newEvent(ctx, req.Key, req.Value, latest.Value) - if err != nil { - rsp.Error = AsErrorResult(err) + event, e := s.newEvent(ctx, user, req.Key, req.Value, latest.Value) + if e != nil { + rsp.Error = e return rsp, err } diff --git a/pkg/storage/unified/resource/validation.go b/pkg/storage/unified/resource/validation.go index 5ce5362717a..dce25a1b8fe 100644 --- a/pkg/storage/unified/resource/validation.go +++ b/pkg/storage/unified/resource/validation.go @@ -1,22 +1,21 @@ package resource import ( - "fmt" "regexp" ) var validNameCharPattern = `a-zA-Z0-9\-\_\.` var validNamePattern = regexp.MustCompile(`^[` + validNameCharPattern + `]*$`).MatchString -func validateName(name string) error { +func validateName(name string) *ErrorResult { if len(name) == 0 { - return fmt.Errorf("name is too short") + return NewBadRequestError("name is too short") } if len(name) > 64 { - return fmt.Errorf("name is too long") + return NewBadRequestError("name is too long") } if !validNamePattern(name) { - return fmt.Errorf("name includes invalid characters") + return NewBadRequestError("name includes invalid characters") } // In standard k8s, it must not start with a number // however that would force us to update many many many existing resources diff --git a/pkg/storage/unified/resource/validation_test.go b/pkg/storage/unified/resource/validation_test.go index 9b4907d8630..b9f9a51d7b5 100644 --- a/pkg/storage/unified/resource/validation_test.go +++ b/pkg/storage/unified/resource/validation_test.go @@ -7,24 +7,24 @@ import ( ) func TestNameValidation(t *testing.T) { - require.Error(t, validateName("")) // too short - require.Error(t, validateName( // too long (max 64) + require.NotNil(t, validateName("")) // too short + require.NotNil(t, validateName( // too long (max 64) "0123456789012345678901234567890123456789012345678901234567890123456789", )) // OK - require.NoError(t, validateName("a")) - require.NoError(t, validateName("hello-world")) - require.NoError(t, validateName("hello.world")) - require.NoError(t, validateName("hello_world")) + require.Nil(t, validateName("a")) + require.Nil(t, validateName("hello-world")) + require.Nil(t, validateName("hello.world")) + require.Nil(t, validateName("hello_world")) // Bad characters - require.Error(t, validateName("hello world")) - require.Error(t, validateName("hello!")) - require.Error(t, validateName("hello~")) - require.Error(t, validateName("hello ")) - require.Error(t, validateName("hello*")) - require.Error(t, validateName("hello+")) - require.Error(t, validateName("hello=")) - require.Error(t, validateName("hello%")) + require.NotNil(t, validateName("hello world")) + require.NotNil(t, validateName("hello!")) + require.NotNil(t, validateName("hello~")) + require.NotNil(t, validateName("hello ")) + require.NotNil(t, validateName("hello*")) + require.NotNil(t, validateName("hello+")) + require.NotNil(t, validateName("hello=")) + require.NotNil(t, validateName("hello%")) }