package database

import (
	"context"
	"encoding/base64"
	"time"

	"github.com/grafana/grafana/pkg/infra/db"
	"github.com/grafana/grafana/pkg/infra/log"
	"github.com/grafana/grafana/pkg/services/login"
	"github.com/grafana/grafana/pkg/services/secrets"
	"github.com/grafana/grafana/pkg/services/user"
)

var GetTime = time.Now

type AuthInfoStore struct {
	sqlStore       db.DB
	secretsService secrets.Service
	logger         log.Logger
	userService    user.Service
}

func ProvideAuthInfoStore(sqlStore db.DB, secretsService secrets.Service, userService user.Service) login.Store {
	store := &AuthInfoStore{
		sqlStore:       sqlStore,
		secretsService: secretsService,
		logger:         log.New("login.authinfo.store"),
		userService:    userService,
	}
	// FIXME: disabled the metric collection for duplicate user entries
	// due to query performance issues that is clogging the users Grafana instance
	// InitDuplicateUserMetrics()
	return store
}

// GetAuthInfo returns the auth info for a user
// It will return the latest auth info for a user
func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
	if query.UserId == 0 && query.AuthId == "" {
		return nil, user.ErrUserNotFound
	}

	userAuth := &login.UserAuth{
		UserId:     query.UserId,
		AuthModule: query.AuthModule,
		AuthId:     query.AuthId,
	}

	var has bool
	var err error

	err = s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
		has, err = sess.Desc("created").Get(userAuth)
		return err
	})
	if err != nil {
		return nil, err
	}

	if !has {
		return nil, user.ErrUserNotFound
	}

	secretAccessToken, err := s.decodeAndDecrypt(userAuth.OAuthAccessToken)
	if err != nil {
		return nil, err
	}
	secretRefreshToken, err := s.decodeAndDecrypt(userAuth.OAuthRefreshToken)
	if err != nil {
		return nil, err
	}
	secretTokenType, err := s.decodeAndDecrypt(userAuth.OAuthTokenType)
	if err != nil {
		return nil, err
	}
	secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken)
	if err != nil {
		return nil, err
	}
	userAuth.OAuthAccessToken = secretAccessToken
	userAuth.OAuthRefreshToken = secretRefreshToken
	userAuth.OAuthTokenType = secretTokenType
	userAuth.OAuthIdToken = secretIdToken

	return userAuth, nil
}

func (s *AuthInfoStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
	userAuths := []login.UserAuth{}
	params := make([]interface{}, 0, len(query.UserIDs))
	for _, id := range query.UserIDs {
		params = append(params, id)
	}

	err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
		return sess.Table("user_auth").In("user_id", params).OrderBy("created").Find(&userAuths)
	})

	if err != nil {
		return nil, err
	}

	labelMap := make(map[int64]string, len(userAuths))

	for i := range userAuths {
		labelMap[userAuths[i].UserId] = userAuths[i].AuthModule
	}

	return labelMap, nil
}

func (s *AuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
	authUser := &login.UserAuth{
		UserId:     cmd.UserId,
		AuthModule: cmd.AuthModule,
		AuthId:     cmd.AuthId,
		Created:    GetTime(),
	}

	if cmd.OAuthToken != nil {
		secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken)
		if err != nil {
			return err
		}
		secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken)
		if err != nil {
			return err
		}
		secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType)
		if err != nil {
			return err
		}

		var secretIdToken string
		if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
			secretIdToken, err = s.encryptAndEncode(idToken)
			if err != nil {
				return err
			}
		}

		authUser.OAuthAccessToken = secretAccessToken
		authUser.OAuthRefreshToken = secretRefreshToken
		authUser.OAuthTokenType = secretTokenType
		authUser.OAuthIdToken = secretIdToken
		authUser.OAuthExpiry = cmd.OAuthToken.Expiry
	}

	return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
		_, err := sess.Insert(authUser)
		return err
	})
}

