LDAP: reduce API and allow its extension (#17209)

* Removes Add/Remove methods

* Publicise necessary fields and methods so we could extend it

* Publicise mock API

* More comments and additional simplifications

* Sync with master

Still having low coverage :/ - should be addressed in #17208
This commit is contained in:
Oleg Gaidarenko 2019-05-27 10:36:49 +03:00 committed by GitHub
parent 5884e235fc
commit de92c360a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 209 additions and 408 deletions

View File

@ -29,18 +29,17 @@ type IConnection interface {
// IServer is interface for LDAP authorization // IServer is interface for LDAP authorization
type IServer interface { type IServer interface {
Login(*models.LoginUserQuery) (*models.ExternalUserInfo, error) Login(*models.LoginUserQuery) (*models.ExternalUserInfo, error)
Add(string, map[string][]string) error
Remove(string) error
Users([]string) ([]*models.ExternalUserInfo, error) Users([]string) ([]*models.ExternalUserInfo, error)
ExtractGrafanaUser(*UserInfo) (*models.ExternalUserInfo, error) ExtractGrafanaUser(*UserInfo) (*models.ExternalUserInfo, error)
InitialBind(string, string) error
Dial() error Dial() error
Close() Close()
} }
// Server is basic struct of LDAP authorization // Server is basic struct of LDAP authorization
type Server struct { type Server struct {
config *ServerConfig Config *ServerConfig
connection IConnection Connection IConnection
requireSecondBind bool requireSecondBind bool
log log.Logger log log.Logger
} }
@ -49,7 +48,6 @@ var (
// ErrInvalidCredentials is returned if username and password do not match // ErrInvalidCredentials is returned if username and password do not match
ErrInvalidCredentials = errors.New("Invalid Username or Password") ErrInvalidCredentials = errors.New("Invalid Username or Password")
ErrLDAPUserNotFound = errors.New("LDAP user not found")
) )
var dial = func(network, addr string) (IConnection, error) { var dial = func(network, addr string) (IConnection, error) {
@ -59,7 +57,7 @@ var dial = func(network, addr string) (IConnection, error) {
// New creates the new LDAP auth // New creates the new LDAP auth
func New(config *ServerConfig) IServer { func New(config *ServerConfig) IServer {
return &Server{ return &Server{
config: config, Config: config,
log: log.New("ldap"), log: log.New("ldap"),
} }
} }
@ -68,9 +66,9 @@ func New(config *ServerConfig) IServer {
func (server *Server) Dial() error { func (server *Server) Dial() error {
var err error var err error
var certPool *x509.CertPool var certPool *x509.CertPool
if server.config.RootCACert != "" { if server.Config.RootCACert != "" {
certPool = x509.NewCertPool() certPool = x509.NewCertPool()
for _, caCertFile := range strings.Split(server.config.RootCACert, " ") { for _, caCertFile := range strings.Split(server.Config.RootCACert, " ") {
pem, err := ioutil.ReadFile(caCertFile) pem, err := ioutil.ReadFile(caCertFile)
if err != nil { if err != nil {
return err return err
@ -81,35 +79,35 @@ func (server *Server) Dial() error {
} }
} }
var clientCert tls.Certificate var clientCert tls.Certificate
if server.config.ClientCert != "" && server.config.ClientKey != "" { if server.Config.ClientCert != "" && server.Config.ClientKey != "" {
clientCert, err = tls.LoadX509KeyPair(server.config.ClientCert, server.config.ClientKey) clientCert, err = tls.LoadX509KeyPair(server.Config.ClientCert, server.Config.ClientKey)
if err != nil { if err != nil {
return err return err
} }
} }
for _, host := range strings.Split(server.config.Host, " ") { for _, host := range strings.Split(server.Config.Host, " ") {
address := fmt.Sprintf("%s:%d", host, server.config.Port) address := fmt.Sprintf("%s:%d", host, server.Config.Port)
if server.config.UseSSL { if server.Config.UseSSL {
tlsCfg := &tls.Config{ tlsCfg := &tls.Config{
InsecureSkipVerify: server.config.SkipVerifySSL, InsecureSkipVerify: server.Config.SkipVerifySSL,
ServerName: host, ServerName: host,
RootCAs: certPool, RootCAs: certPool,
} }
if len(clientCert.Certificate) > 0 { if len(clientCert.Certificate) > 0 {
tlsCfg.Certificates = append(tlsCfg.Certificates, clientCert) tlsCfg.Certificates = append(tlsCfg.Certificates, clientCert)
} }
if server.config.StartTLS { if server.Config.StartTLS {
server.connection, err = dial("tcp", address) server.Connection, err = dial("tcp", address)
if err == nil { if err == nil {
if err = server.connection.StartTLS(tlsCfg); err == nil { if err = server.Connection.StartTLS(tlsCfg); err == nil {
return nil return nil
} }
} }
} else { } else {
server.connection, err = ldap.DialTLS("tcp", address, tlsCfg) server.Connection, err = ldap.DialTLS("tcp", address, tlsCfg)
} }
} else { } else {
server.connection, err = dial("tcp", address) server.Connection, err = dial("tcp", address)
} }
if err == nil { if err == nil {
@ -121,16 +119,16 @@ func (server *Server) Dial() error {
// Close closes the LDAP connection // Close closes the LDAP connection
func (server *Server) Close() { func (server *Server) Close() {
server.connection.Close() server.Connection.Close()
} }
// Log in user by searching and serializing it // Login user by searching and serializing it
func (server *Server) Login(query *models.LoginUserQuery) ( func (server *Server) Login(query *models.LoginUserQuery) (
*models.ExternalUserInfo, error, *models.ExternalUserInfo, error,
) { ) {
// Perform initial authentication // Perform initial authentication
err := server.initialBind(query.Username, query.Password) err := server.InitialBind(query.Username, query.Password)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -160,56 +158,6 @@ func (server *Server) Login(query *models.LoginUserQuery) (
return user, nil return user, nil
} }
// Add adds stuff to LDAP
func (server *Server) Add(dn string, values map[string][]string) error {
err := server.initialBind(
server.config.BindDN,
server.config.BindPassword,
)
if err != nil {
return err
}
attributes := make([]ldap.Attribute, 0)
for key, value := range values {
attributes = append(attributes, ldap.Attribute{
Type: key,
Vals: value,
})
}
request := &ldap.AddRequest{
DN: dn,
Attributes: attributes,
}
err = server.connection.Add(request)
if err != nil {
return err
}
return nil
}
// Remove removes stuff from LDAP
func (server *Server) Remove(dn string) error {
err := server.initialBind(
server.config.BindDN,
server.config.BindPassword,
)
if err != nil {
return err
}
request := ldap.NewDelRequest(dn, nil)
err = server.connection.Del(request)
if err != nil {
return err
}
return nil
}
// Users gets LDAP users // Users gets LDAP users
func (server *Server) Users(logins []string) ( func (server *Server) Users(logins []string) (
[]*models.ExternalUserInfo, []*models.ExternalUserInfo,
@ -217,10 +165,10 @@ func (server *Server) Users(logins []string) (
) { ) {
var result *ldap.SearchResult var result *ldap.SearchResult
var err error var err error
var config = server.config var Config = server.Config
for _, base := range config.SearchBaseDNs { for _, base := range Config.SearchBaseDNs {
result, err = server.connection.Search( result, err = server.Connection.Search(
server.getSearchRequest(base, logins), server.getSearchRequest(base, logins),
) )
if err != nil { if err != nil {
@ -254,7 +202,7 @@ func (server *Server) ExtractGrafanaUser(user *UserInfo) (*models.ExternalUserIn
// If there are no ldap group mappings access is true // If there are no ldap group mappings access is true
// otherwise a single group must match // otherwise a single group must match
func (server *Server) validateGrafanaUser(user *models.ExternalUserInfo) error { func (server *Server) validateGrafanaUser(user *models.ExternalUserInfo) error {
if len(server.config.Groups) > 0 && len(user.OrgRoles) < 1 { if len(server.Config.Groups) > 0 && len(user.OrgRoles) < 1 {
server.log.Error( server.log.Error(
"user does not belong in any of the specified LDAP groups", "user does not belong in any of the specified LDAP groups",
"username", user.Login, "username", user.Login,
@ -301,7 +249,7 @@ func (server *Server) getSearchRequest(
) *ldap.SearchRequest { ) *ldap.SearchRequest {
attributes := []string{} attributes := []string{}
inputs := server.config.Attr inputs := server.Config.Attr
attributes = appendIfNotEmpty( attributes = appendIfNotEmpty(
attributes, attributes,
inputs.Username, inputs.Username,
@ -314,7 +262,7 @@ func (server *Server) getSearchRequest(
search := "" search := ""
for _, login := range logins { for _, login := range logins {
query := strings.Replace( query := strings.Replace(
server.config.SearchFilter, server.Config.SearchFilter,
"%s", ldap.EscapeFilter(login), "%s", ldap.EscapeFilter(login),
-1, -1,
) )
@ -347,7 +295,7 @@ func (server *Server) buildGrafanaUser(user *UserInfo) *models.ExternalUserInfo
OrgRoles: map[int64]models.RoleType{}, OrgRoles: map[int64]models.RoleType{},
} }
for _, group := range server.config.Groups { for _, group := range server.Config.Groups {
// only use the first match for each org // only use the first match for each org
if extUser.OrgRoles[group.OrgId] != "" { if extUser.OrgRoles[group.OrgId] != "" {
continue continue
@ -366,15 +314,15 @@ func (server *Server) buildGrafanaUser(user *UserInfo) *models.ExternalUserInfo
func (server *Server) serverBind() error { func (server *Server) serverBind() error {
bindFn := func() error { bindFn := func() error {
return server.connection.Bind( return server.Connection.Bind(
server.config.BindDN, server.Config.BindDN,
server.config.BindPassword, server.Config.BindPassword,
) )
} }
if server.config.BindPassword == "" { if server.Config.BindPassword == "" {
bindFn = func() error { bindFn = func() error {
return server.connection.UnauthenticatedBind(server.config.BindDN) return server.Connection.UnauthenticatedBind(server.Config.BindDN)
} }
} }
@ -397,7 +345,7 @@ func (server *Server) secondBind(
user *models.ExternalUserInfo, user *models.ExternalUserInfo,
userPassword string, userPassword string,
) error { ) error {
err := server.connection.Bind(user.AuthId, userPassword) err := server.Connection.Bind(user.AuthId, userPassword)
if err != nil { if err != nil {
server.log.Info("Second bind failed", "error", err) server.log.Info("Second bind failed", "error", err)
@ -412,24 +360,25 @@ func (server *Server) secondBind(
return nil return nil
} }
func (server *Server) initialBind(username, userPassword string) error { // InitialBind intiates first bind to LDAP server
if server.config.BindPassword != "" || server.config.BindDN == "" { func (server *Server) InitialBind(username, userPassword string) error {
userPassword = server.config.BindPassword if server.Config.BindPassword != "" || server.Config.BindDN == "" {
userPassword = server.Config.BindPassword
server.requireSecondBind = true server.requireSecondBind = true
} }
bindPath := server.config.BindDN bindPath := server.Config.BindDN
if strings.Contains(bindPath, "%s") { if strings.Contains(bindPath, "%s") {
bindPath = fmt.Sprintf(server.config.BindDN, username) bindPath = fmt.Sprintf(server.Config.BindDN, username)
} }
bindFn := func() error { bindFn := func() error {
return server.connection.Bind(bindPath, userPassword) return server.Connection.Bind(bindPath, userPassword)
} }
if userPassword == "" { if userPassword == "" {
bindFn = func() error { bindFn = func() error {
return server.connection.UnauthenticatedBind(bindPath) return server.Connection.UnauthenticatedBind(bindPath)
} }
} }
@ -451,16 +400,16 @@ func (server *Server) initialBind(username, userPassword string) error {
func (server *Server) requestMemberOf(searchResult *ldap.SearchResult) ([]string, error) { func (server *Server) requestMemberOf(searchResult *ldap.SearchResult) ([]string, error) {
var memberOf []string var memberOf []string
for _, groupSearchBase := range server.config.GroupSearchBaseDNs { for _, groupSearchBase := range server.Config.GroupSearchBaseDNs {
var filterReplace string var filterReplace string
if server.config.GroupSearchFilterUserAttribute == "" { if server.Config.GroupSearchFilterUserAttribute == "" {
filterReplace = getLDAPAttr(server.config.Attr.Username, searchResult) filterReplace = getLDAPAttr(server.Config.Attr.Username, searchResult)
} else { } else {
filterReplace = getLDAPAttr(server.config.GroupSearchFilterUserAttribute, searchResult) filterReplace = getLDAPAttr(server.Config.GroupSearchFilterUserAttribute, searchResult)
} }
filter := strings.Replace( filter := strings.Replace(
server.config.GroupSearchFilter, "%s", server.Config.GroupSearchFilter, "%s",
ldap.EscapeFilter(filterReplace), ldap.EscapeFilter(filterReplace),
-1, -1,
) )
@ -468,7 +417,7 @@ func (server *Server) requestMemberOf(searchResult *ldap.SearchResult) ([]string
server.log.Info("Searching for user's groups", "filter", filter) server.log.Info("Searching for user's groups", "filter", filter)
// support old way of reading settings // support old way of reading settings
groupIDAttribute := server.config.Attr.MemberOf groupIDAttribute := server.Config.Attr.MemberOf
// but prefer dn attribute if default settings are used // but prefer dn attribute if default settings are used
if groupIDAttribute == "" || groupIDAttribute == "memberOf" { if groupIDAttribute == "" || groupIDAttribute == "memberOf" {
groupIDAttribute = "dn" groupIDAttribute = "dn"
@ -482,7 +431,7 @@ func (server *Server) requestMemberOf(searchResult *ldap.SearchResult) ([]string
Filter: filter, Filter: filter,
} }
groupSearchResult, err := server.connection.Search(&groupSearchReq) groupSearchResult, err := server.Connection.Search(&groupSearchReq)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -518,22 +467,22 @@ func (server *Server) serializeUsers(
index, index,
), ),
LastName: getLDAPAttrN( LastName: getLDAPAttrN(
server.config.Attr.Surname, server.Config.Attr.Surname,
users, users,
index, index,
), ),
FirstName: getLDAPAttrN( FirstName: getLDAPAttrN(
server.config.Attr.Name, server.Config.Attr.Name,
users, users,
index, index,
), ),
Username: getLDAPAttrN( Username: getLDAPAttrN(
server.config.Attr.Username, server.Config.Attr.Username,
users, users,
index, index,
), ),
Email: getLDAPAttrN( Email: getLDAPAttrN(
server.config.Attr.Email, server.Config.Attr.Email,
users, users,
index, index,
), ),
@ -553,8 +502,8 @@ func (server *Server) serializeUsers(
func (server *Server) getMemberOf(search *ldap.SearchResult) ( func (server *Server) getMemberOf(search *ldap.SearchResult) (
[]string, error, []string, error,
) { ) {
if server.config.GroupSearchFilter == "" { if server.Config.GroupSearchFilter == "" {
memberOf := getLDAPAttrArray(server.config.Attr.MemberOf, search) memberOf := getLDAPAttrArray(server.Config.Attr.MemberOf, search)
return memberOf, nil return memberOf, nil
} }

View File

@ -13,7 +13,7 @@ func TestLDAPHelpers(t *testing.T) {
Convey("serializeUsers()", t, func() { Convey("serializeUsers()", t, func() {
Convey("simple case", func() { Convey("simple case", func() {
server := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
Attr: AttributeMap{ Attr: AttributeMap{
Username: "username", Username: "username",
Name: "name", Name: "name",
@ -22,7 +22,7 @@ func TestLDAPHelpers(t *testing.T) {
}, },
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: &mockConnection{}, Connection: &MockConnection{},
log: log.New("test-logger"), log: log.New("test-logger"),
} }
@ -46,7 +46,7 @@ func TestLDAPHelpers(t *testing.T) {
Convey("without lastname", func() { Convey("without lastname", func() {
server := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
Attr: AttributeMap{ Attr: AttributeMap{
Username: "username", Username: "username",
Name: "name", Name: "name",
@ -55,7 +55,7 @@ func TestLDAPHelpers(t *testing.T) {
}, },
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: &mockConnection{}, Connection: &MockConnection{},
log: log.New("test-logger"), log: log.New("test-logger"),
} }
@ -75,74 +75,9 @@ func TestLDAPHelpers(t *testing.T) {
}) })
}) })
Convey("initialBind", t, func() {
Convey("Given bind dn and password configured", func() {
connection := &mockConnection{}
var actualUsername, actualPassword string
connection.bindProvider = func(username, password string) error {
actualUsername = username
actualPassword = password
return nil
}
server := &Server{
connection: connection,
config: &ServerConfig{
BindDN: "cn=%s,o=users,dc=grafana,dc=org",
BindPassword: "bindpwd",
},
}
err := server.initialBind("user", "pwd")
So(err, ShouldBeNil)
So(server.requireSecondBind, ShouldBeTrue)
So(actualUsername, ShouldEqual, "cn=user,o=users,dc=grafana,dc=org")
So(actualPassword, ShouldEqual, "bindpwd")
})
Convey("Given bind dn configured", func() {
connection := &mockConnection{}
var actualUsername, actualPassword string
connection.bindProvider = func(username, password string) error {
actualUsername = username
actualPassword = password
return nil
}
server := &Server{
connection: connection,
config: &ServerConfig{
BindDN: "cn=%s,o=users,dc=grafana,dc=org",
},
}
err := server.initialBind("user", "pwd")
So(err, ShouldBeNil)
So(server.requireSecondBind, ShouldBeFalse)
So(actualUsername, ShouldEqual, "cn=user,o=users,dc=grafana,dc=org")
So(actualPassword, ShouldEqual, "pwd")
})
Convey("Given empty bind dn and password", func() {
connection := &mockConnection{}
unauthenticatedBindWasCalled := false
var actualUsername string
connection.unauthenticatedBindProvider = func(username string) error {
unauthenticatedBindWasCalled = true
actualUsername = username
return nil
}
server := &Server{
connection: connection,
config: &ServerConfig{},
}
err := server.initialBind("user", "pwd")
So(err, ShouldBeNil)
So(server.requireSecondBind, ShouldBeTrue)
So(unauthenticatedBindWasCalled, ShouldBeTrue)
So(actualUsername, ShouldBeEmpty)
})
})
Convey("serverBind()", t, func() { Convey("serverBind()", t, func() {
Convey("Given bind dn and password configured", func() { Convey("Given bind dn and password configured", func() {
connection := &mockConnection{} connection := &MockConnection{}
var actualUsername, actualPassword string var actualUsername, actualPassword string
connection.bindProvider = func(username, password string) error { connection.bindProvider = func(username, password string) error {
actualUsername = username actualUsername = username
@ -150,8 +85,8 @@ func TestLDAPHelpers(t *testing.T) {
return nil return nil
} }
server := &Server{ server := &Server{
connection: connection, Connection: connection,
config: &ServerConfig{ Config: &ServerConfig{
BindDN: "o=users,dc=grafana,dc=org", BindDN: "o=users,dc=grafana,dc=org",
BindPassword: "bindpwd", BindPassword: "bindpwd",
}, },
@ -163,7 +98,7 @@ func TestLDAPHelpers(t *testing.T) {
}) })
Convey("Given bind dn configured", func() { Convey("Given bind dn configured", func() {
connection := &mockConnection{} connection := &MockConnection{}
unauthenticatedBindWasCalled := false unauthenticatedBindWasCalled := false
var actualUsername string var actualUsername string
connection.unauthenticatedBindProvider = func(username string) error { connection.unauthenticatedBindProvider = func(username string) error {
@ -172,8 +107,8 @@ func TestLDAPHelpers(t *testing.T) {
return nil return nil
} }
server := &Server{ server := &Server{
connection: connection, Connection: connection,
config: &ServerConfig{ Config: &ServerConfig{
BindDN: "o=users,dc=grafana,dc=org", BindDN: "o=users,dc=grafana,dc=org",
}, },
} }
@ -184,7 +119,7 @@ func TestLDAPHelpers(t *testing.T) {
}) })
Convey("Given empty bind dn and password", func() { Convey("Given empty bind dn and password", func() {
connection := &mockConnection{} connection := &MockConnection{}
unauthenticatedBindWasCalled := false unauthenticatedBindWasCalled := false
var actualUsername string var actualUsername string
connection.unauthenticatedBindProvider = func(username string) error { connection.unauthenticatedBindProvider = func(username string) error {
@ -193,8 +128,8 @@ func TestLDAPHelpers(t *testing.T) {
return nil return nil
} }
server := &Server{ server := &Server{
connection: connection, Connection: connection,
config: &ServerConfig{}, Config: &ServerConfig{},
} }
err := server.serverBind() err := server.serverBind()
So(err, ShouldBeNil) So(err, ShouldBeNil)

View File

@ -13,12 +13,12 @@ import (
func TestLDAPLogin(t *testing.T) { func TestLDAPLogin(t *testing.T) {
Convey("Login()", t, func() { Convey("Login()", t, func() {
authScenario("When user is log in and updated", func(sc *scenarioContext) { serverScenario("When user is log in and updated", func(sc *scenarioContext) {
// arrange // arrange
mockConnection := &mockConnection{} mockConnection := &MockConnection{}
auth := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
Host: "", Host: "",
RootCACert: "", RootCACert: "",
Groups: []*GroupToOrgRole{ Groups: []*GroupToOrgRole{
@ -33,7 +33,7 @@ func TestLDAPLogin(t *testing.T) {
}, },
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: mockConnection, Connection: mockConnection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
@ -61,7 +61,7 @@ func TestLDAPLogin(t *testing.T) {
sc.userOrgsQueryReturns([]*models.UserOrgDTO{}) sc.userOrgsQueryReturns([]*models.UserOrgDTO{})
// act // act
extUser, _ := auth.Login(query) extUser, _ := server.Login(query)
userInfo, err := user.Upsert(&user.UpsertArgs{ userInfo, err := user.Upsert(&user.UpsertArgs{
SignupAllowed: true, SignupAllowed: true,
ExternalUser: extUser, ExternalUser: extUser,
@ -73,7 +73,7 @@ func TestLDAPLogin(t *testing.T) {
So(err, ShouldBeNil) So(err, ShouldBeNil)
// User should be searched in ldap // User should be searched in ldap
So(mockConnection.searchCalled, ShouldBeTrue) So(mockConnection.SearchCalled, ShouldBeTrue)
// Info should be updated (email differs) // Info should be updated (email differs)
So(userInfo.Email, ShouldEqual, "roel@test.com") So(userInfo.Email, ShouldEqual, "roel@test.com")
@ -82,8 +82,8 @@ func TestLDAPLogin(t *testing.T) {
So(sc.addOrgUserCmd.Role, ShouldEqual, "Admin") So(sc.addOrgUserCmd.Role, ShouldEqual, "Admin")
}) })
authScenario("When login with invalid credentials", func(scenario *scenarioContext) { serverScenario("When login with invalid credentials", func(scenario *scenarioContext) {
connection := &mockConnection{} connection := &MockConnection{}
entry := ldap.Entry{} entry := ldap.Entry{}
result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}}
connection.setSearchResult(&result) connection.setSearchResult(&result)
@ -93,8 +93,8 @@ func TestLDAPLogin(t *testing.T) {
ResultCode: 49, ResultCode: 49,
} }
} }
auth := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
Attr: AttributeMap{ Attr: AttributeMap{
Username: "username", Username: "username",
Name: "name", Name: "name",
@ -102,19 +102,19 @@ func TestLDAPLogin(t *testing.T) {
}, },
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: connection, Connection: connection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
_, err := auth.Login(scenario.loginUserQuery) _, err := server.Login(scenario.loginUserQuery)
Convey("it should return invalid credentials error", func() { Convey("it should return invalid credentials error", func() {
So(err, ShouldEqual, ErrInvalidCredentials) So(err, ShouldEqual, ErrInvalidCredentials)
}) })
}) })
authScenario("When login with valid credentials", func(scenario *scenarioContext) { serverScenario("When login with valid credentials", func(scenario *scenarioContext) {
connection := &mockConnection{} connection := &MockConnection{}
entry := ldap.Entry{ entry := ldap.Entry{
DN: "dn", Attributes: []*ldap.EntryAttribute{ DN: "dn", Attributes: []*ldap.EntryAttribute{
{Name: "username", Values: []string{"markelog"}}, {Name: "username", Values: []string{"markelog"}},
@ -130,8 +130,8 @@ func TestLDAPLogin(t *testing.T) {
connection.bindProvider = func(username, password string) error { connection.bindProvider = func(username, password string) error {
return nil return nil
} }
auth := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
Attr: AttributeMap{ Attr: AttributeMap{
Username: "username", Username: "username",
Name: "name", Name: "name",
@ -139,18 +139,18 @@ func TestLDAPLogin(t *testing.T) {
}, },
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: connection, Connection: connection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
resp, err := auth.Login(scenario.loginUserQuery) resp, err := server.Login(scenario.loginUserQuery)
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(resp.Login, ShouldEqual, "markelog") So(resp.Login, ShouldEqual, "markelog")
}) })
authScenario("When user not found in LDAP, but exist in Grafana", func(scenario *scenarioContext) { serverScenario("When user not found in LDAP, but exist in Grafana", func(scenario *scenarioContext) {
connection := &mockConnection{} connection := &MockConnection{}
result := ldap.SearchResult{Entries: []*ldap.Entry{}} result := ldap.SearchResult{Entries: []*ldap.Entry{}}
connection.setSearchResult(&result) connection.setSearchResult(&result)
@ -160,15 +160,15 @@ func TestLDAPLogin(t *testing.T) {
connection.bindProvider = func(username, password string) error { connection.bindProvider = func(username, password string) error {
return nil return nil
} }
auth := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: connection, Connection: connection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
_, err := auth.Login(scenario.loginUserQuery) _, err := server.Login(scenario.loginUserQuery)
Convey("it should disable user", func() { Convey("it should disable user", func() {
So(scenario.disableExternalUserCalled, ShouldBeTrue) So(scenario.disableExternalUserCalled, ShouldBeTrue)
@ -181,8 +181,8 @@ func TestLDAPLogin(t *testing.T) {
}) })
}) })
authScenario("When user not found in LDAP, and disabled in Grafana already", func(scenario *scenarioContext) { serverScenario("When user not found in LDAP, and disabled in Grafana already", func(scenario *scenarioContext) {
connection := &mockConnection{} connection := &MockConnection{}
result := ldap.SearchResult{Entries: []*ldap.Entry{}} result := ldap.SearchResult{Entries: []*ldap.Entry{}}
connection.setSearchResult(&result) connection.setSearchResult(&result)
@ -192,15 +192,15 @@ func TestLDAPLogin(t *testing.T) {
connection.bindProvider = func(username, password string) error { connection.bindProvider = func(username, password string) error {
return nil return nil
} }
auth := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: connection, Connection: connection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
_, err := auth.Login(scenario.loginUserQuery) _, err := server.Login(scenario.loginUserQuery)
Convey("it should't call disable function", func() { Convey("it should't call disable function", func() {
So(scenario.disableExternalUserCalled, ShouldBeFalse) So(scenario.disableExternalUserCalled, ShouldBeFalse)
@ -211,8 +211,8 @@ func TestLDAPLogin(t *testing.T) {
}) })
}) })
authScenario("When user found in LDAP, and disabled in Grafana", func(scenario *scenarioContext) { serverScenario("When user found in LDAP, and disabled in Grafana", func(scenario *scenarioContext) {
connection := &mockConnection{} connection := &MockConnection{}
entry := ldap.Entry{} entry := ldap.Entry{}
result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}}
connection.setSearchResult(&result) connection.setSearchResult(&result)
@ -221,15 +221,15 @@ func TestLDAPLogin(t *testing.T) {
connection.bindProvider = func(username, password string) error { connection.bindProvider = func(username, password string) error {
return nil return nil
} }
auth := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: connection, Connection: connection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
extUser, _ := auth.Login(scenario.loginUserQuery) extUser, _ := server.Login(scenario.loginUserQuery)
_, err := user.Upsert(&user.UpsertArgs{ _, err := user.Upsert(&user.UpsertArgs{
SignupAllowed: true, SignupAllowed: true,
ExternalUser: extUser, ExternalUser: extUser,

View File

@ -9,114 +9,10 @@ import (
"github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/log"
) )
func TestAuth(t *testing.T) { func TestPublicAPI(t *testing.T) {
Convey("Add()", t, func() {
connection := &mockConnection{}
auth := &Server{
config: &ServerConfig{
SearchBaseDNs: []string{"BaseDNHere"},
},
connection: connection,
log: log.New("test-logger"),
}
Convey("Adds user", func() {
err := auth.Add(
"cn=ldap-tuz,ou=users,dc=grafana,dc=org",
map[string][]string{
"mail": {"ldap-viewer@grafana.com"},
"userPassword": {"grafana"},
"objectClass": {
"person",
"top",
"inetOrgPerson",
"organizationalPerson",
},
"sn": {"ldap-tuz"},
"cn": {"ldap-tuz"},
},
)
hasMail := false
hasUserPassword := false
hasObjectClass := false
hasSN := false
hasCN := false
So(err, ShouldBeNil)
So(connection.addParams.Controls, ShouldBeNil)
So(connection.addCalled, ShouldBeTrue)
So(
connection.addParams.DN,
ShouldEqual,
"cn=ldap-tuz,ou=users,dc=grafana,dc=org",
)
attrs := connection.addParams.Attributes
for _, value := range attrs {
if value.Type == "mail" {
So(value.Vals, ShouldContain, "ldap-viewer@grafana.com")
hasMail = true
}
if value.Type == "userPassword" {
hasUserPassword = true
So(value.Vals, ShouldContain, "grafana")
}
if value.Type == "objectClass" {
hasObjectClass = true
So(value.Vals, ShouldContain, "person")
So(value.Vals, ShouldContain, "top")
So(value.Vals, ShouldContain, "inetOrgPerson")
So(value.Vals, ShouldContain, "organizationalPerson")
}
if value.Type == "sn" {
hasSN = true
So(value.Vals, ShouldContain, "ldap-tuz")
}
if value.Type == "cn" {
hasCN = true
So(value.Vals, ShouldContain, "ldap-tuz")
}
}
So(hasMail, ShouldBeTrue)
So(hasUserPassword, ShouldBeTrue)
So(hasObjectClass, ShouldBeTrue)
So(hasSN, ShouldBeTrue)
So(hasCN, ShouldBeTrue)
})
})
Convey("Remove()", t, func() {
connection := &mockConnection{}
auth := &Server{
config: &ServerConfig{
SearchBaseDNs: []string{"BaseDNHere"},
},
connection: connection,
log: log.New("test-logger"),
}
Convey("Removes the user", func() {
dn := "cn=ldap-tuz,ou=users,dc=grafana,dc=org"
err := auth.Remove(dn)
So(err, ShouldBeNil)
So(connection.delCalled, ShouldBeTrue)
So(connection.delParams.Controls, ShouldBeNil)
So(connection.delParams.DN, ShouldEqual, dn)
})
})
Convey("Users()", t, func() { Convey("Users()", t, func() {
Convey("find one user", func() { Convey("find one user", func() {
mockConnection := &mockConnection{} MockConnection := &MockConnection{}
entry := ldap.Entry{ entry := ldap.Entry{
DN: "dn", Attributes: []*ldap.EntryAttribute{ DN: "dn", Attributes: []*ldap.EntryAttribute{
{Name: "username", Values: []string{"roelgerrits"}}, {Name: "username", Values: []string{"roelgerrits"}},
@ -126,11 +22,11 @@ func TestAuth(t *testing.T) {
{Name: "memberof", Values: []string{"admins"}}, {Name: "memberof", Values: []string{"admins"}},
}} }}
result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}} result := ldap.SearchResult{Entries: []*ldap.Entry{&entry}}
mockConnection.setSearchResult(&result) MockConnection.setSearchResult(&result)
// Set up attribute map without surname and email // Set up attribute map without surname and email
server := &Server{ server := &Server{
config: &ServerConfig{ Config: &ServerConfig{
Attr: AttributeMap{ Attr: AttributeMap{
Username: "username", Username: "username",
Name: "name", Name: "name",
@ -138,7 +34,7 @@ func TestAuth(t *testing.T) {
}, },
SearchBaseDNs: []string{"BaseDNHere"}, SearchBaseDNs: []string{"BaseDNHere"},
}, },
connection: mockConnection, Connection: MockConnection,
log: log.New("test-logger"), log: log.New("test-logger"),
} }
@ -148,10 +44,75 @@ func TestAuth(t *testing.T) {
So(searchResult, ShouldNotBeNil) So(searchResult, ShouldNotBeNil)
// User should be searched in ldap // User should be searched in ldap
So(mockConnection.searchCalled, ShouldBeTrue) So(MockConnection.SearchCalled, ShouldBeTrue)
// No empty attributes should be added to the search request // No empty attributes should be added to the search request
So(len(mockConnection.searchAttributes), ShouldEqual, 3) So(len(MockConnection.SearchAttributes), ShouldEqual, 3)
})
})
Convey("InitialBind", t, func() {
Convey("Given bind dn and password configured", func() {
connection := &MockConnection{}
var actualUsername, actualPassword string
connection.bindProvider = func(username, password string) error {
actualUsername = username
actualPassword = password
return nil
}
server := &Server{
Connection: connection,
Config: &ServerConfig{
BindDN: "cn=%s,o=users,dc=grafana,dc=org",
BindPassword: "bindpwd",
},
}
err := server.InitialBind("user", "pwd")
So(err, ShouldBeNil)
So(server.requireSecondBind, ShouldBeTrue)
So(actualUsername, ShouldEqual, "cn=user,o=users,dc=grafana,dc=org")
So(actualPassword, ShouldEqual, "bindpwd")
})
Convey("Given bind dn configured", func() {
connection := &MockConnection{}
var actualUsername, actualPassword string
connection.bindProvider = func(username, password string) error {
actualUsername = username
actualPassword = password
return nil
}
server := &Server{
Connection: connection,
Config: &ServerConfig{
BindDN: "cn=%s,o=users,dc=grafana,dc=org",
},
}
err := server.InitialBind("user", "pwd")
So(err, ShouldBeNil)
So(server.requireSecondBind, ShouldBeFalse)
So(actualUsername, ShouldEqual, "cn=user,o=users,dc=grafana,dc=org")
So(actualPassword, ShouldEqual, "pwd")
})
Convey("Given empty bind dn and password", func() {
connection := &MockConnection{}
unauthenticatedBindWasCalled := false
var actualUsername string
connection.unauthenticatedBindProvider = func(username string) error {
unauthenticatedBindWasCalled = true
actualUsername = username
return nil
}
server := &Server{
Connection: connection,
Config: &ServerConfig{},
}
err := server.InitialBind("user", "pwd")
So(err, ShouldBeNil)
So(server.requireSecondBind, ShouldBeTrue)
So(unauthenticatedBindWasCalled, ShouldBeTrue)
So(actualUsername, ShouldBeEmpty)
}) })
}) })
} }

View File

@ -12,22 +12,24 @@ import (
"github.com/grafana/grafana/pkg/services/login" "github.com/grafana/grafana/pkg/services/login"
) )
type mockConnection struct { // MockConnection struct for testing
searchResult *ldap.SearchResult type MockConnection struct {
searchCalled bool SearchResult *ldap.SearchResult
searchAttributes []string SearchCalled bool
SearchAttributes []string
addParams *ldap.AddRequest AddParams *ldap.AddRequest
addCalled bool AddCalled bool
delParams *ldap.DelRequest DelParams *ldap.DelRequest
delCalled bool DelCalled bool
bindProvider func(username, password string) error bindProvider func(username, password string) error
unauthenticatedBindProvider func(username string) error unauthenticatedBindProvider func(username string) error
} }
func (c *mockConnection) Bind(username, password string) error { // Bind mocks Bind connection function
func (c *MockConnection) Bind(username, password string) error {
if c.bindProvider != nil { if c.bindProvider != nil {
return c.bindProvider(username, password) return c.bindProvider(username, password)
} }
@ -35,7 +37,8 @@ func (c *mockConnection) Bind(username, password string) error {
return nil return nil
} }
func (c *mockConnection) UnauthenticatedBind(username string) error { // UnauthenticatedBind mocks UnauthenticatedBind connection function
func (c *MockConnection) UnauthenticatedBind(username string) error {
if c.unauthenticatedBindProvider != nil { if c.unauthenticatedBindProvider != nil {
return c.unauthenticatedBindProvider(username) return c.unauthenticatedBindProvider(username)
} }
@ -43,35 +46,40 @@ func (c *mockConnection) UnauthenticatedBind(username string) error {
return nil return nil
} }
func (c *mockConnection) Close() {} // Close mocks Close connection function
func (c *MockConnection) Close() {}
func (c *mockConnection) setSearchResult(result *ldap.SearchResult) { func (c *MockConnection) setSearchResult(result *ldap.SearchResult) {
c.searchResult = result c.SearchResult = result
} }
func (c *mockConnection) Search(sr *ldap.SearchRequest) (*ldap.SearchResult, error) { // Search mocks Search connection function
c.searchCalled = true func (c *MockConnection) Search(sr *ldap.SearchRequest) (*ldap.SearchResult, error) {
c.searchAttributes = sr.Attributes c.SearchCalled = true
return c.searchResult, nil c.SearchAttributes = sr.Attributes
return c.SearchResult, nil
} }
func (c *mockConnection) Add(request *ldap.AddRequest) error { // Add mocks Add connection function
c.addCalled = true func (c *MockConnection) Add(request *ldap.AddRequest) error {
c.addParams = request c.AddCalled = true
c.AddParams = request
return nil return nil
} }
func (c *mockConnection) Del(request *ldap.DelRequest) error { // Del mocks Del connection function
c.delCalled = true func (c *MockConnection) Del(request *ldap.DelRequest) error {
c.delParams = request c.DelCalled = true
c.DelParams = request
return nil return nil
} }
func (c *mockConnection) StartTLS(*tls.Config) error { // StartTLS mocks StartTLS connection function
func (c *MockConnection) StartTLS(*tls.Config) error {
return nil return nil
} }
func authScenario(desc string, fn scenarioFunc) { func serverScenario(desc string, fn scenarioFunc) {
Convey(desc, func() { Convey(desc, func() {
defer bus.ClearBusHandlers() defer bus.ClearBusHandlers()

View File

@ -35,9 +35,6 @@ type IMultiLDAP interface {
User(login string) ( User(login string) (
*models.ExternalUserInfo, error, *models.ExternalUserInfo, error,
) )
Add(dn string, values map[string][]string) error
Remove(dn string) error
} }
// MultiLDAP is basic struct of LDAP authorization // MultiLDAP is basic struct of LDAP authorization
@ -52,55 +49,6 @@ func New(configs []*ldap.ServerConfig) IMultiLDAP {
} }
} }
// Add adds user to the *first* defined LDAP
func (multiples *MultiLDAP) Add(
dn string,
values map[string][]string,
) error {
if len(multiples.configs) == 0 {
return ErrNoLDAPServers
}
config := multiples.configs[0]
ldap := ldap.New(config)
if err := ldap.Dial(); err != nil {
return err
}
defer ldap.Close()
err := ldap.Add(dn, values)
if err != nil {
return err
}
return nil
}
// Remove removes user from the *first* defined LDAP
func (multiples *MultiLDAP) Remove(dn string) error {
if len(multiples.configs) == 0 {
return ErrNoLDAPServers
}
config := multiples.configs[0]
ldap := ldap.New(config)
if err := ldap.Dial(); err != nil {
return err
}
defer ldap.Close()
err := ldap.Remove(dn)
if err != nil {
return err
}
return nil
}
// Login tries to log in the user in multiples LDAP // Login tries to log in the user in multiples LDAP
func (multiples *MultiLDAP) Login(query *models.LoginUserQuery) ( func (multiples *MultiLDAP) Login(query *models.LoginUserQuery) (
*models.ExternalUserInfo, error, *models.ExternalUserInfo, error,