mirror of
https://github.com/grafana/grafana.git
synced 2024-11-25 10:20:29 -06:00
OAuth: Fix group mapping when use generic OAuth (#41418)
* Fix group mapping when use generic OAuth * Fix lint * Fix lint
This commit is contained in:
parent
fdd1364ddd
commit
25ad916473
@ -156,7 +156,7 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
|
||||
}
|
||||
}
|
||||
|
||||
if userInfo.Groups != nil && len(userInfo.Groups) == 0 {
|
||||
if len(userInfo.Groups) == 0 {
|
||||
groups, err := s.extractGroups(data)
|
||||
if err != nil {
|
||||
s.log.Warn("Failed to extract groups", "err", err)
|
||||
|
@ -5,15 +5,14 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"github.com/inconshreveable/log15"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"testing"
|
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
@ -646,6 +645,75 @@ func TestUserInfoSearchesForName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserInfoSearchesForGroup(t *testing.T) {
|
||||
t.Run("Given a generic OAuth provider", func(t *testing.T) {
|
||||
provider := SocialGenericOAuth{
|
||||
SocialBase: &SocialBase{
|
||||
log: newLogger("generic_oauth_test", log15.LvlDebug),
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
groupsAttributePath string
|
||||
responseBody interface{}
|
||||
expectedResult []string
|
||||
}{
|
||||
{
|
||||
name: "If groups are not set, user groups are nil",
|
||||
groupsAttributePath: "",
|
||||
expectedResult: nil,
|
||||
},
|
||||
{
|
||||
name: "If groups are empty, user groups are nil",
|
||||
groupsAttributePath: "info.groups",
|
||||
responseBody: map[string]interface{}{
|
||||
"info": map[string]interface{}{
|
||||
"groups": []string{},
|
||||
},
|
||||
},
|
||||
expectedResult: nil,
|
||||
},
|
||||
{
|
||||
name: "If groups are set, user groups are set",
|
||||
groupsAttributePath: "info.groups",
|
||||
responseBody: map[string]interface{}{
|
||||
"info": map[string]interface{}{
|
||||
"groups": []string{"foo", "bar"},
|
||||
},
|
||||
},
|
||||
expectedResult: []string{"foo", "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
provider.groupsAttributePath = test.groupsAttributePath
|
||||
body, err := json.Marshal(test.responseBody)
|
||||
require.NoError(t, err)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
t.Log("Writing fake API response body", "body", test.responseBody)
|
||||
_, err := w.Write(body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
provider.apiUrl = ts.URL
|
||||
token := &oauth2.Token{
|
||||
AccessToken: "",
|
||||
TokenType: "",
|
||||
RefreshToken: "",
|
||||
Expiry: time.Now(),
|
||||
}
|
||||
|
||||
userInfo, err := provider.UserInfo(ts.Client(), token)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expectedResult, userInfo.Groups)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPayloadCompression(t *testing.T) {
|
||||
provider := SocialGenericOAuth{
|
||||
SocialBase: &SocialBase{
|
||||
|
Loading…
Reference in New Issue
Block a user