mirror of
https://github.com/discourse/discourse.git
synced 2025-02-25 18:55:32 -06:00
FIX: handle CORS in hijacked requests
This commit is contained in:
parent
c64774f4f8
commit
90a55d6f7c
@ -1,41 +1,52 @@
|
||||
if GlobalSetting.enable_cors
|
||||
class Discourse::Cors
|
||||
def initialize(app, options = nil)
|
||||
@app = app
|
||||
if GlobalSetting.enable_cors && GlobalSetting.cors_origin.present?
|
||||
@global_origins = GlobalSetting.cors_origin.split(',').map(&:strip)
|
||||
end
|
||||
end
|
||||
# frozen_string_literal: true
|
||||
|
||||
def call(env)
|
||||
if env['REQUEST_METHOD'] == ('OPTIONS') && env['HTTP_ACCESS_CONTROL_REQUEST_METHOD']
|
||||
return [200, apply_headers(env), []]
|
||||
end
|
||||
class Discourse::Cors
|
||||
ORIGINS_ENV = "Discourse_Cors_Origins"
|
||||
|
||||
status, headers, body = @app.call(env)
|
||||
[status, apply_headers(env, headers), body]
|
||||
end
|
||||
|
||||
def apply_headers(env, headers = nil)
|
||||
headers ||= {}
|
||||
|
||||
origin = nil
|
||||
cors_origins = @global_origins || []
|
||||
cors_origins += SiteSetting.cors_origins.split('|') if SiteSetting.cors_origins
|
||||
|
||||
if cors_origins
|
||||
if origin = env['HTTP_ORIGIN']
|
||||
origin = nil unless cors_origins.include?(origin)
|
||||
end
|
||||
|
||||
headers['Access-Control-Allow-Origin'] = origin || cors_origins[0]
|
||||
headers['Access-Control-Allow-Headers'] = 'X-Requested-With, X-CSRF-Token, Discourse-Visible'
|
||||
headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
end
|
||||
|
||||
headers
|
||||
def initialize(app, options = nil)
|
||||
@app = app
|
||||
if GlobalSetting.enable_cors && GlobalSetting.cors_origin.present?
|
||||
@global_origins = GlobalSetting.cors_origin.split(',').map(&:strip)
|
||||
end
|
||||
end
|
||||
|
||||
def call(env)
|
||||
|
||||
cors_origins = @global_origins || []
|
||||
cors_origins += SiteSetting.cors_origins.split('|') if SiteSetting.cors_origins.present?
|
||||
cors_origins = cors_origins.presence
|
||||
|
||||
if env['REQUEST_METHOD'] == ('OPTIONS') && env['HTTP_ACCESS_CONTROL_REQUEST_METHOD']
|
||||
return [200, Discourse::Cors.apply_headers(cors_origins, env, {}), []]
|
||||
end
|
||||
|
||||
env[Discourse::Cors::ORIGINS_ENV] = cors_origins if cors_origins
|
||||
|
||||
status, headers, body = @app.call(env)
|
||||
headers ||= {}
|
||||
|
||||
Discourse::Cors.apply_headers(cors_origins, env, headers) if cors_origins
|
||||
|
||||
[status, headers, body]
|
||||
end
|
||||
|
||||
def self.apply_headers(cors_origins, env, headers)
|
||||
origin = nil
|
||||
|
||||
if cors_origins
|
||||
if origin = env['HTTP_ORIGIN']
|
||||
origin = nil unless cors_origins.include?(origin)
|
||||
end
|
||||
|
||||
headers['Access-Control-Allow-Origin'] = origin || cors_origins[0]
|
||||
headers['Access-Control-Allow-Headers'] = 'X-Requested-With, X-CSRF-Token, Discourse-Visible'
|
||||
headers['Access-Control-Allow-Credentials'] = 'true'
|
||||
end
|
||||
|
||||
headers
|
||||
end
|
||||
end
|
||||
|
||||
if GlobalSetting.enable_cors
|
||||
Rails.configuration.middleware.insert_before ActionDispatch::Flash, Discourse::Cors
|
||||
end
|
||||
|
@ -55,6 +55,11 @@ module Hijack
|
||||
body = response.body
|
||||
|
||||
headers = response.headers
|
||||
# add cors if needed
|
||||
if cors_origins = env_copy[Discourse::Cors::ORIGINS_ENV]
|
||||
Discourse::Cors.apply_headers(cors_origins, env_copy, headers)
|
||||
end
|
||||
|
||||
headers['Content-Length'] = body.bytesize
|
||||
headers['Content-Type'] = response.content_type || "text/plain"
|
||||
headers['Connection'] = "close"
|
||||
|
@ -79,6 +79,41 @@ describe Hijack do
|
||||
expect(copy_req.object_id).not_to eq(orig_req.object_id)
|
||||
end
|
||||
|
||||
it "handles cors" do
|
||||
SiteSetting.cors_origins = "www.rainbows.com"
|
||||
|
||||
app = lambda do |env|
|
||||
tester = Hijack::Tester.new(env)
|
||||
tester.hijack_test do
|
||||
render body: "hello", status: 201
|
||||
end
|
||||
|
||||
expect(tester.io.string).to include("Access-Control-Allow-Origin: www.rainbows.com")
|
||||
end
|
||||
|
||||
env = {}
|
||||
middleware = Discourse::Cors.new(app)
|
||||
middleware.call(env)
|
||||
|
||||
# it can do pre-flight
|
||||
env = {
|
||||
'REQUEST_METHOD' => 'OPTIONS',
|
||||
'HTTP_ACCESS_CONTROL_REQUEST_METHOD' => 'GET'
|
||||
}
|
||||
|
||||
status, headers, _body = middleware.call(env)
|
||||
|
||||
expect(status).to eq(200)
|
||||
|
||||
expected = {
|
||||
"Access-Control-Allow-Origin" => "www.rainbows.com",
|
||||
"Access-Control-Allow-Headers" => "X-Requested-With, X-CSRF-Token, Discourse-Visible",
|
||||
"Access-Control-Allow-Credentials" => "true"
|
||||
}
|
||||
|
||||
expect(headers).to eq(expected)
|
||||
end
|
||||
|
||||
it "handles expires_in" do
|
||||
tester.hijack_test do
|
||||
expires_in 1.year
|
||||
|
Loading…
Reference in New Issue
Block a user