mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
OAuth: Improve domain validation (#83110)
* enforce hd claim validation * add tests
This commit is contained in:
parent
1394b3341f
commit
7e8b679237
@ -17,6 +17,7 @@ import (
|
|||||||
ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models"
|
ssoModels "github.com/grafana/grafana/pkg/services/ssosettings/models"
|
||||||
"github.com/grafana/grafana/pkg/services/ssosettings/validation"
|
"github.com/grafana/grafana/pkg/services/ssosettings/validation"
|
||||||
"github.com/grafana/grafana/pkg/setting"
|
"github.com/grafana/grafana/pkg/setting"
|
||||||
|
"github.com/grafana/grafana/pkg/util/errutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -37,6 +38,7 @@ type googleUserData struct {
|
|||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
EmailVerified bool `json:"email_verified"`
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
HD string `json:"hd"`
|
||||||
rawJSON []byte `json:"-"`
|
rawJSON []byte `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -115,6 +117,10 @@ func (s *SocialGoogle) UserInfo(ctx context.Context, client *http.Client, token
|
|||||||
return nil, fmt.Errorf("user email is not verified")
|
return nil, fmt.Errorf("user email is not verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.isHDAllowed(data.HD); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
groups, errPage := s.retrieveGroups(ctx, client, data)
|
groups, errPage := s.retrieveGroups(ctx, client, data)
|
||||||
if errPage != nil {
|
if errPage != nil {
|
||||||
s.log.Warn("Error retrieving groups", "error", errPage)
|
s.log.Warn("Error retrieving groups", "error", errPage)
|
||||||
@ -290,3 +296,17 @@ func (s *SocialGoogle) getGroupsPage(ctx context.Context, client *http.Client, u
|
|||||||
|
|
||||||
return &data, nil
|
return &data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SocialGoogle) isHDAllowed(hd string) error {
|
||||||
|
if len(s.info.AllowedDomains) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, allowedDomain := range s.info.AllowedDomains {
|
||||||
|
if hd == allowedDomain {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return errutil.Forbidden("the hd claim found in the ID token is not present in the allowed domains", errutil.WithPublicMessage("Invalid domain"))
|
||||||
|
}
|
||||||
|
@ -890,3 +890,47 @@ func TestSocialGoogle_Reload(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsHDAllowed(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
email string
|
||||||
|
allowedDomains []string
|
||||||
|
expectedErrorMessage string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should not fail if no allowed domains are set",
|
||||||
|
email: "mycompany.com",
|
||||||
|
allowedDomains: []string{},
|
||||||
|
expectedErrorMessage: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not fail if email is from allowed domain",
|
||||||
|
email: "mycompany.com",
|
||||||
|
allowedDomains: []string{"grafana.com", "mycompany.com", "example.com"},
|
||||||
|
expectedErrorMessage: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should fail if email is not from allowed domain",
|
||||||
|
email: "mycompany.com",
|
||||||
|
allowedDomains: []string{"grafana.com", "example.com"},
|
||||||
|
expectedErrorMessage: "the hd claim found in the ID token is not present in the allowed domains",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
info := &social.OAuthInfo{}
|
||||||
|
info.AllowedDomains = tc.allowedDomains
|
||||||
|
s := NewGoogleProvider(info, &setting.Cfg{}, &ssosettingstests.MockService{}, featuremgmt.WithFeatures())
|
||||||
|
err := s.isHDAllowed(tc.email)
|
||||||
|
|
||||||
|
if tc.expectedErrorMessage != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tc.expectedErrorMessage)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user