Skip to content

Commit

Permalink
fix: http download retries too many times (#849)
Browse files Browse the repository at this point in the history
  • Loading branch information
monkeyWie authored Dec 24, 2024
1 parent 8b61343 commit c970378
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 43 deletions.
59 changes: 35 additions & 24 deletions internal/protocol/http/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"mime"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
Expand All @@ -25,6 +26,9 @@ import (
"golang.org/x/sync/errgroup"
)

const connectTimeout = 15 * time.Second
const readTimeout = 15 * time.Second

type RequestError struct {
Code int
Msg string
Expand Down Expand Up @@ -103,7 +107,7 @@ func (f *Fetcher) Resolve(req *base.Request) error {
if base.HttpCodePartialContent == httpResp.StatusCode || (base.HttpCodeOK == httpResp.StatusCode && httpResp.Header.Get(base.HttpHeaderAcceptRanges) == base.HttpHeaderBytes && strings.HasPrefix(httpResp.Header.Get(base.HttpHeaderContentRange), base.HttpHeaderBytes)) {
// response 206 status code, support breakpoint continuation
res.Range = true
// 解析资源大小: bytes 0-1000/1001 => 1001
// parse content length from Content-Range header, eg: bytes 0-1000/1001
contentTotal := path.Base(httpResp.Header.Get(base.HttpHeaderContentRange))
if contentTotal != "" {
parse, err := strconv.ParseInt(contentTotal, 10, 64)
Expand Down Expand Up @@ -146,7 +150,7 @@ func (f *Fetcher) Resolve(req *base.Request) error {
file.Name = filename
}
}
// Get file filePath by URL
// get file filePath by URL
if file.Name == "" {
file.Name = path.Base(httpReq.URL.Path)
}
Expand Down Expand Up @@ -250,23 +254,42 @@ func (f *Fetcher) Wait() (err error) {
return <-f.doneCh
}

type fetchResult struct {
err error
}

func (f *Fetcher) fetch() {
var ctx context.Context
ctx, f.cancel = context.WithCancel(context.Background())
f.eg, _ = errgroup.WithContext(ctx)
chunkErrs := make([]error, len(f.chunks))
for i := 0; i < len(f.chunks); i++ {
i := i
f.eg.Go(func() error {
return f.fetchChunk(i, ctx)
err := f.fetchChunk(i, ctx)
// if canceled, fail fast
if errors.Is(err, context.Canceled) {
return err
}
chunkErrs[i] = err
return nil
})
}

go func() {
err := f.eg.Wait()
// check if canceled
if errors.Is(err, context.Canceled) {
// error returned only if canceled, just return
if err != nil {
return
}
// check all fetch results, if any error, return
for _, chunkErr := range chunkErrs {
if chunkErr != nil {
err = chunkErr
break
}
}

f.file.Close()
// Update file last modified time
if f.config.UseServerCtime && f.meta.Res.Files[0].Ctime != nil {
Expand All @@ -289,24 +312,11 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) {
for {
// if chunk is completed, return
if f.meta.Res.Range && chunk.Downloaded >= chunk.End-chunk.Begin+1 {
return
return nil
}

if chunk.retryTimes >= maxRetries {
if !f.meta.Res.Range {
return
}
// check if all failed
allFailed := true
for _, c := range f.chunks {
if chunk.Downloaded < chunk.End-chunk.Begin+1 && c.retryTimes < maxRetries {
allFailed = false
break
}
}
if allFailed {
return
}
return
}

var (
Expand All @@ -333,10 +343,9 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) {
err = NewRequestError(resp.StatusCode, resp.Status)
return err
}
// Http request success, reset retry times
chunk.retryTimes = 0
reader := NewTimeoutReader(resp.Body, readTimeout)
for {
n, err := resp.Body.Read(buf)
n, err := reader.Read(buf)
if n > 0 {
_, err := f.file.WriteAt(buf[:n], chunk.Begin+chunk.Downloaded)
if err != nil {
Expand All @@ -351,7 +360,6 @@ func (f *Fetcher) fetchChunk(index int, ctx context.Context) (err error) {
return err
}
}
return nil
}()
if err != nil {
// If canceled, do not retry
Expand Down Expand Up @@ -452,6 +460,9 @@ func (f *Fetcher) splitChunk() (chunks []*chunk) {

func (f *Fetcher) buildClient() *http.Client {
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: connectTimeout,
}).DialContext,
Proxy: f.ctl.GetProxy(f.meta.Req.Proxy),
TLSClientConfig: &tls.Config{
InsecureSkipVerify: f.meta.Req.SkipVerifyCert,
Expand Down
8 changes: 8 additions & 0 deletions internal/protocol/http/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ func TestFetcher_DownloadLimit(t *testing.T) {
downloadNormal(listener, 8, t)
}

func TestFetcher_DownloadResponseBodyReadTimeout(t *testing.T) {
listener := test.StartTestLimitServer(16, readTimeout.Milliseconds()+5000)
defer listener.Close()

downloadError(listener, 1, t)
downloadError(listener, 4, t)
}

func TestFetcher_DownloadResume(t *testing.T) {
listener := test.StartTestFileServer()
defer listener.Close()
Expand Down
40 changes: 40 additions & 0 deletions internal/protocol/http/timeout_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package http

import (
"context"
"io"
"time"
)

type TimeoutReader struct {
reader io.Reader
timeout time.Duration
}

func NewTimeoutReader(r io.Reader, timeout time.Duration) *TimeoutReader {
return &TimeoutReader{
reader: r,
timeout: timeout,
}
}

func (tr *TimeoutReader) Read(p []byte) (n int, err error) {
ctx, cancel := context.WithTimeout(context.Background(), tr.timeout)
defer cancel()

done := make(chan struct{})
var readErr error
var bytesRead int

go func() {
bytesRead, readErr = tr.reader.Read(p)
close(done)
}()

select {
case <-done:
return bytesRead, readErr
case <-ctx.Done():
return 0, ctx.Err()
}
}
51 changes: 51 additions & 0 deletions internal/protocol/http/timeout_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package http

import (
"bytes"
"context"
"errors"
"io"
"testing"
"time"
)

func TestTimeoutReader_Read(t *testing.T) {
data := []byte("Hello, World!")
reader := bytes.NewReader(data)
timeoutReader := NewTimeoutReader(reader, 1*time.Second)

buf := make([]byte, len(data))
n, err := timeoutReader.Read(buf)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if n != len(data) {
t.Fatalf("expected to read %d bytes, read %d", len(data), n)
}
if !bytes.Equal(buf, data) {
t.Fatalf("expected %s, got %s", data, buf)
}
}

func TestTimeoutReader_ReadTimeout(t *testing.T) {
reader := &slowReader{delay: 2 * time.Second}
timeoutReader := NewTimeoutReader(reader, 1*time.Second)

buf := make([]byte, 8192)
_, err := timeoutReader.Read(buf)
if err == nil {
t.Fatal("expected timeout error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected %v, got %v", context.DeadlineExceeded, err)
}
}

type slowReader struct {
delay time.Duration
}

func (sr *slowReader) Read(p []byte) (n int, err error) {
time.Sleep(sr.delay)
return 0, io.EOF
}
Loading

0 comments on commit c970378

Please sign in to comment.