// UpdateAuthInfoDate updates the auth info for the user with the latest date.
// Avoids overlapping entries hiding the last used one (ex: LDAP->SAML->LDAP).
func (s *AuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login.UserAuth) error {
	authInfo.Created = GetTime()

	cond := &login.UserAuth{
		Id:         authInfo.Id,
		UserId:     authInfo.UserId,
		AuthModule: authInfo.AuthModule,
	}
	return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
		_, err := sess.Cols("created").Update(authInfo, cond)

		return err
	})
}

func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
	authUser := &login.UserAuth{
		UserId:     cmd.UserId,
		AuthModule: cmd.AuthModule,
		AuthId:     cmd.AuthId,
		Created:    GetTime(),
	}

	if cmd.OAuthToken != nil {
		secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken)
		if err != nil {
			return err
		}
		secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken)
		if err != nil {
			return err
		}
		secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType)
		if err != nil {
			return err
		}

		var secretIdToken string
		if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
			secretIdToken, err = s.encryptAndEncode(idToken)
			if err != nil {
				return err
			}
		}

		authUser.OAuthAccessToken = secretAccessToken
		authUser.OAuthRefreshToken = secretRefreshToken
		authUser.OAuthTokenType = secretTokenType
		authUser.OAuthIdToken = secretIdToken
		authUser.OAuthExpiry = cmd.OAuthToken.Expiry
	}

	return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
		upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser)

		s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_id", cmd.AuthId, "auth_module", cmd.AuthModule, "rows", upd)

		// Clean up duplicated entries
		if upd > 1 {
			var id int64
			ok, err := sess.SQL(
				"SELECT id FROM user_auth WHERE user_id = ? AND auth_module = ? AND auth_id = ?",
				cmd.UserId, cmd.AuthModule, cmd.AuthId,
			).Get(&id)

			if err != nil {
				return err
			}

			if !ok {
				return nil
			}

			_, err = sess.Exec(
				"DELETE FROM user_auth WHERE user_id = ? AND auth_module = ? AND auth_id = ? AND id != ?",
				cmd.UserId, cmd.AuthModule, cmd.AuthId, id,
			)
			return err
		}

		return err
	})
}

func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
	return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
		_, err := sess.Delete(cmd.UserAuth)
		return err
	})
}

func (s *AuthInfoStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
	return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
		var rawSQL = "DELETE FROM user_auth WHERE user_id = ?"
		_, err := sess.Exec(rawSQL, userID)
		return err
	})
}

func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
	query := user.GetUserByIDQuery{ID: id}
	user, err := s.userService.GetByID(ctx, &query)
	if err != nil {
		return nil, err
	}

	return user, nil
}

func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
	query := user.GetUserByLoginQuery{LoginOrEmail: login}
	usr, err := s.userService.GetByLogin(ctx, &query)
	if err != nil {
		return nil, err
	}

	return usr, nil
}

func (s *AuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
	query := user.GetUserByEmailQuery{Email: email}
	usr, err := s.userService.GetByEmail(ctx, &query)
	if err != nil {
		return nil, err
	}

	return usr, nil
}

// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it
func (s *AuthInfoStore) decodeAndDecrypt(str string) (string, error) {
	// Bail out if empty string since it'll cause a segfault in Decrypt
	if str == "" {
		return "", nil
	}
	decoded, err := base64.StdEncoding.DecodeString(str)
	if err != nil {
		return "", err
	}
	decrypted, err := s.secretsService.Decrypt(context.Background(), decoded)
	if err != nil {
		return "", err
	}
	return string(decrypted), nil
}

// encryptAndEncode will encrypt a string with grafana's secretKey, and
// then encode it with the standard bas64 encoder
func (s *AuthInfoStore) encryptAndEncode(str string) (string, error) {
	encrypted, err := s.secretsService.Encrypt(context.Background(), []byte(str), secrets.WithoutScope())
	if err != nil {
		return "", err
	}
	return base64.StdEncoding.EncodeToString(encrypted), nil
}