Storage: Improve (some) error handling (#91373)

This commit is contained in:
Ryan McKinley 2024-08-02 10:27:10 +03:00 committed by GitHub
parent b63694d75f
commit 391284bb33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 57 additions and 47 deletions

View File

@ -12,7 +12,6 @@ import (
// Package-level errors. // Package-level errors.
var ( var (
ErrOptimisticLockingFailed = errors.New("optimistic locking failed") ErrOptimisticLockingFailed = errors.New("optimistic locking failed")
ErrUserNotFoundInContext = errors.New("user not found in context")
ErrNotImplementedYet = errors.New("not implemented yet") ErrNotImplementedYet = errors.New("not implemented yet")
) )

View File

@ -211,19 +211,15 @@ func (s *server) Stop(ctx context.Context) error {
} }
// Old value indicates an update -- otherwise a create // Old value indicates an update -- otherwise a create
func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue []byte) (*WriteEvent, error) { func (s *server) newEvent(ctx context.Context, user identity.Requester, key *ResourceKey, value, oldValue []byte) (*WriteEvent, *ErrorResult) {
user, err := identity.GetRequester(ctx)
if err != nil {
return nil, ErrUserNotFoundInContext
}
tmp := &unstructured.Unstructured{} tmp := &unstructured.Unstructured{}
err = tmp.UnmarshalJSON(value) err := tmp.UnmarshalJSON(value)
if err != nil { if err != nil {
return nil, err return nil, AsErrorResult(err)
} }
obj, err := utils.MetaAccessor(tmp) obj, err := utils.MetaAccessor(tmp)
if err != nil { if err != nil {
return nil, err return nil, AsErrorResult(err)
} }
event := &WriteEvent{ event := &WriteEvent{
@ -239,27 +235,27 @@ func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue
temp := &unstructured.Unstructured{} temp := &unstructured.Unstructured{}
err = temp.UnmarshalJSON(oldValue) err = temp.UnmarshalJSON(oldValue)
if err != nil { if err != nil {
return nil, err return nil, AsErrorResult(err)
} }
event.ObjectOld, err = utils.MetaAccessor(temp) event.ObjectOld, err = utils.MetaAccessor(temp)
if err != nil { if err != nil {
return nil, err return nil, AsErrorResult(err)
} }
} }
if key.Namespace != obj.GetNamespace() { if key.Namespace != obj.GetNamespace() {
return nil, apierrors.NewBadRequest("key/namespace do not match") return nil, NewBadRequestError("key/namespace do not match")
} }
gvk := obj.GetGroupVersionKind() gvk := obj.GetGroupVersionKind()
if gvk.Kind == "" { if gvk.Kind == "" {
return nil, apierrors.NewBadRequest("expecting resources with a kind in the body") return nil, NewBadRequestError("expecting resources with a kind in the body")
} }
if gvk.Version == "" { if gvk.Version == "" {
return nil, apierrors.NewBadRequest("expecting resources with an apiVersion") return nil, NewBadRequestError("expecting resources with an apiVersion")
} }
if gvk.Group != "" && gvk.Group != key.Group { if gvk.Group != "" && gvk.Group != key.Group {
return nil, apierrors.NewBadRequest( return nil, NewBadRequestError(
fmt.Sprintf("group in key does not match group in the body (%s != %s)", key.Group, gvk.Group), fmt.Sprintf("group in key does not match group in the body (%s != %s)", key.Group, gvk.Group),
) )
} }
@ -267,15 +263,14 @@ func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue
// This needs to be a create function // This needs to be a create function
if key.Name == "" { if key.Name == "" {
if obj.GetName() == "" { if obj.GetName() == "" {
return nil, apierrors.NewBadRequest("missing name") return nil, NewBadRequestError("missing name")
} }
key.Name = obj.GetName() key.Name = obj.GetName()
} else if key.Name != obj.GetName() { } else if key.Name != obj.GetName() {
return nil, apierrors.NewBadRequest( return nil, NewBadRequestError(
fmt.Sprintf("key/name do not match (key: %s, name: %s)", key.Name, obj.GetName())) fmt.Sprintf("key/name do not match (key: %s, name: %s)", key.Name, obj.GetName()))
} }
err = validateName(obj.GetName()) if err := validateName(obj.GetName()); err != nil {
if err != nil {
return nil, err return nil, err
} }
@ -283,17 +278,17 @@ func (s *server) newEvent(ctx context.Context, key *ResourceKey, value, oldValue
if folder != "" { if folder != "" {
err = s.access.CanWriteFolder(ctx, user, folder) err = s.access.CanWriteFolder(ctx, user, folder)
if err != nil { if err != nil {
return nil, err return nil, AsErrorResult(err)
} }
} }
origin, err := obj.GetOriginInfo() origin, err := obj.GetOriginInfo()
if err != nil { if err != nil {
return nil, apierrors.NewBadRequest("invalid origin info") return nil, NewBadRequestError("invalid origin info")
} }
if origin != nil { if origin != nil {
err = s.access.CanWriteOrigin(ctx, user, origin.Name) err = s.access.CanWriteOrigin(ctx, user, origin.Name)
if err != nil { if err != nil {
return nil, err return nil, AsErrorResult(err)
} }
} }
return event, nil return event, nil
@ -308,6 +303,15 @@ func (s *server) Create(ctx context.Context, req *CreateRequest) (*CreateRespons
} }
rsp := &CreateResponse{} rsp := &CreateResponse{}
user, err := identity.GetRequester(ctx)
if err != nil || user == nil {
rsp.Error = &ErrorResult{
Message: "no user found in context",
Code: http.StatusUnauthorized,
}
return rsp, nil
}
found := s.backend.ReadResource(ctx, &ReadRequest{Key: req.Key}) found := s.backend.ReadResource(ctx, &ReadRequest{Key: req.Key})
if found != nil && len(found.Value) > 0 { if found != nil && len(found.Value) > 0 {
rsp.Error = &ErrorResult{ rsp.Error = &ErrorResult{
@ -317,9 +321,9 @@ func (s *server) Create(ctx context.Context, req *CreateRequest) (*CreateRespons
return rsp, nil return rsp, nil
} }
event, err := s.newEvent(ctx, req.Key, req.Value, nil) event, e := s.newEvent(ctx, user, req.Key, req.Value, nil)
if err != nil { if e != nil {
rsp.Error = AsErrorResult(err) rsp.Error = e
return rsp, nil return rsp, nil
} }
@ -339,6 +343,14 @@ func (s *server) Update(ctx context.Context, req *UpdateRequest) (*UpdateRespons
} }
rsp := &UpdateResponse{} rsp := &UpdateResponse{}
user, err := identity.GetRequester(ctx)
if err != nil || user == nil {
rsp.Error = &ErrorResult{
Message: "no user found in context",
Code: http.StatusUnauthorized,
}
return rsp, nil
}
if req.ResourceVersion < 0 { if req.ResourceVersion < 0 {
rsp.Error = AsErrorResult(apierrors.NewBadRequest("update must include the previous version")) rsp.Error = AsErrorResult(apierrors.NewBadRequest("update must include the previous version"))
return rsp, nil return rsp, nil
@ -359,9 +371,9 @@ func (s *server) Update(ctx context.Context, req *UpdateRequest) (*UpdateRespons
return nil, ErrOptimisticLockingFailed return nil, ErrOptimisticLockingFailed
} }
event, err := s.newEvent(ctx, req.Key, req.Value, latest.Value) event, e := s.newEvent(ctx, user, req.Key, req.Value, latest.Value)
if err != nil { if e != nil {
rsp.Error = AsErrorResult(err) rsp.Error = e
return rsp, err return rsp, err
} }

View File

@ -1,22 +1,21 @@
package resource package resource
import ( import (
"fmt"
"regexp" "regexp"
) )
var validNameCharPattern = `a-zA-Z0-9\-\_\.` var validNameCharPattern = `a-zA-Z0-9\-\_\.`
var validNamePattern = regexp.MustCompile(`^[` + validNameCharPattern + `]*$`).MatchString var validNamePattern = regexp.MustCompile(`^[` + validNameCharPattern + `]*$`).MatchString
func validateName(name string) error { func validateName(name string) *ErrorResult {
if len(name) == 0 { if len(name) == 0 {
return fmt.Errorf("name is too short") return NewBadRequestError("name is too short")
} }
if len(name) > 64 { if len(name) > 64 {
return fmt.Errorf("name is too long") return NewBadRequestError("name is too long")
} }
if !validNamePattern(name) { if !validNamePattern(name) {
return fmt.Errorf("name includes invalid characters") return NewBadRequestError("name includes invalid characters")
} }
// In standard k8s, it must not start with a number // In standard k8s, it must not start with a number
// however that would force us to update many many many existing resources // however that would force us to update many many many existing resources

View File

@ -7,24 +7,24 @@ import (
) )
func TestNameValidation(t *testing.T) { func TestNameValidation(t *testing.T) {
require.Error(t, validateName("")) // too short require.NotNil(t, validateName("")) // too short
require.Error(t, validateName( // too long (max 64) require.NotNil(t, validateName( // too long (max 64)
"0123456789012345678901234567890123456789012345678901234567890123456789", "0123456789012345678901234567890123456789012345678901234567890123456789",
)) ))
// OK // OK
require.NoError(t, validateName("a")) require.Nil(t, validateName("a"))
require.NoError(t, validateName("hello-world")) require.Nil(t, validateName("hello-world"))
require.NoError(t, validateName("hello.world")) require.Nil(t, validateName("hello.world"))
require.NoError(t, validateName("hello_world")) require.Nil(t, validateName("hello_world"))
// Bad characters // Bad characters
require.Error(t, validateName("hello world")) require.NotNil(t, validateName("hello world"))
require.Error(t, validateName("hello!")) require.NotNil(t, validateName("hello!"))
require.Error(t, validateName("hello~")) require.NotNil(t, validateName("hello~"))
require.Error(t, validateName("hello ")) require.NotNil(t, validateName("hello "))
require.Error(t, validateName("hello*")) require.NotNil(t, validateName("hello*"))
require.Error(t, validateName("hello+")) require.NotNil(t, validateName("hello+"))
require.Error(t, validateName("hello=")) require.NotNil(t, validateName("hello="))
require.Error(t, validateName("hello%")) require.NotNil(t, validateName("hello%"))
} }