Refactor CORS handler (#28587)

The CORS code has been unmaintained for long time, and the behavior is
not correct.

This PR tries to improve it. The key point is written as comment in
code. And add more tests.

Fix #28515
Fix #27642
Fix #17098
This commit is contained in:
wxiaoguang 2023-12-25 20:13:18 +08:00 committed by GitHub
parent d0f24ff4ca
commit b41925cee3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 131 additions and 78 deletions

View file

@ -7,17 +7,88 @@ import (
"net/http"
"testing"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/test"
"code.gitea.io/gitea/routers"
"code.gitea.io/gitea/tests"
"github.com/stretchr/testify/assert"
)
func TestCORSNotSet(t *testing.T) {
func TestCORS(t *testing.T) {
defer tests.PrepareTestEnv(t)()
req := NewRequest(t, "GET", "/api/v1/version")
session := loginUser(t, "user2")
resp := session.MakeRequest(t, req, http.StatusOK)
assert.Equal(t, resp.Code, http.StatusOK)
corsHeader := resp.Header().Get("Access-Control-Allow-Origin")
assert.Empty(t, corsHeader, "Access-Control-Allow-Origin: generated header should match") // header not set
t.Run("CORS enabled", func(t *testing.T) {
defer test.MockVariableValue(&setting.CORSConfig.Enabled, true)()
defer test.MockVariableValue(&testWebRoutes, routers.NormalRoutes())()
t.Run("API with CORS", func(t *testing.T) {
// GET api with no CORS header
req := NewRequest(t, "GET", "/api/v1/version")
resp := MakeRequest(t, req, http.StatusOK)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")
// OPTIONS api for CORS
req = NewRequest(t, "OPTIONS", "/api/v1/version").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusOK)
assert.NotEmpty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")
})
t.Run("Web with CORS", func(t *testing.T) {
// GET userinfo with no CORS header
req := NewRequest(t, "GET", "/login/oauth/userinfo")
resp := MakeRequest(t, req, http.StatusUnauthorized)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")
// OPTIONS userinfo for CORS
req = NewRequest(t, "OPTIONS", "/login/oauth/userinfo").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusOK)
assert.NotEmpty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Contains(t, resp.Header().Values("Vary"), "Origin")
// OPTIONS userinfo for non-CORS
req = NewRequest(t, "OPTIONS", "/login/oauth/userinfo")
resp = MakeRequest(t, req, http.StatusMethodNotAllowed)
assert.NotContains(t, resp.Header().Values("Vary"), "Origin")
})
})
t.Run("CORS disabled", func(t *testing.T) {
defer test.MockVariableValue(&setting.CORSConfig.Enabled, false)()
defer test.MockVariableValue(&testWebRoutes, routers.NormalRoutes())()
t.Run("API without CORS", func(t *testing.T) {
req := NewRequest(t, "GET", "/api/v1/version")
resp := MakeRequest(t, req, http.StatusOK)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Header().Values("Vary"))
req = NewRequest(t, "OPTIONS", "/api/v1/version").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusMethodNotAllowed)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.Empty(t, resp.Header().Values("Vary"))
})
t.Run("Web without CORS", func(t *testing.T) {
req := NewRequest(t, "GET", "/login/oauth/userinfo")
resp := MakeRequest(t, req, http.StatusUnauthorized)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.NotContains(t, resp.Header().Values("Vary"), "Origin")
req = NewRequest(t, "OPTIONS", "/login/oauth/userinfo").
SetHeader("Origin", "https://example.com").
SetHeader("Access-Control-Request-Method", "GET")
resp = MakeRequest(t, req, http.StatusMethodNotAllowed)
assert.Empty(t, resp.Header().Get("Access-Control-Allow-Origin"))
assert.NotContains(t, resp.Header().Values("Vary"), "Origin")
})
})
}