Second attempt at preventing zombies (#16326)
* Second attempt at preventing zombies * Ensure that the pipes are closed in ssh.go * Ensure that a cancellable context is passed up in cmd/* http requests * Make cmd.fail return properly so defers are obeyed * Ensure that something is sent to stdout in case of blocks here Signed-off-by: Andrew Thornton <art27@cantab.net> * placate lint Signed-off-by: Andrew Thornton <art27@cantab.net> * placate lint 2 Signed-off-by: Andrew Thornton <art27@cantab.net> * placate lint 3 Signed-off-by: Andrew Thornton <art27@cantab.net> * fixup Signed-off-by: Andrew Thornton <art27@cantab.net> * Apply suggestions from code review Co-authored-by: 6543 <6543@obermui.de> Co-authored-by: Lauris BH <lauris@nix.lv>
This commit is contained in:
parent
ee43d70a0c
commit
3dcb3e9073
21 changed files with 229 additions and 143 deletions
|
@ -7,6 +7,7 @@ package httplib
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
|
@ -122,6 +123,12 @@ func (r *Request) Setting(setting Settings) *Request {
|
|||
return r
|
||||
}
|
||||
|
||||
// SetContext sets the request's Context
|
||||
func (r *Request) SetContext(ctx context.Context) *Request {
|
||||
r.req = r.req.WithContext(ctx)
|
||||
return r
|
||||
}
|
||||
|
||||
// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password.
|
||||
func (r *Request) SetBasicAuth(username, password string) *Request {
|
||||
r.req.SetBasicAuth(username, password)
|
||||
|
@ -325,7 +332,7 @@ func (r *Request) getResponse() (*http.Response, error) {
|
|||
trans = &http.Transport{
|
||||
TLSClientConfig: r.setting.TLSClientConfig,
|
||||
Proxy: proxy,
|
||||
Dial: TimeoutDialer(r.setting.ConnectTimeout),
|
||||
DialContext: TimeoutDialer(r.setting.ConnectTimeout),
|
||||
}
|
||||
} else if t, ok := trans.(*http.Transport); ok {
|
||||
if t.TLSClientConfig == nil {
|
||||
|
@ -334,8 +341,8 @@ func (r *Request) getResponse() (*http.Response, error) {
|
|||
if t.Proxy == nil {
|
||||
t.Proxy = r.setting.Proxy
|
||||
}
|
||||
if t.Dial == nil {
|
||||
t.Dial = TimeoutDialer(r.setting.ConnectTimeout)
|
||||
if t.DialContext == nil {
|
||||
t.DialContext = TimeoutDialer(r.setting.ConnectTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -458,9 +465,10 @@ func (r *Request) Response() (*http.Response, error) {
|
|||
}
|
||||
|
||||
// TimeoutDialer returns functions of connection dialer with timeout settings for http.Transport Dial field.
|
||||
func TimeoutDialer(cTimeout time.Duration) func(net, addr string) (c net.Conn, err error) {
|
||||
return func(netw, addr string) (net.Conn, error) {
|
||||
conn, err := net.DialTimeout(netw, addr, cTimeout)
|
||||
func TimeoutDialer(cTimeout time.Duration) func(ctx context.Context, net, addr string) (c net.Conn, err error) {
|
||||
return func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
d := net.Dialer{Timeout: cTimeout}
|
||||
conn, err := d.DialContext(ctx, netw, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -80,12 +81,12 @@ type HookPostReceiveBranchResult struct {
|
|||
}
|
||||
|
||||
// HookPreReceive check whether the provided commits are allowed
|
||||
func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string) {
|
||||
func HookPreReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (int, string) {
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/pre-receive/%s/%s",
|
||||
url.PathEscape(ownerName),
|
||||
url.PathEscape(repoName),
|
||||
)
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
json := jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
jsonBytes, _ := json.Marshal(opts)
|
||||
|
@ -105,13 +106,13 @@ func HookPreReceive(ownerName, repoName string, opts HookOptions) (int, string)
|
|||
}
|
||||
|
||||
// HookPostReceive updates services and users
|
||||
func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) {
|
||||
func HookPostReceive(ctx context.Context, ownerName, repoName string, opts HookOptions) (*HookPostReceiveResult, string) {
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/post-receive/%s/%s",
|
||||
url.PathEscape(ownerName),
|
||||
url.PathEscape(repoName),
|
||||
)
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
req.SetTimeout(60*time.Second, time.Duration(60+len(opts.OldCommitIDs))*time.Second)
|
||||
json := jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
|
@ -133,13 +134,13 @@ func HookPostReceive(ownerName, repoName string, opts HookOptions) (*HookPostRec
|
|||
}
|
||||
|
||||
// SetDefaultBranch will set the default branch to the provided branch for the provided repository
|
||||
func SetDefaultBranch(ownerName, repoName, branch string) error {
|
||||
func SetDefaultBranch(ctx context.Context, ownerName, repoName, branch string) error {
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/hook/set-default-branch/%s/%s/%s",
|
||||
url.PathEscape(ownerName),
|
||||
url.PathEscape(repoName),
|
||||
url.PathEscape(branch),
|
||||
)
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
|
||||
req.SetTimeout(60*time.Second, 60*time.Second)
|
||||
|
@ -155,9 +156,9 @@ func SetDefaultBranch(ownerName, repoName, branch string) error {
|
|||
}
|
||||
|
||||
// SSHLog sends ssh error log response
|
||||
func SSHLog(isErr bool, msg string) error {
|
||||
func SSHLog(ctx context.Context, isErr bool, msg string) error {
|
||||
reqURL := setting.LocalURL + "api/internal/ssh/log"
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
|
||||
jsonBytes, _ := json.Marshal(&SSHLogOption{
|
||||
|
@ -171,6 +172,7 @@ func SSHLog(isErr bool, msg string) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("unable to contact gitea: %v", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("Error returned from gitea: %v", decodeJSONError(resp).Err)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
|
@ -15,9 +16,11 @@ import (
|
|||
jsoniter "github.com/json-iterator/go"
|
||||
)
|
||||
|
||||
func newRequest(url, method string) *httplib.Request {
|
||||
return httplib.NewRequest(url, method).Header("Authorization",
|
||||
fmt.Sprintf("Bearer %s", setting.InternalToken))
|
||||
func newRequest(ctx context.Context, url, method string) *httplib.Request {
|
||||
return httplib.NewRequest(url, method).
|
||||
SetContext(ctx).
|
||||
Header("Authorization",
|
||||
fmt.Sprintf("Bearer %s", setting.InternalToken))
|
||||
}
|
||||
|
||||
// Response internal request response
|
||||
|
@ -35,8 +38,8 @@ func decodeJSONError(resp *http.Response) *Response {
|
|||
return &res
|
||||
}
|
||||
|
||||
func newInternalRequest(url, method string) *httplib.Request {
|
||||
req := newRequest(url, method).SetTLSClientConfig(&tls.Config{
|
||||
func newInternalRequest(ctx context.Context, url, method string) *httplib.Request {
|
||||
req := newRequest(ctx, url, method).SetTLSClientConfig(&tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: setting.Domain,
|
||||
})
|
||||
|
@ -45,6 +48,10 @@ func newInternalRequest(url, method string) *httplib.Request {
|
|||
Dial: func(_, _ string) (net.Conn, error) {
|
||||
return net.Dial("unix", setting.HTTPAddr)
|
||||
},
|
||||
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "unix", setting.HTTPAddr)
|
||||
},
|
||||
})
|
||||
}
|
||||
return req
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -13,10 +14,10 @@ import (
|
|||
)
|
||||
|
||||
// UpdatePublicKeyInRepo update public key and if necessary deploy key updates
|
||||
func UpdatePublicKeyInRepo(keyID, repoID int64) error {
|
||||
func UpdatePublicKeyInRepo(ctx context.Context, keyID, repoID int64) error {
|
||||
// Ask for running deliver hook and test pull request tasks.
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/ssh/%d/update/%d", keyID, repoID)
|
||||
resp, err := newInternalRequest(reqURL, "POST").Response()
|
||||
resp, err := newInternalRequest(ctx, reqURL, "POST").Response()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -32,10 +33,10 @@ func UpdatePublicKeyInRepo(keyID, repoID int64) error {
|
|||
|
||||
// AuthorizedPublicKeyByContent searches content as prefix (leak e-mail part)
|
||||
// and returns public key found.
|
||||
func AuthorizedPublicKeyByContent(content string) (string, error) {
|
||||
func AuthorizedPublicKeyByContent(ctx context.Context, content string) (string, error) {
|
||||
// Ask for running deliver hook and test pull request tasks.
|
||||
reqURL := setting.LocalURL + "api/internal/ssh/authorized_keys"
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req.Param("content", content)
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -27,10 +28,10 @@ type Email struct {
|
|||
//
|
||||
// If to list == nil its supposed to send an email to every
|
||||
// user present in DB
|
||||
func SendEmail(subject, message string, to []string) (int, string) {
|
||||
func SendEmail(ctx context.Context, subject, message string, to []string) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/mail/send"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
json := jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
jsonBytes, _ := json.Marshal(Email{
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -15,10 +16,10 @@ import (
|
|||
)
|
||||
|
||||
// Shutdown calls the internal shutdown function
|
||||
func Shutdown() (int, string) {
|
||||
func Shutdown(ctx context.Context) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/shutdown"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
|
||||
|
@ -33,10 +34,10 @@ func Shutdown() (int, string) {
|
|||
}
|
||||
|
||||
// Restart calls the internal restart function
|
||||
func Restart() (int, string) {
|
||||
func Restart(ctx context.Context) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/restart"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
|
||||
|
@ -57,10 +58,10 @@ type FlushOptions struct {
|
|||
}
|
||||
|
||||
// FlushQueues calls the internal flush-queues function
|
||||
func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) {
|
||||
func FlushQueues(ctx context.Context, timeout time.Duration, nonBlocking bool) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/flush-queues"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
if timeout > 0 {
|
||||
req.SetTimeout(timeout+10*time.Second, timeout+10*time.Second)
|
||||
}
|
||||
|
@ -85,10 +86,10 @@ func FlushQueues(timeout time.Duration, nonBlocking bool) (int, string) {
|
|||
}
|
||||
|
||||
// PauseLogging pauses logging
|
||||
func PauseLogging() (int, string) {
|
||||
func PauseLogging(ctx context.Context) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/pause-logging"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
|
||||
|
@ -103,10 +104,10 @@ func PauseLogging() (int, string) {
|
|||
}
|
||||
|
||||
// ResumeLogging resumes logging
|
||||
func ResumeLogging() (int, string) {
|
||||
func ResumeLogging(ctx context.Context) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/resume-logging"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
|
||||
|
@ -121,10 +122,10 @@ func ResumeLogging() (int, string) {
|
|||
}
|
||||
|
||||
// ReleaseReopenLogging releases and reopens logging files
|
||||
func ReleaseReopenLogging() (int, string) {
|
||||
func ReleaseReopenLogging(ctx context.Context) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/release-and-reopen-logging"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
|
||||
|
@ -147,10 +148,10 @@ type LoggerOptions struct {
|
|||
}
|
||||
|
||||
// AddLogger adds a logger
|
||||
func AddLogger(group, name, mode string, config map[string]interface{}) (int, string) {
|
||||
func AddLogger(ctx context.Context, group, name, mode string, config map[string]interface{}) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/manager/add-logger"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
json := jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
jsonBytes, _ := json.Marshal(LoggerOptions{
|
||||
|
@ -175,10 +176,10 @@ func AddLogger(group, name, mode string, config map[string]interface{}) (int, st
|
|||
}
|
||||
|
||||
// RemoveLogger removes a logger
|
||||
func RemoveLogger(group, name string) (int, string) {
|
||||
func RemoveLogger(ctx context.Context, group, name string) (int, string) {
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/manager/remove-logger/%s/%s", url.PathEscape(group), url.PathEscape(name))
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
resp, err := req.Response()
|
||||
if err != nil {
|
||||
return http.StatusInternalServerError, fmt.Sprintf("Unable to contact gitea: %v", err.Error())
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -23,10 +24,10 @@ type RestoreParams struct {
|
|||
}
|
||||
|
||||
// RestoreRepo calls the internal RestoreRepo function
|
||||
func RestoreRepo(repoDir, ownerName, repoName string, units []string) (int, string) {
|
||||
func RestoreRepo(ctx context.Context, repoDir, ownerName, repoName string, units []string) (int, string) {
|
||||
reqURL := setting.LocalURL + "api/internal/restore_repo"
|
||||
|
||||
req := newInternalRequest(reqURL, "POST")
|
||||
req := newInternalRequest(ctx, reqURL, "POST")
|
||||
req.SetTimeout(3*time.Second, 0) // since the request will spend much time, don't timeout
|
||||
req = req.Header("Content-Type", "application/json")
|
||||
json := jsoniter.ConfigCompatibleWithStandardLibrary
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
package private
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
@ -21,10 +22,10 @@ type KeyAndOwner struct {
|
|||
}
|
||||
|
||||
// ServNoCommand returns information about the provided key
|
||||
func ServNoCommand(keyID int64) (*models.PublicKey, *models.User, error) {
|
||||
func ServNoCommand(ctx context.Context, keyID int64) (*models.PublicKey, *models.User, error) {
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/none/%d",
|
||||
keyID)
|
||||
resp, err := newInternalRequest(reqURL, "GET").Response()
|
||||
resp, err := newInternalRequest(ctx, reqURL, "GET").Response()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -73,7 +74,7 @@ func IsErrServCommand(err error) bool {
|
|||
}
|
||||
|
||||
// ServCommand preps for a serv call
|
||||
func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) {
|
||||
func ServCommand(ctx context.Context, keyID int64, ownerName, repoName string, mode models.AccessMode, verbs ...string) (*ServCommandResults, error) {
|
||||
reqURL := setting.LocalURL + fmt.Sprintf("api/internal/serv/command/%d/%s/%s?mode=%d",
|
||||
keyID,
|
||||
url.PathEscape(ownerName),
|
||||
|
@ -85,7 +86,7 @@ func ServCommand(keyID int64, ownerName, repoName string, mode models.AccessMode
|
|||
}
|
||||
}
|
||||
|
||||
resp, err := newInternalRequest(reqURL, "GET").Response()
|
||||
resp, err := newInternalRequest(ctx, reqURL, "GET").Response()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ package ssh
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
|
@ -66,7 +67,11 @@ func sessionHandler(session ssh.Session) {
|
|||
|
||||
args := []string{"serv", "key-" + keyID, "--config=" + setting.CustomConf}
|
||||
log.Trace("SSH: Arguments: %v", args)
|
||||
cmd := exec.CommandContext(session.Context(), setting.AppPath, args...)
|
||||
|
||||
ctx, cancel := context.WithCancel(session.Context())
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, setting.AppPath, args...)
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"SSH_ORIGINAL_COMMAND="+command,
|
||||
|
@ -78,16 +83,21 @@ func sessionHandler(session ssh.Session) {
|
|||
log.Error("SSH: StdoutPipe: %v", err)
|
||||
return
|
||||
}
|
||||
defer stdout.Close()
|
||||
|
||||
stderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
log.Error("SSH: StderrPipe: %v", err)
|
||||
return
|
||||
}
|
||||
defer stderr.Close()
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
log.Error("SSH: StdinPipe: %v", err)
|
||||
return
|
||||
}
|
||||
defer stdin.Close()
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
@ -106,6 +116,7 @@ func sessionHandler(session ssh.Session) {
|
|||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer stdout.Close()
|
||||
if _, err := io.Copy(session, stdout); err != nil {
|
||||
log.Error("Failed to write stdout to session. %s", err)
|
||||
}
|
||||
|
@ -113,6 +124,7 @@ func sessionHandler(session ssh.Session) {
|
|||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer stderr.Close()
|
||||
if _, err := io.Copy(session.Stderr(), stderr); err != nil {
|
||||
log.Error("Failed to write stderr to session. %s", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue