From c970378fd7932043c8cb8bfda786f4ffce5994b8 Mon Sep 17 00:00:00 2001 From: Levi Date: Tue, 24 Dec 2024 20:28:04 +0800 Subject: [PATCH] fix: http download retries too many times (#849) --- internal/protocol/http/fetcher.go | 59 +++++++++++-------- internal/protocol/http/fetcher_test.go | 8 +++ internal/protocol/http/timeout_reader.go | 40 +++++++++++++ internal/protocol/http/timeout_reader_test.go | 51 ++++++++++++++++ internal/test/httptest.go | 49 +++++++++------ 5 files changed, 164 insertions(+), 43 deletions(-) create mode 100644 internal/protocol/http/timeout_reader.go create mode 100644 internal/protocol/http/timeout_reader_test.go diff --git a/internal/protocol/http/fetcher.go b/internal/protocol/http/fetcher.go index 1d3def5ae..d40be4525 100644 --- a/internal/protocol/http/fetcher.go +++ b/internal/protocol/http/fetcher.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "mime" + "net" "net/http" "net/http/cookiejar" "net/url" @@ -25,6 +26,9 @@ import ( "golang.org/x/sync/errgroup" ) +const connectTimeout = 15 * time.Second +const readTimeout = 15 * time.Second + type RequestError struct { Code int Msg string @@ -103,7 +107,7 @@ func (f *Fetcher) Resolve(req *base.Request) error { if base.HttpCodePartialContent == httpResp.StatusCode || (base.HttpCodeOK == httpResp.StatusCode && httpResp.Header.Get(base.HttpHeaderAcceptRanges) == base.HttpHeaderBytes && strings.HasPrefix(httpResp.Header.Get(base.HttpHeaderContentRange), base.HttpHeaderBytes)) { // response 206 status code, support breakpoint continuation res.Range = true - // 解析资源大小: bytes 0-1000/1001 => 1001 + // parse content length from Content-Range header, eg: bytes 0-1000/1001 contentTotal := path.Base(httpResp.Header.Get(base.HttpHeaderContentRange)) if contentTotal != "" { parse, err := strconv.ParseInt(contentTotal, 10, 64) @@ -146,7 +150,7 @@ func (f *Fetcher) Resolve(req *base.Request) error { file.Name = filename } } - // Get file filePath by URL + // get file filePath by URL if file.Name == "" { file.Name = path.Base(httpReq.URL.Path) } @@ -250,23 +254,42 @@ func (f *Fetcher) Wait() (err error) { return <-f.doneCh } +type fetchResult struct { + err error +} + func (f *Fetcher) fetch() { var ctx context.Context ctx, f.cancel = context.WithCancel(context.Background()) f.eg, _ = errgroup.WithContext(ctx) + chunkErrs := make([]error, len(f.chunks)) for i := 0; i < len(f.chunks); i++ { i := i f.eg.Go(func() error { - return f.fetchChunk(i, ctx) + err := f.fetchChunk(i, ctx) + // if canceled, fail fast + if errors.Is(err, context.Canceled) { + return err + } + chunkErrs[i] = err + return nil }) } go func() { err := f.eg.Wait() - // check if canceled - if errors.Is(err, context.Canceled) { + // error returned only if canceled, just return + if err != nil { return } + // check all fetch results, if any error, return + for _, chunkErr := range chunkErrs { + if chunkErr != nil { + err = chunkErr + break + } + } + f.file.Close() // Update file last modified time if f.config.UseServerCtime && f.meta.Res.Files[0].Ctime != nil { @@ -289,24 +312,11 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { for { // if chunk is completed, return if f.meta.Res.Range && chunk.Downloaded >= chunk.End-chunk.Begin+1 { - return + return nil } if chunk.retryTimes >= maxRetries { - if !f.meta.Res.Range { - return - } - // check if all failed - allFailed := true - for _, c := range f.chunks { - if chunk.Downloaded < chunk.End-chunk.Begin+1 && c.retryTimes < maxRetries { - allFailed = false - break - } - } - if allFailed { - return - } + return } var ( @@ -333,10 +343,9 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { err = NewRequestError(resp.StatusCode, resp.Status) return err } - // Http request success, reset retry times - chunk.retryTimes = 0 + reader := NewTimeoutReader(resp.Body, readTimeout) for { - n, err := resp.Body.Read(buf) + n, err := reader.Read(buf) if n > 0 { _, err := f.file.WriteAt(buf[:n], chunk.Begin+chunk.Downloaded) if err != nil { @@ -351,7 +360,6 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) { return err } } - return nil }() if err != nil { // If canceled, do not retry @@ -452,6 +460,9 @@ func (f *Fetcher) splitChunk() (chunks []*chunk) { func (f *Fetcher) buildClient() *http.Client { transport := &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: connectTimeout, + }).DialContext, Proxy: f.ctl.GetProxy(f.meta.Req.Proxy), TLSClientConfig: &tls.Config{ InsecureSkipVerify: f.meta.Req.SkipVerifyCert, diff --git a/internal/protocol/http/fetcher_test.go b/internal/protocol/http/fetcher_test.go index 8109cbb3d..89e4e6ef9 100644 --- a/internal/protocol/http/fetcher_test.go +++ b/internal/protocol/http/fetcher_test.go @@ -132,6 +132,14 @@ func TestFetcher_DownloadLimit(t *testing.T) { downloadNormal(listener, 8, t) } +func TestFetcher_DownloadResponseBodyReadTimeout(t *testing.T) { + listener := test.StartTestLimitServer(16, readTimeout.Milliseconds()+5000) + defer listener.Close() + + downloadError(listener, 1, t) + downloadError(listener, 4, t) +} + func TestFetcher_DownloadResume(t *testing.T) { listener := test.StartTestFileServer() defer listener.Close() diff --git a/internal/protocol/http/timeout_reader.go b/internal/protocol/http/timeout_reader.go new file mode 100644 index 000000000..f23e16e9a --- /dev/null +++ b/internal/protocol/http/timeout_reader.go @@ -0,0 +1,40 @@ +package http + +import ( + "context" + "io" + "time" +) + +type TimeoutReader struct { + reader io.Reader + timeout time.Duration +} + +func NewTimeoutReader(r io.Reader, timeout time.Duration) *TimeoutReader { + return &TimeoutReader{ + reader: r, + timeout: timeout, + } +} + +func (tr *TimeoutReader) Read(p []byte) (n int, err error) { + ctx, cancel := context.WithTimeout(context.Background(), tr.timeout) + defer cancel() + + done := make(chan struct{}) + var readErr error + var bytesRead int + + go func() { + bytesRead, readErr = tr.reader.Read(p) + close(done) + }() + + select { + case <-done: + return bytesRead, readErr + case <-ctx.Done(): + return 0, ctx.Err() + } +} diff --git a/internal/protocol/http/timeout_reader_test.go b/internal/protocol/http/timeout_reader_test.go new file mode 100644 index 000000000..337d55b34 --- /dev/null +++ b/internal/protocol/http/timeout_reader_test.go @@ -0,0 +1,51 @@ +package http + +import ( + "bytes" + "context" + "errors" + "io" + "testing" + "time" +) + +func TestTimeoutReader_Read(t *testing.T) { + data := []byte("Hello, World!") + reader := bytes.NewReader(data) + timeoutReader := NewTimeoutReader(reader, 1*time.Second) + + buf := make([]byte, len(data)) + n, err := timeoutReader.Read(buf) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if n != len(data) { + t.Fatalf("expected to read %d bytes, read %d", len(data), n) + } + if !bytes.Equal(buf, data) { + t.Fatalf("expected %s, got %s", data, buf) + } +} + +func TestTimeoutReader_ReadTimeout(t *testing.T) { + reader := &slowReader{delay: 2 * time.Second} + timeoutReader := NewTimeoutReader(reader, 1*time.Second) + + buf := make([]byte, 8192) + _, err := timeoutReader.Read(buf) + if err == nil { + t.Fatal("expected timeout error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err) + } +} + +type slowReader struct { + delay time.Duration +} + +func (sr *slowReader) Read(p []byte) (n int, err error) { + time.Sleep(sr.delay) + return 0, io.EOF +} diff --git a/internal/test/httptest.go b/internal/test/httptest.go index 56142b949..77f183eb8 100644 --- a/internal/test/httptest.go +++ b/internal/test/httptest.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/json" + "errors" "fmt" "github.com/GopeedLab/gopeed/pkg/base" "github.com/armon/go-socks5" @@ -36,7 +37,7 @@ const ( ) func StartTestFileServer() net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { return http.FileServer(http.Dir(Dir)) }) } @@ -52,7 +53,7 @@ func (s *SlowFileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func StartTestSlowFileServer(delay time.Duration) net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { return &SlowFileServer{ delay: delay, handler: http.FileServer(http.Dir(Dir)), @@ -61,7 +62,7 @@ func StartTestSlowFileServer(delay time.Duration) net.Listener { } func StartTestCustomServer() net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { file, err := os.Open(BuildFile) @@ -88,7 +89,7 @@ func StartTestCustomServer() net.Listener { func StartTestRetryServer() net.Listener { counter := 0 - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { counter++ @@ -108,7 +109,7 @@ func StartTestRetryServer() net.Listener { } func StartTestPostServer() net.Listener { - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { if request.Method == "POST" && request.Header.Get("Authorization") != "" { @@ -132,7 +133,7 @@ func StartTestPostServer() net.Listener { func StartTestErrorServer() net.Listener { counter := 0 - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { counter++ @@ -149,7 +150,7 @@ func StartTestErrorServer() net.Listener { func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { var connections atomic.Int32 - return startTestServer(func() http.Handler { + return startTestServer(func(sl *shutdownListener) http.Handler { mux := http.NewServeMux() mux.HandleFunc("/"+BuildName, func(writer http.ResponseWriter, request *http.Request) { defer func() { @@ -165,12 +166,14 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { if r == "" { writer.Header().Set("Content-Length", fmt.Sprintf("%d", BuildSize)) writer.WriteHeader(200) + (writer.(http.Flusher)).Flush() + file, err := os.Open(BuildFile) if err != nil { panic(err) } defer file.Close() - slowCopy(writer, file, delay) + slowCopy(sl, writer, file, delay) } else { // split range s := strings.Split(r, "=") @@ -204,6 +207,8 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { writer.Header().Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, BuildSize)) writer.Header().Set("Accept-Ranges", "bytes") writer.WriteHeader(206) + (writer.(http.Flusher)).Flush() + file, err := os.Open(BuildFile) if err != nil { writer.WriteHeader(500) @@ -211,7 +216,7 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { } defer file.Close() file.Seek(start, 0) - slowCopyN(writer, file, end-start+1, delay) + slowCopyN(sl, writer, file, end-start+1, delay) } }) return mux @@ -219,9 +224,12 @@ func StartTestLimitServer(maxConnections int32, delay int64) net.Listener { } // slowCopyN copies n bytes from src to dst, speed limit is bytes per second -func slowCopy(dst io.Writer, src io.Reader, delay int64) (written int64, err error) { +func slowCopy(sl *shutdownListener, dst io.Writer, src io.Reader, delay int64) (written int64, err error) { buf := make([]byte, 32*1024) for { + if sl.isShutdown { + return 0, errors.New("server shutdown") + } nr, er := src.Read(buf) if nr > 0 { nw, ew := dst.Write(buf[0:nr]) @@ -250,8 +258,8 @@ func slowCopy(dst io.Writer, src io.Reader, delay int64) (written int64, err err return written, err } -func slowCopyN(dst io.Writer, src io.Reader, n int64, delay int64) (written int64, err error) { - written, err = slowCopy(dst, io.LimitReader(src, n), delay) +func slowCopyN(sl *shutdownListener, dst io.Writer, src io.Reader, n int64, delay int64) (written int64, err error) { + written, err = slowCopy(sl, dst, io.LimitReader(src, n), delay) if written == n { return n, nil } @@ -262,7 +270,7 @@ func slowCopyN(dst io.Writer, src io.Reader, n int64, delay int64) (written int6 return } -func startTestServer(serverHandle func() http.Handler) net.Listener { +func startTestServer(serverHandle func(sl *shutdownListener) http.Handler) net.Listener { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { panic(err) @@ -272,7 +280,7 @@ func startTestServer(serverHandle func() http.Handler) net.Listener { panic(err) } defer file.Close() - // 随机生成一个文件 + // Write random data l := int64(8192) buf := make([]byte, l) size := int64(0) @@ -289,21 +297,24 @@ func startTestServer(serverHandle func() http.Handler) net.Listener { size += l } server := &http.Server{} - server.Handler = serverHandle() - go server.Serve(listener) - - return &shutdownListener{ + sl := &shutdownListener{ server: server, Listener: listener, } + server.Handler = serverHandle(sl) + go server.Serve(listener) + + return sl } type shutdownListener struct { - server *http.Server + server *http.Server + isShutdown bool net.Listener } func (c *shutdownListener) Close() error { + c.isShutdown = true closeErr := c.server.Shutdown(context.Background()) if err := ifExistAndRemove(BuildFile); err != nil { fmt.Println(err)