OAuth: Fix group mapping when use generic OAuth (#41418)

* Fix group mapping when use generic OAuth

* Fix lint

* Fix lint
This commit is contained in:
Selene 2021-11-12 10:22:04 +01:00 committed by GitHub
parent fdd1364ddd
commit 25ad916473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 5 deletions

View File

@ -156,7 +156,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
}
}
if userInfo.Groups != nil && len(userInfo.Groups) == 0 {
if len(userInfo.Groups) == 0 {
groups, err := s.extractGroups(data)
if err != nil {
s.log.Warn("Failed to extract groups", "err", err)

View File

@ -5,15 +5,14 @@ import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/inconshreveable/log15"
"github.com/mattn/go-isatty"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"github.com/grafana/grafana/pkg/infra/log"
"golang.org/x/oauth2"
)
@ -646,6 +645,75 @@ func TestUserInfoSearchesForName(t *testing.T) {
})
}
func TestUserInfoSearchesForGroup(t *testing.T) {
t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{
SocialBase: &SocialBase{
log: newLogger("generic_oauth_test", log15.LvlDebug),
},
}
tests := []struct {
name string
groupsAttributePath string
responseBody interface{}
expectedResult []string
}{
{
name: "If groups are not set, user groups are nil",
groupsAttributePath: "",
expectedResult: nil,
},
{
name: "If groups are empty, user groups are nil",
groupsAttributePath: "info.groups",
responseBody: map[string]interface{}{
"info": map[string]interface{}{
"groups": []string{},
},
},
expectedResult: nil,
},
{
name: "If groups are set, user groups are set",
groupsAttributePath: "info.groups",
responseBody: map[string]interface{}{
"info": map[string]interface{}{
"groups": []string{"foo", "bar"},
},
},
expectedResult: []string{"foo", "bar"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
provider.groupsAttributePath = test.groupsAttributePath
body, err := json.Marshal(test.responseBody)
require.NoError(t, err)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
t.Log("Writing fake API response body", "body", test.responseBody)
_, err := w.Write(body)
require.NoError(t, err)
}))
provider.apiUrl = ts.URL
token := &oauth2.Token{
AccessToken: "",
TokenType: "",
RefreshToken: "",
Expiry: time.Now(),
}
userInfo, err := provider.UserInfo(ts.Client(), token)
assert.NoError(t, err)
assert.Equal(t, test.expectedResult, userInfo.Groups)
})
}
})
}
func TestPayloadCompression(t *testing.T) {
provider := SocialGenericOAuth{
SocialBase: &SocialBase{