mirror of
https://github.com/mattermost/mattermost.git
synced 2025-02-25 18:55:24 -06:00
MM-52898: WS fix (#23489)
https://mattermost.atlassian.net/browse/MM-52898 ```release-note NONE ```
This commit is contained in:
@@ -1152,7 +1152,31 @@ func (a *App) OriginChecker() func(*http.Request) bool {
|
||||
|
||||
return utils.OriginChecker(allowed)
|
||||
}
|
||||
return nil
|
||||
|
||||
// Overriding the default origin checker
|
||||
return func(r *http.Request) bool {
|
||||
origin := r.Header["Origin"]
|
||||
if len(origin) == 0 {
|
||||
return true
|
||||
}
|
||||
if origin[0] == "null" {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(origin[0])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// To maintain the case where siteURL is not set.
|
||||
if *a.Config().ServiceSettings.SiteURL == "" {
|
||||
return strings.EqualFold(u.Host, r.Host)
|
||||
}
|
||||
siteURL, err := url.Parse(*a.Config().ServiceSettings.SiteURL)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(u.Host, siteURL.Host) && strings.EqualFold(u.Scheme, siteURL.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) checkPushNotificationServerURL() {
|
||||
|
||||
@@ -494,3 +494,62 @@ func TestCancelTaskSetsTaskToNil(t *testing.T) {
|
||||
require.Nil(t, task)
|
||||
require.NotPanics(t, func() { cancelTask(&taskMut, &task) })
|
||||
}
|
||||
|
||||
func TestOriginChecker(t *testing.T) {
|
||||
th := Setup(t)
|
||||
defer th.TearDown()
|
||||
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.AllowCorsFrom = ""
|
||||
})
|
||||
|
||||
tcs := []struct {
|
||||
SiteURL string
|
||||
HeaderScheme string
|
||||
HeaderHost string
|
||||
Pass bool
|
||||
}{
|
||||
{
|
||||
HeaderHost: "test.com",
|
||||
HeaderScheme: "https://",
|
||||
SiteURL: "https://test.com",
|
||||
Pass: true,
|
||||
},
|
||||
{
|
||||
HeaderHost: "test.com",
|
||||
HeaderScheme: "http://",
|
||||
SiteURL: "https://test.com",
|
||||
Pass: false,
|
||||
},
|
||||
{
|
||||
HeaderHost: "test.com",
|
||||
HeaderScheme: "https://",
|
||||
SiteURL: "https://www.test.com",
|
||||
Pass: false,
|
||||
},
|
||||
{
|
||||
HeaderHost: "example.com",
|
||||
HeaderScheme: "http://",
|
||||
SiteURL: "http://test.com",
|
||||
Pass: false,
|
||||
},
|
||||
{
|
||||
HeaderHost: "null",
|
||||
HeaderScheme: "",
|
||||
SiteURL: "http://test.com",
|
||||
Pass: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range tcs {
|
||||
th.App.UpdateConfig(func(cfg *model.Config) {
|
||||
*cfg.ServiceSettings.SiteURL = tc.SiteURL
|
||||
})
|
||||
|
||||
r := &http.Request{
|
||||
Header: http.Header{"Origin": []string{fmt.Sprintf("%s%s", tc.HeaderScheme, tc.HeaderHost)}},
|
||||
}
|
||||
res := th.App.OriginChecker()(r)
|
||||
require.Equalf(t, tc.Pass, res, "Test case (%d)", i)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user