API handler opts modifier (#26148)

* POC for API handler opts modifier

* Made upload POSt api a  file upload API

* Specified file upload local API

* Specified file upload local API

* Specified file upload API

* Simplified handler params

* Added basic security checks

* Fixed i18n

* used type for API handler options

* Removed limited reader from util deserializers (#26263)
This commit is contained in:
Harshil Sharma 2024-02-21 17:43:50 +05:30 committed by GitHub
parent ecb09de6c7
commit 521844fed5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 127 additions and 165 deletions

View File

@ -13,7 +13,7 @@ import (
func (api *API) InitBrand() {
api.BaseRoutes.Brand.Handle("/image", api.APIHandlerTrustRequester(getBrandImage)).Methods("GET")
api.BaseRoutes.Brand.Handle("/image", api.APISessionRequired(uploadBrandImage)).Methods("POST")
api.BaseRoutes.Brand.Handle("/image", api.APISessionRequired(uploadBrandImage, handlerParamFileAPI)).Methods("POST")
api.BaseRoutes.Brand.Handle("/image", api.APISessionRequired(deleteBrandImage)).Methods("DELETE")
}

View File

@ -425,7 +425,7 @@ func restoreChannel(c *Context, w http.ResponseWriter, r *http.Request) {
}
func createDirectChannel(c *Context, w http.ResponseWriter, r *http.Request) {
userIds, err := model.NonSortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIds, err := model.NonSortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("createDirectChannel", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -520,7 +520,7 @@ func searchGroupChannels(c *Context, w http.ResponseWriter, r *http.Request) {
}
func createGroupChannel(c *Context, w http.ResponseWriter, r *http.Request) {
userIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("createGroupChannel", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -708,7 +708,7 @@ func getChannelsMemberCount(c *Context, w http.ResponseWriter, r *http.Request)
return
}
channelIDs, sortErr := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
channelIDs, sortErr := model.SortedArrayFromJSON(r.Body)
if sortErr != nil {
c.Err = model.NewAppError("getChannelsMemberCount", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(sortErr)
return
@ -914,7 +914,7 @@ func getPublicChannelsByIdsForTeam(c *Context, w http.ResponseWriter, r *http.Re
return
}
channelIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
channelIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getPublicChannelsByIdsForTeam", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -1457,7 +1457,7 @@ func getChannelMembersByIds(c *Context, w http.ResponseWriter, r *http.Request)
return
}
userIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getChannelMembersByIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -1583,7 +1583,7 @@ func viewChannel(c *Context, w http.ResponseWriter, r *http.Request) {
func readMultipleChannels(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequireUserId()
channelIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
channelIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("readMultipleChannels", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -118,7 +118,7 @@ func updateCategoryOrderForTeamForUser(c *Context, w http.ResponseWriter, r *htt
auditRec := c.MakeAuditRecord("updateCategoryOrderForTeamForUser", audit.Fail)
defer c.LogAuditRec(auditRec)
categoryOrder, err := model.NonSortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
categoryOrder, err := model.NonSortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("updateCategoryOrderForTeamForUser", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -264,7 +264,7 @@ func searchTeamsInPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
func addTeamsToPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequirePolicyId()
policyId := c.Params.PolicyId
teamIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
teamIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("addTeamsToPolicy", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -291,7 +291,7 @@ func addTeamsToPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
func removeTeamsFromPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequirePolicyId()
policyId := c.Params.PolicyId
teamIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
teamIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("removeTeamsFromPolicy", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -383,7 +383,7 @@ func searchChannelsInPolicy(c *Context, w http.ResponseWriter, r *http.Request)
func addChannelsToPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequirePolicyId()
policyId := c.Params.PolicyId
channelIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
channelIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("addChannelsToPolicy", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -411,7 +411,7 @@ func addChannelsToPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
func removeChannelsFromPolicy(c *Context, w http.ResponseWriter, r *http.Request) {
c.RequirePolicyId()
policyId := c.Params.PolicyId
channelIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
channelIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("removeChannelsFromPolicy", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -240,7 +240,7 @@ func getEmojiByName(c *Context, w http.ResponseWriter, r *http.Request) {
}
func getEmojisByNames(c *Context, w http.ResponseWriter, r *http.Request) {
names, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
names, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getEmojisByNames", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -32,7 +32,7 @@ const (
const maxMultipartFormDataBytes = 10 * 1024 // 10Kb
func (api *API) InitFile() {
api.BaseRoutes.Files.Handle("", api.APISessionRequired(uploadFileStream)).Methods("POST")
api.BaseRoutes.Files.Handle("", api.APISessionRequired(uploadFileStream, handlerParamFileAPI)).Methods("POST")
api.BaseRoutes.File.Handle("", api.APISessionRequiredTrustRequester(getFile)).Methods("GET")
api.BaseRoutes.File.Handle("/thumbnail", api.APISessionRequiredTrustRequester(getFileThumbnail)).Methods("GET")
api.BaseRoutes.File.Handle("/link", api.APISessionRequired(getFileLink)).Methods("GET")

View File

@ -18,9 +18,15 @@ type Context = web.Context
type handlerFunc func(*Context, http.ResponseWriter, *http.Request)
type APIHandlerOption string
const (
handlerParamFileAPI = APIHandlerOption("fileAPI")
)
// APIHandler provides a handler for API endpoints which do not require the user to be logged in order for access to be
// granted.
func (api *API) APIHandler(h handlerFunc) http.Handler {
func (api *API) APIHandler(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -31,6 +37,8 @@ func (api *API) APIHandler(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -39,7 +47,7 @@ func (api *API) APIHandler(h handlerFunc) http.Handler {
// APISessionRequired provides a handler for API endpoints which require the user to be logged in in order for access to
// be granted.
func (api *API) APISessionRequired(h handlerFunc) http.Handler {
func (api *API) APISessionRequired(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -50,6 +58,8 @@ func (api *API) APISessionRequired(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -57,7 +67,7 @@ func (api *API) APISessionRequired(h handlerFunc) http.Handler {
}
// CloudAPIKeyRequired provides a handler for webhook endpoints to access Cloud installations from CWS
func (api *API) CloudAPIKeyRequired(h handlerFunc) http.Handler {
func (api *API) CloudAPIKeyRequired(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -69,6 +79,8 @@ func (api *API) CloudAPIKeyRequired(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -76,7 +88,7 @@ func (api *API) CloudAPIKeyRequired(h handlerFunc) http.Handler {
}
// RemoteClusterTokenRequired provides a handler for remote cluster requests to /remotecluster endpoints.
func (api *API) RemoteClusterTokenRequired(h handlerFunc) http.Handler {
func (api *API) RemoteClusterTokenRequired(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -89,6 +101,8 @@ func (api *API) RemoteClusterTokenRequired(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -98,7 +112,7 @@ func (api *API) RemoteClusterTokenRequired(h handlerFunc) http.Handler {
// APISessionRequiredMfa provides a handler for API endpoints which require a logged-in user session but when accessed,
// if MFA is enabled, the MFA process is not yet complete, and therefore the requirement to have completed the MFA
// authentication must be waived.
func (api *API) APISessionRequiredMfa(h handlerFunc) http.Handler {
func (api *API) APISessionRequiredMfa(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -109,6 +123,8 @@ func (api *API) APISessionRequiredMfa(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -118,7 +134,7 @@ func (api *API) APISessionRequiredMfa(h handlerFunc) http.Handler {
// APIHandlerTrustRequester provides a handler for API endpoints which do not require the user to be logged in and are
// allowed to be requested directly rather than via javascript/XMLHttpRequest, such as site branding images or the
// websocket.
func (api *API) APIHandlerTrustRequester(h handlerFunc) http.Handler {
func (api *API) APIHandlerTrustRequester(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -129,6 +145,8 @@ func (api *API) APIHandlerTrustRequester(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -137,7 +155,7 @@ func (api *API) APIHandlerTrustRequester(h handlerFunc) http.Handler {
// APISessionRequiredTrustRequester provides a handler for API endpoints which do require the user to be logged in and
// are allowed to be requested directly rather than via javascript/XMLHttpRequest, such as emoji or file uploads.
func (api *API) APISessionRequiredTrustRequester(h handlerFunc) http.Handler {
func (api *API) APISessionRequiredTrustRequester(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -148,6 +166,8 @@ func (api *API) APISessionRequiredTrustRequester(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: false,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -156,7 +176,7 @@ func (api *API) APISessionRequiredTrustRequester(h handlerFunc) http.Handler {
// DisableWhenBusy provides a handler for API endpoints which should be disabled when the server is under load,
// responding with HTTP 503 (Service Unavailable).
func (api *API) APISessionRequiredDisableWhenBusy(h handlerFunc) http.Handler {
func (api *API) APISessionRequiredDisableWhenBusy(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -168,6 +188,8 @@ func (api *API) APISessionRequiredDisableWhenBusy(h handlerFunc) http.Handler {
IsLocal: false,
DisableWhenBusy: true,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
}
@ -178,7 +200,7 @@ func (api *API) APISessionRequiredDisableWhenBusy(h handlerFunc) http.Handler {
// mode, this is, through a UNIX socket and without an authenticated
// session, but with one that has no user set and no permission
// restrictions
func (api *API) APILocal(h handlerFunc) http.Handler {
func (api *API) APILocal(h handlerFunc, opts ...APIHandlerOption) http.Handler {
handler := &web.Handler{
Srv: api.srv,
HandleFunc: h,
@ -189,6 +211,7 @@ func (api *API) APILocal(h handlerFunc) http.Handler {
IsStatic: false,
IsLocal: true,
}
setHandlerOpts(handler, opts...)
if *api.srv.Config().ServiceSettings.WebserverMode == "gzip" {
return gzhttp.GzipHandler(handler)
@ -223,3 +246,16 @@ func minimumProfessionalLicense(c *Context) *model.AppError {
}
return nil
}
func setHandlerOpts(handler *web.Handler, opts ...APIHandlerOption) {
if len(opts) == 0 {
return
}
for _, option := range opts {
switch option {
case handlerParamFileAPI:
handler.FileAPI = true
}
}
}

View File

@ -6,7 +6,6 @@ package api4
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
@ -220,8 +219,7 @@ func createOutgoingOAuthConnection(c *Context, w http.ResponseWriter, r *http.Re
}
var inputConnection model.OutgoingOAuthConnection
bodyReader := io.LimitReader(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
if err := json.NewDecoder(bodyReader).Decode(&inputConnection); err != nil {
if err := json.NewDecoder(r.Body).Decode(&inputConnection); err != nil {
c.Err = model.NewAppError(whereOutgoingOAuthConnection, "api.context.outgoing_oauth_connection.create_connection.input_error", nil, err.Error(), http.StatusBadRequest)
return
}
@ -271,8 +269,7 @@ func updateOutgoingOAuthConnection(c *Context, w http.ResponseWriter, r *http.Re
}
var inputConnection model.OutgoingOAuthConnection
bodyReader := io.LimitReader(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
if err := json.NewDecoder(bodyReader).Decode(&inputConnection); err != nil {
if err := json.NewDecoder(r.Body).Decode(&inputConnection); err != nil {
c.Err = model.NewAppError(whereOutgoingOAuthConnection, "api.context.outgoing_oauth_connection.update_connection.input_error", nil, err.Error(), http.StatusBadRequest)
return
}
@ -376,8 +373,7 @@ func validateOutgoingOAuthConnectionCredentials(c *Context, w http.ResponseWrite
// connection url.
var inputConnection *model.OutgoingOAuthConnection
bodyReader := io.LimitReader(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
if err := json.NewDecoder(bodyReader).Decode(&inputConnection); err != nil {
if err := json.NewDecoder(r.Body).Decode(&inputConnection); err != nil {
c.Err = model.NewAppError(whereOutgoingOAuthConnection, "api.context.outgoing_oauth_connection.validate_connection_credentials.input_error", nil, err.Error(), http.StatusBadRequest)
w.WriteHeader(c.Err.StatusCode)
return

View File

@ -504,7 +504,7 @@ func getPost(c *Context, w http.ResponseWriter, r *http.Request) {
// getPostsByIds also sets a header to indicate, if posts were truncated as per the cloud plan's limit.
func getPostsByIds(c *Context, w http.ResponseWriter, r *http.Request) {
postIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
postIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getPostsByIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -103,7 +103,7 @@ func updatePreferences(c *Context, w http.ResponseWriter, r *http.Request) {
}
var preferences model.Preferences
err := model.StructFromJSONLimited(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes, &preferences)
err := model.StructFromJSONLimited(r.Body, &preferences)
if err != nil {
c.SetInvalidParamWithErr("preferences", err)
return
@ -155,7 +155,7 @@ func deletePreferences(c *Context, w http.ResponseWriter, r *http.Request) {
}
var preferences model.Preferences
err := model.StructFromJSONLimited(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes, &preferences)
err := model.StructFromJSONLimited(r.Body, &preferences)
if err != nil {
c.SetInvalidParamWithErr("preferences", err)
return

View File

@ -119,7 +119,7 @@ func deleteReaction(c *Context, w http.ResponseWriter, r *http.Request) {
}
func getBulkReactions(c *Context, w http.ResponseWriter, r *http.Request) {
postIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
postIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getBulkReactions", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -84,7 +84,7 @@ func getRoleByName(c *Context, w http.ResponseWriter, r *http.Request) {
}
func getRolesByNames(c *Context, w http.ResponseWriter, r *http.Request) {
rolenames, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
rolenames, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getRolesByNames", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -49,7 +49,7 @@ func getUserStatus(c *Context, w http.ResponseWriter, r *http.Request) {
}
func getUserStatusesByIds(c *Context, w http.ResponseWriter, r *http.Request) {
userIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getUserStatusesByIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -968,7 +968,7 @@ func updateViewedProductNotices(c *Context, w http.ResponseWriter, r *http.Reque
defer c.LogAuditRec(auditRec)
c.LogAudit("attempt")
ids, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
ids, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("updateViewedProductNotices", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -49,7 +49,7 @@ func (api *API) InitTeam() {
api.BaseRoutes.Team.Handle("/regenerate_invite_id", api.APISessionRequired(regenerateTeamInviteId)).Methods("POST")
api.BaseRoutes.Team.Handle("/image", api.APISessionRequiredTrustRequester(getTeamIcon)).Methods("GET")
api.BaseRoutes.Team.Handle("/image", api.APISessionRequired(setTeamIcon)).Methods("POST")
api.BaseRoutes.Team.Handle("/image", api.APISessionRequired(setTeamIcon, handlerParamFileAPI)).Methods("POST")
api.BaseRoutes.Team.Handle("/image", api.APISessionRequired(removeTeamIcon)).Methods("DELETE")
api.BaseRoutes.TeamMembers.Handle("", api.APISessionRequired(getTeamMembers)).Methods("GET")
@ -644,7 +644,7 @@ func getTeamMembersByIds(c *Context, w http.ResponseWriter, r *http.Request) {
return
}
userIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getTeamMembersByIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -1376,7 +1376,7 @@ func inviteUsersToTeam(c *Context, w http.ResponseWriter, r *http.Request) {
}
memberInvite := &model.MemberInvite{}
err := model.StructFromJSONLimited(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes, memberInvite)
err := model.StructFromJSONLimited(r.Body, memberInvite)
if err != nil {
c.Err = model.NewAppError("Api4.inviteUsersToTeams", "api.team.invite_members_to_team_and_channels.invalid_body.app_error", nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -79,7 +79,7 @@ func localInviteUsersToTeam(c *Context, w http.ResponseWriter, r *http.Request)
}
memberInvite := &model.MemberInvite{}
err := model.StructFromJSONLimited(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes, memberInvite)
err := model.StructFromJSONLimited(r.Body, memberInvite)
if err != nil {
c.Err = model.NewAppError("Api4.localInviteUsersToTeam", "api.team.invite_members_to_team_and_channels.invalid_body.app_error", nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -17,9 +17,9 @@ import (
)
func (api *API) InitUpload() {
api.BaseRoutes.Uploads.Handle("", api.APISessionRequired(createUpload)).Methods("POST")
api.BaseRoutes.Uploads.Handle("", api.APISessionRequired(createUpload, handlerParamFileAPI)).Methods("POST")
api.BaseRoutes.Upload.Handle("", api.APISessionRequired(getUpload)).Methods("GET")
api.BaseRoutes.Upload.Handle("", api.APISessionRequired(uploadData)).Methods("POST")
api.BaseRoutes.Upload.Handle("", api.APISessionRequired(uploadData, handlerParamFileAPI)).Methods("POST")
}
func createUpload(c *Context, w http.ResponseWriter, r *http.Request) {

View File

@ -4,7 +4,7 @@
package api4
func (api *API) InitUploadLocal() {
api.BaseRoutes.Uploads.Handle("", api.APILocal(createUpload)).Methods("POST")
api.BaseRoutes.Uploads.Handle("", api.APILocal(createUpload, handlerParamFileAPI)).Methods("POST")
api.BaseRoutes.Upload.Handle("", api.APILocal(getUpload)).Methods("GET")
api.BaseRoutes.Upload.Handle("", api.APILocal(uploadData)).Methods("POST")
api.BaseRoutes.Upload.Handle("", api.APILocal(uploadData, handlerParamFileAPI)).Methods("POST")
}

View File

@ -38,7 +38,7 @@ func (api *API) InitUser() {
api.BaseRoutes.User.Handle("", api.APISessionRequired(getUser)).Methods("GET")
api.BaseRoutes.User.Handle("/image/default", api.APISessionRequiredTrustRequester(getDefaultProfileImage)).Methods("GET")
api.BaseRoutes.User.Handle("/image", api.APISessionRequiredTrustRequester(getProfileImage)).Methods("GET")
api.BaseRoutes.User.Handle("/image", api.APISessionRequired(setProfileImage)).Methods("POST")
api.BaseRoutes.User.Handle("/image", api.APISessionRequired(setProfileImage, handlerParamFileAPI)).Methods("POST")
api.BaseRoutes.User.Handle("/image", api.APISessionRequired(setDefaultProfileImage)).Methods("DELETE")
api.BaseRoutes.User.Handle("", api.APISessionRequired(updateUser)).Methods("PUT")
api.BaseRoutes.User.Handle("/patch", api.APISessionRequired(patchUser)).Methods("PUT")
@ -623,7 +623,7 @@ func getFilteredUsersStats(c *Context, w http.ResponseWriter, r *http.Request) {
}
func getUsersByGroupChannelIds(c *Context, w http.ResponseWriter, r *http.Request) {
channelIds, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
channelIds, err := model.SortedArrayFromJSON(r.Body)
if err != nil || len(channelIds) == 0 {
c.Err = model.NewAppError("getUsersByGroupChannelIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -955,7 +955,7 @@ func requireGroupAccess(c *web.Context, groupID string) *model.AppError {
}
func getUsersByIds(c *Context, w http.ResponseWriter, r *http.Request) {
userIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getUsersByIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return
@ -1002,7 +1002,7 @@ func getUsersByIds(c *Context, w http.ResponseWriter, r *http.Request) {
}
func getUsersByNames(c *Context, w http.ResponseWriter, r *http.Request) {
usernames, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
usernames, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("getUsersByNames", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -233,7 +233,7 @@ func localGetUsers(c *Context, w http.ResponseWriter, r *http.Request) {
}
func localGetUsersByIds(c *Context, w http.ResponseWriter, r *http.Request) {
userIDs, err := model.SortedArrayFromJSON(r.Body, *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes)
userIDs, err := model.SortedArrayFromJSON(r.Body)
if err != nil {
c.Err = model.NewAppError("localGetUsersByIds", model.PayloadParseError, nil, "", http.StatusBadRequest).Wrap(err)
return

View File

@ -31,7 +31,8 @@ import (
)
const (
frameAncestors = "'self' teams.microsoft.com"
frameAncestors = "'self' teams.microsoft.com"
maxURLCharacters = 2048
)
func GetHandlerName(h func(*Context, http.ResponseWriter, *http.Request)) string {
@ -86,6 +87,7 @@ type Handler struct {
IsStatic bool
IsLocal bool
DisableWhenBusy bool
FileAPI bool
cspShaDirective string
}
@ -142,7 +144,20 @@ func generateDevCSP(c Context) string {
return " " + strings.Join(devCSP, " ")
}
func (h Handler) basicSecurityChecks(w http.ResponseWriter, r *http.Request) *model.AppError {
if len(r.RequestURI) > maxURLCharacters {
return model.NewAppError("basicSecurityChecks", "basic_security_check.url.too_long_error", nil, "", http.StatusRequestURITooLong)
}
return nil
}
func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if appErr := h.basicSecurityChecks(w, r); appErr != nil {
http.Error(w, appErr.Error(), appErr.StatusCode)
return
}
w = newWrappedWriter(w)
now := time.Now()
@ -213,13 +228,16 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.App = app_opentracing.NewOpenTracingAppLayer(c.App, ctx)
}
// Set the max request body size to be equal to MaxFileSize.
// Ideally, non-file request bodies should be smaller than file request bodies,
// but we don't have a clean way to identify all file upload handlers.
// So to keep it simple, we clamp it to the max file size.
// We add a buffer of bytes.MinRead so that file sizes close to max file size
// do not get cut off.
r.Body = http.MaxBytesReader(w, r.Body, *c.App.Config().FileSettings.MaxFileSize+bytes.MinRead)
var maxBytes int64
if h.FileAPI {
// We add a buffer of bytes.MinRead so that file sizes close to max file size
// do not get cut off.
maxBytes = *c.App.Config().FileSettings.MaxFileSize + bytes.MinRead
} else {
maxBytes = *c.App.Config().ServiceSettings.MaximumPayloadSizeBytes + bytes.MinRead
}
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
subpath, _ := utils.GetSubpathFromConfig(c.App.Config())
siteURLHeader := app.GetProtocol(r) + "://" + r.Host + subpath

View File

@ -7530,6 +7530,10 @@
"id": "app.webhooks.update_outgoing.app_error",
"translation": "Unable to update the webhook."
},
{
"id": "basic_security_check.url.too_long_error",
"translation": "URL is too long"
},
{
"id": "bleveengine.already_started.error",
"translation": "Bleve is already started."

View File

@ -498,10 +498,9 @@ func ArrayFromJSON(data io.Reader) []string {
return objmap
}
func SortedArrayFromJSON(data io.Reader, maxBytes int64) ([]string, error) {
func SortedArrayFromJSON(data io.Reader) ([]string, error) {
var obj []string
lr := io.LimitReader(data, maxBytes)
err := json.NewDecoder(lr).Decode(&obj)
err := json.NewDecoder(data).Decode(&obj)
if err != nil || obj == nil {
return nil, err
}
@ -510,10 +509,9 @@ func SortedArrayFromJSON(data io.Reader, maxBytes int64) ([]string, error) {
return RemoveDuplicateStrings(obj), nil
}
func NonSortedArrayFromJSON(data io.Reader, maxBytes int64) ([]string, error) {
func NonSortedArrayFromJSON(data io.Reader) ([]string, error) {
var obj []string
lr := io.LimitReader(data, maxBytes)
err := json.NewDecoder(lr).Decode(&obj)
err := json.NewDecoder(data).Decode(&obj)
if err != nil || obj == nil {
return nil, err
}
@ -555,9 +553,8 @@ func StringInterfaceFromJSON(data io.Reader) map[string]any {
return objmap
}
func StructFromJSONLimited[V any](data io.Reader, maxBytes int64, obj *V) error {
lr := io.LimitReader(data, maxBytes)
err := json.NewDecoder(lr).Decode(&obj)
func StructFromJSONLimited[V any](data io.Reader, obj *V) error {
err := json.NewDecoder(data).Decode(&obj)
if err != nil || obj == nil {
return err
}

View File

@ -8,7 +8,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"strings"
@ -222,7 +221,7 @@ func TestSortedArrayFromJSON(t *testing.T) {
t.Run("Successful parse", func(t *testing.T) {
ids := []string{NewId(), NewId(), NewId()}
b, _ := json.Marshal(ids)
a, err := SortedArrayFromJSON(bytes.NewReader(b), 1000)
a, err := SortedArrayFromJSON(bytes.NewReader(b))
require.NoError(t, err)
require.ElementsMatch(t, ids, a)
})
@ -230,22 +229,11 @@ func TestSortedArrayFromJSON(t *testing.T) {
t.Run("Empty Array", func(t *testing.T) {
ids := []string{}
b, _ := json.Marshal(ids)
a, err := SortedArrayFromJSON(bytes.NewReader(b), 1000)
a, err := SortedArrayFromJSON(bytes.NewReader(b))
require.NoError(t, err)
require.Empty(t, a)
})
t.Run("Error too large", func(t *testing.T) {
var ids []string
for i := 0; i <= 100; i++ {
ids = append(ids, NewId())
}
b, _ := json.Marshal(ids)
_, err := SortedArrayFromJSON(bytes.NewReader(b), 1000)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
})
t.Run("Duplicate keys, returns one", func(t *testing.T) {
var ids []string
id := NewId()
@ -253,7 +241,7 @@ func TestSortedArrayFromJSON(t *testing.T) {
ids = append(ids, id)
}
b, _ := json.Marshal(ids)
a, err := SortedArrayFromJSON(bytes.NewReader(b), 26000)
a, err := SortedArrayFromJSON(bytes.NewReader(b))
require.NoError(t, err)
require.Len(t, a, 1)
})
@ -263,7 +251,7 @@ func TestNonSortedArrayFromJSON(t *testing.T) {
t.Run("Successful parse", func(t *testing.T) {
ids := []string{NewId(), NewId(), NewId()}
b, _ := json.Marshal(ids)
a, err := NonSortedArrayFromJSON(bytes.NewReader(b), 1000)
a, err := NonSortedArrayFromJSON(bytes.NewReader(b))
require.NoError(t, err)
require.Equal(t, ids, a)
})
@ -271,22 +259,11 @@ func TestNonSortedArrayFromJSON(t *testing.T) {
t.Run("Empty Array", func(t *testing.T) {
ids := []string{}
b, _ := json.Marshal(ids)
a, err := NonSortedArrayFromJSON(bytes.NewReader(b), 1000)
a, err := NonSortedArrayFromJSON(bytes.NewReader(b))
require.NoError(t, err)
require.Empty(t, a)
})
t.Run("Error too large", func(t *testing.T) {
var ids []string
for i := 0; i <= 100; i++ {
ids = append(ids, NewId())
}
b, _ := json.Marshal(ids)
_, err := NonSortedArrayFromJSON(bytes.NewReader(b), 1000)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
})
t.Run("Duplicate keys, returns one", func(t *testing.T) {
var ids []string
id := NewId()
@ -294,7 +271,7 @@ func TestNonSortedArrayFromJSON(t *testing.T) {
ids = append(ids, id)
}
b, _ := json.Marshal(ids)
a, err := NonSortedArrayFromJSON(bytes.NewReader(b), 26000)
a, err := NonSortedArrayFromJSON(bytes.NewReader(b))
require.NoError(t, err)
require.Len(t, a, 1)
})
@ -1243,7 +1220,7 @@ func TestStructFromJSONLimited(t *testing.T) {
require.NoError(t, err)
b := &TestStruct{}
err = StructFromJSONLimited(bytes.NewReader(testStructBytes), 1000, b)
err = StructFromJSONLimited(bytes.NewReader(testStructBytes), b)
require.NoError(t, err)
require.Equal(t, b.StringField, "string")
@ -1252,29 +1229,6 @@ func TestStructFromJSONLimited(t *testing.T) {
require.Equal(t, b.BoolField, true)
})
t.Run("error too big", func(t *testing.T) {
type TestStruct struct {
StringField string
IntField int
FloatField float32
BoolField bool
}
testStruct := TestStruct{
StringField: "string",
IntField: 2,
FloatField: 3.1415,
BoolField: true,
}
testStructBytes, err := json.Marshal(testStruct)
require.NoError(t, err)
b := &TestStruct{}
err = StructFromJSONLimited(bytes.NewReader(testStructBytes), 10, b)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
})
t.Run("successfully parses nested struct", func(t *testing.T) {
type TestStruct struct {
StringField string
@ -1313,7 +1267,7 @@ func TestStructFromJSONLimited(t *testing.T) {
require.NoError(t, err)
b := &NestedStruct{}
err = StructFromJSONLimited(bytes.NewReader(nestedStructBytes), 1000, b)
err = StructFromJSONLimited(bytes.NewReader(nestedStructBytes), b)
require.NoError(t, err)
require.Equal(t, b.FieldA.StringField, "string A")
@ -1329,49 +1283,6 @@ func TestStructFromJSONLimited(t *testing.T) {
require.Equal(t, b.FieldC, []int{5, 9, 1, 5, 7})
})
t.Run("errors on too big nested struct", func(t *testing.T) {
type TestStruct struct {
StringField string
IntField int
FloatField float32
BoolField bool
}
type NestedStruct struct {
FieldA TestStruct
FieldB TestStruct
FieldC []int
}
testStructA := TestStruct{
StringField: "string A",
IntField: 2,
FloatField: 3.1415,
BoolField: true,
}
testStructB := TestStruct{
StringField: "string B",
IntField: 3,
FloatField: 100,
BoolField: false,
}
nestedStruct := NestedStruct{
FieldA: testStructA,
FieldB: testStructB,
FieldC: []int{5, 9, 1, 5, 7},
}
nestedStructBytes, err := json.Marshal(nestedStruct)
require.NoError(t, err)
b := &NestedStruct{}
err = StructFromJSONLimited(bytes.NewReader(nestedStructBytes), 50, b)
require.Error(t, err)
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
})
t.Run("handles empty structs", func(t *testing.T) {
type TestStruct struct{}
@ -1380,7 +1291,7 @@ func TestStructFromJSONLimited(t *testing.T) {
require.NoError(t, err)
b := &TestStruct{}
err = StructFromJSONLimited(bytes.NewReader(testStructBytes), 1000, b)
err = StructFromJSONLimited(bytes.NewReader(testStructBytes), b)
require.NoError(t, err)
})
}