Refactor CSRF protection modules, make sure CSRF tokens can be up-to-date. (#19337)

Do a refactoring to the CSRF related code, remove most unnecessary functions.
Parse the generated token's issue time, regenerate the token every a few minutes.
This commit is contained in:
wxiaoguang 2022-04-08 13:21:05 +08:00 committed by GitHub
parent 3c3d49899f
commit 84ceaa98bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 170 additions and 196 deletions

View file

@ -37,18 +37,18 @@ var (
func Test_ValidToken(t *testing.T) {
t.Run("Validate token", func(t *testing.T) {
tok := generateTokenAtTime(key, userID, actionID, now)
assert.True(t, validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow))
assert.True(t, validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond)))
assert.True(t, validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute)))
tok := GenerateCsrfToken(key, userID, actionID, now)
assert.True(t, ValidCsrfToken(tok, key, userID, actionID, oneMinuteFromNow))
assert.True(t, ValidCsrfToken(tok, key, userID, actionID, now.Add(CsrfTokenTimeout-1*time.Nanosecond)))
assert.True(t, ValidCsrfToken(tok, key, userID, actionID, now.Add(-1*time.Minute)))
})
}
// Test_SeparatorReplacement tests that separators are being correctly substituted
func Test_SeparatorReplacement(t *testing.T) {
t.Run("Test two separator replacements", func(t *testing.T) {
assert.NotEqual(t, generateTokenAtTime("foo:bar", "baz", "wah", now),
generateTokenAtTime("foo", "bar:baz", "wah", now))
assert.NotEqual(t, GenerateCsrfToken("foo:bar", "baz", "wah", now),
GenerateCsrfToken("foo", "bar:baz", "wah", now))
})
}
@ -61,13 +61,13 @@ func Test_InvalidToken(t *testing.T) {
{"Bad key", "foobar", userID, actionID, oneMinuteFromNow},
{"Bad userID", key, "foobar", actionID, oneMinuteFromNow},
{"Bad actionID", key, userID, "foobar", oneMinuteFromNow},
{"Expired", key, userID, actionID, now.Add(Timeout)},
{"Expired", key, userID, actionID, now.Add(CsrfTokenTimeout)},
{"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)},
}
tok := generateTokenAtTime(key, userID, actionID, now)
tok := GenerateCsrfToken(key, userID, actionID, now)
for _, itt := range invalidTokenTests {
assert.False(t, validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t))
assert.False(t, ValidCsrfToken(tok, itt.key, itt.userID, itt.actionID, itt.t))
}
})
}
@ -84,7 +84,7 @@ func Test_ValidateBadData(t *testing.T) {
}
for _, bdt := range badDataTests {
assert.False(t, validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow))
assert.False(t, ValidCsrfToken(bdt.tok, key, userID, actionID, oneMinuteFromNow))
}
})
}