[REFACTOR] PKT protocol

- Use `Fprintf` to convert to hex and do padding. Simplifies the code.
- Use `Read()` and `io.ReadFull` instead of `ReadByte()`. Should improve
performance and allows for cleaner code.
- s/pktLineTypeUnknow/pktLineTypeUnknown.
- Disallow empty Pkt line per the specification.
- Disallow too large Pkt line per the specification.
- Add unit tests.
This commit is contained in:
Gusted 2024-03-29 00:20:21 +01:00
parent a11116602e
commit 2c8bcc163e
No known key found for this signature in database
GPG key ID: FD821B732837125F
3 changed files with 93 additions and 52 deletions

View file

@ -583,7 +583,7 @@ Forgejo or set your environment appropriately.`, "")
for { for {
// note: pktLineTypeUnknow means pktLineTypeFlush and pktLineTypeData all allowed // note: pktLineTypeUnknow means pktLineTypeFlush and pktLineTypeData all allowed
rs, err = readPktLine(ctx, reader, pktLineTypeUnknow) rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
if err != nil { if err != nil {
return err return err
} }
@ -604,7 +604,7 @@ Forgejo or set your environment appropriately.`, "")
if hasPushOptions { if hasPushOptions {
for { for {
rs, err = readPktLine(ctx, reader, pktLineTypeUnknow) rs, err = readPktLine(ctx, reader, pktLineTypeUnknown)
if err != nil { if err != nil {
return err return err
} }
@ -699,8 +699,8 @@ Forgejo or set your environment appropriately.`, "")
type pktLineType int64 type pktLineType int64
const ( const (
// UnKnow type // Unknown type
pktLineTypeUnknow pktLineType = 0 pktLineTypeUnknown pktLineType = 0
// flush-pkt "0000" // flush-pkt "0000"
pktLineTypeFlush pktLineType = iota pktLineTypeFlush pktLineType = iota
// data line // data line
@ -714,22 +714,16 @@ type gitPktLine struct {
Data []byte Data []byte
} }
// Reads an Pkt-Line from `in`. If requestType is not unknown, it will a
func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType) (*gitPktLine, error) { func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType) (*gitPktLine, error) {
var ( // Read length prefix
err error
r *gitPktLine
)
// read prefix
lengthBytes := make([]byte, 4) lengthBytes := make([]byte, 4)
for i := 0; i < 4; i++ { if n, err := in.Read(lengthBytes); n != 4 || err != nil {
lengthBytes[i], err = in.ReadByte()
if err != nil {
return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err) return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
} }
}
r = new(gitPktLine) var err error
r := &gitPktLine{}
r.Length, err = strconv.ParseUint(string(lengthBytes), 16, 32) r.Length, err = strconv.ParseUint(string(lengthBytes), 16, 32)
if err != nil { if err != nil {
return nil, fail(ctx, "Protocol: format parse error", "Pkt-Line format is wrong :%v", err) return nil, fail(ctx, "Protocol: format parse error", "Pkt-Line format is wrong :%v", err)
@ -748,11 +742,8 @@ func readPktLine(ctx context.Context, in *bufio.Reader, requestType pktLineType)
} }
r.Data = make([]byte, r.Length-4) r.Data = make([]byte, r.Length-4)
for i := range r.Data { if n, err := io.ReadFull(in, r.Data); uint64(n) != r.Length-4 || err != nil {
r.Data[i], err = in.ReadByte() return nil, fail(ctx, "Protocol: stdin error", "Pkt-Line: read stdin failed : %v", err)
if err != nil {
return nil, fail(ctx, "Protocol: data error", "Pkt-Line: read stdin failed : %v", err)
}
} }
r.Type = pktLineTypeData r.Type = pktLineTypeData
@ -768,20 +759,23 @@ func writeFlushPktLine(ctx context.Context, out io.Writer) error {
return nil return nil
} }
// Write an Pkt-Line based on `data` to `out` according to the specifcation.
// https://git-scm.com/docs/protocol-common
func writeDataPktLine(ctx context.Context, out io.Writer, data []byte) error { func writeDataPktLine(ctx context.Context, out io.Writer, data []byte) error {
hexchar := []byte("0123456789abcdef") // Implementations SHOULD NOT send an empty pkt-line ("0004").
hex := func(n uint64) byte { if len(data) == 0 {
return hexchar[(n)&15] return fail(ctx, "Protocol: write error", "Not allowed to write empty Pkt-Line")
} }
length := uint64(len(data) + 4) length := uint64(len(data) + 4)
tmp := make([]byte, 4)
tmp[0] = hex(length >> 12)
tmp[1] = hex(length >> 8)
tmp[2] = hex(length >> 4)
tmp[3] = hex(length)
lr, err := out.Write(tmp) // The maximum length of a pkt-lines data component is 65516 bytes.
// Implementations MUST NOT send pkt-line whose length exceeds 65520 (65516 bytes of payload + 4 bytes of length data).
if length > 65520 {
return fail(ctx, "Protocol: write error", "Pkt-Line exceeds maximum of 65520 bytes")
}
lr, err := fmt.Fprintf(out, "%04x", length)
if err != nil || lr != 4 { if err != nil || lr != 4 {
return fail(ctx, "Protocol: write error", "Pkt-Line response failed: %v", err) return fail(ctx, "Protocol: write error", "Pkt-Line response failed: %v", err)
} }

View file

@ -14,8 +14,9 @@ import (
) )
func TestPktLine(t *testing.T) { func TestPktLine(t *testing.T) {
// test read
ctx := context.Background() ctx := context.Background()
t.Run("Read", func(t *testing.T) {
s := strings.NewReader("0000") s := strings.NewReader("0000")
r := bufio.NewReader(s) r := bufio.NewReader(s)
result, err := readPktLine(ctx, r, pktLineTypeFlush) result, err := readPktLine(ctx, r, pktLineTypeFlush)
@ -29,9 +30,28 @@ func TestPktLine(t *testing.T) {
assert.Equal(t, pktLineTypeData, result.Type) assert.Equal(t, pktLineTypeData, result.Type)
assert.Equal(t, []byte("a\n"), result.Data) assert.Equal(t, []byte("a\n"), result.Data)
// test write s = strings.NewReader("0004")
r = bufio.NewReader(s)
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.Error(t, err)
assert.Nil(t, result)
data := strings.Repeat("x", 65516)
r = bufio.NewReader(strings.NewReader("fff0" + data))
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.NoError(t, err)
assert.Equal(t, pktLineTypeData, result.Type)
assert.Equal(t, []byte(data), result.Data)
r = bufio.NewReader(strings.NewReader("fff1a"))
result, err = readPktLine(ctx, r, pktLineTypeData)
assert.Error(t, err)
assert.Nil(t, result)
})
t.Run("Write", func(t *testing.T) {
w := bytes.NewBuffer([]byte{}) w := bytes.NewBuffer([]byte{})
err = writeFlushPktLine(ctx, w) err := writeFlushPktLine(ctx, w)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []byte("0000"), w.Bytes()) assert.Equal(t, []byte("0000"), w.Bytes())
@ -39,4 +59,27 @@ func TestPktLine(t *testing.T) {
err = writeDataPktLine(ctx, w, []byte("a\nb")) err = writeDataPktLine(ctx, w, []byte("a\nb"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []byte("0007a\nb"), w.Bytes()) assert.Equal(t, []byte("0007a\nb"), w.Bytes())
w.Reset()
data := bytes.Repeat([]byte{0x05}, 288)
err = writeDataPktLine(ctx, w, data)
assert.NoError(t, err)
assert.Equal(t, append([]byte("0124"), data...), w.Bytes())
w.Reset()
err = writeDataPktLine(ctx, w, nil)
assert.Error(t, err)
assert.Empty(t, w.Bytes())
w.Reset()
data = bytes.Repeat([]byte{0x64}, 65516)
err = writeDataPktLine(ctx, w, data)
assert.NoError(t, err)
assert.Equal(t, append([]byte("fff0"), data...), w.Bytes())
w.Reset()
err = writeDataPktLine(ctx, w, bytes.Repeat([]byte{0x64}, 65516+1))
assert.Error(t, err)
assert.Empty(t, w.Bytes())
})
} }

View file

@ -14,6 +14,7 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"testing"
"time" "time"
"unicode" "unicode"
@ -106,8 +107,11 @@ func fail(ctx context.Context, userMessage, logMsgFmt string, args ...any) error
logMsg = userMessage + ". " + logMsg logMsg = userMessage + ". " + logMsg
} }
} }
// Don't send an log if this is done in a test and no InternalToken is set.
if !testing.Testing() || setting.InternalToken != "" {
_ = private.SSHLog(ctx, true, logMsg) _ = private.SSHLog(ctx, true, logMsg)
} }
}
return cli.Exit("", 1) return cli.Exit("", 1)
} }