MM-52898: WS fix (#23489)

https://mattermost.atlassian.net/browse/MM-52898

```release-note
NONE
```
This commit is contained in:
Agniva De Sarker
2023-05-24 23:40:12 +05:30
committed by GitHub
parent c8ee05fb76
commit 289a855330
2 changed files with 84 additions and 1 deletions

View File

@@ -1152,7 +1152,31 @@ func (a *App) OriginChecker() func(*http.Request) bool {
return utils.OriginChecker(allowed)
}
return nil
// Overriding the default origin checker
return func(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
if origin[0] == "null" {
return false
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
// To maintain the case where siteURL is not set.
if *a.Config().ServiceSettings.SiteURL == "" {
return strings.EqualFold(u.Host, r.Host)
}
siteURL, err := url.Parse(*a.Config().ServiceSettings.SiteURL)
if err != nil {
return false
}
return strings.EqualFold(u.Host, siteURL.Host) && strings.EqualFold(u.Scheme, siteURL.Scheme)
}
}
func (s *Server) checkPushNotificationServerURL() {

View File

@@ -494,3 +494,62 @@ func TestCancelTaskSetsTaskToNil(t *testing.T) {
require.Nil(t, task)
require.NotPanics(t, func() { cancelTask(&taskMut, &task) })
}
func TestOriginChecker(t *testing.T) {
th := Setup(t)
defer th.TearDown()
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.AllowCorsFrom = ""
})
tcs := []struct {
SiteURL string
HeaderScheme string
HeaderHost string
Pass bool
}{
{
HeaderHost: "test.com",
HeaderScheme: "https://",
SiteURL: "https://test.com",
Pass: true,
},
{
HeaderHost: "test.com",
HeaderScheme: "http://",
SiteURL: "https://test.com",
Pass: false,
},
{
HeaderHost: "test.com",
HeaderScheme: "https://",
SiteURL: "https://www.test.com",
Pass: false,
},
{
HeaderHost: "example.com",
HeaderScheme: "http://",
SiteURL: "http://test.com",
Pass: false,
},
{
HeaderHost: "null",
HeaderScheme: "",
SiteURL: "http://test.com",
Pass: false,
},
}
for i, tc := range tcs {
th.App.UpdateConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.SiteURL = tc.SiteURL
})
r := &http.Request{
Header: http.Header{"Origin": []string{fmt.Sprintf("%s%s", tc.HeaderScheme, tc.HeaderHost)}},
}
res := th.App.OriginChecker()(r)
require.Equalf(t, tc.Pass, res, "Test case (%d)", i)
}
}