@@ -29,7 +29,6 @@ import (
29
29
"context"
30
30
"fmt"
31
31
"io"
32
- "io/ioutil"
33
32
"log"
34
33
"math"
35
34
"math/rand"
62
61
// limit the size we consume to respReadLimit.
63
62
respReadLimit = int64 (4096 )
64
63
64
+ // timeNow sets the function that returns the current time.
65
+ // This defaults to time.Now. Changes to this should only be done in tests.
66
+ timeNow = time .Now
67
+
65
68
// A regular expression to match the error returned by net/http when the
66
69
// configured number of redirects is exhausted. This error isn't typed
67
70
// specifically so we resort to matching on the error string.
@@ -252,29 +255,27 @@ func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, erro
252
255
// deal with it seeking so want it to match here instead of the
253
256
// io.ReadSeeker case.
254
257
case * bytes.Reader :
255
- buf , err := ioutil .ReadAll (body )
256
- if err != nil {
257
- return nil , 0 , err
258
- }
258
+ snapshot := * body
259
259
bodyReader = func () (io.Reader , error ) {
260
- return bytes .NewReader (buf ), nil
260
+ r := snapshot
261
+ return & r , nil
261
262
}
262
- contentLength = int64 (len ( buf ))
263
+ contentLength = int64 (body . Len ( ))
263
264
264
265
// Compat case
265
266
case io.ReadSeeker :
266
267
raw := body
267
268
bodyReader = func () (io.Reader , error ) {
268
269
_ , err := raw .Seek (0 , 0 )
269
- return ioutil .NopCloser (raw ), err
270
+ return io .NopCloser (raw ), err
270
271
}
271
272
if lr , ok := raw .(LenReader ); ok {
272
273
contentLength = int64 (lr .Len ())
273
274
}
274
275
275
276
// Read all in so we can reset
276
277
case io.Reader :
277
- buf , err := ioutil .ReadAll (body )
278
+ buf , err := io .ReadAll (body )
278
279
if err != nil {
279
280
return nil , 0 , err
280
281
}
@@ -397,6 +398,9 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
397
398
// attempted. If overriding this, be sure to close the body if needed.
398
399
type ErrorHandler func (resp * http.Response , err error , numTries int ) (* http.Response , error )
399
400
401
+ // PrepareRetry is called before retry operation. It can be used for example to re-sign the request
402
+ type PrepareRetry func (req * http.Request ) error
403
+
400
404
// Client is used to make HTTP requests. It adds additional functionality
401
405
// like automatic retries to tolerate minor outages.
402
406
type Client struct {
@@ -425,6 +429,9 @@ type Client struct {
425
429
// ErrorHandler specifies the custom error handler to use, if any
426
430
ErrorHandler ErrorHandler
427
431
432
+ // PrepareRetry can prepare the request for retry operation, for example re-sign it
433
+ PrepareRetry PrepareRetry
434
+
428
435
loggerInit sync.Once
429
436
clientInit sync.Once
430
437
}
@@ -544,10 +551,8 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
544
551
func DefaultBackoff (min , max time.Duration , attemptNum int , resp * http.Response ) time.Duration {
545
552
if resp != nil {
546
553
if resp .StatusCode == http .StatusTooManyRequests || resp .StatusCode == http .StatusServiceUnavailable {
547
- if s , ok := resp .Header ["Retry-After" ]; ok {
548
- if sleep , err := strconv .ParseInt (s [0 ], 10 , 64 ); err == nil {
549
- return time .Second * time .Duration (sleep )
550
- }
554
+ if sleep , ok := parseRetryAfterHeader (resp .Header ["Retry-After" ]); ok {
555
+ return sleep
551
556
}
552
557
}
553
558
}
@@ -560,6 +565,41 @@ func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response)
560
565
return sleep
561
566
}
562
567
568
+ // parseRetryAfterHeader parses the Retry-After header and returns the
569
+ // delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after
570
+ // The bool returned will be true if the header was successfully parsed.
571
+ // Otherwise, the header was either not present, or was not parseable according to the spec.
572
+ //
573
+ // Retry-After headers come in two flavors: Seconds or HTTP-Date
574
+ //
575
+ // Examples:
576
+ // * Retry-After: Fri, 31 Dec 1999 23:59:59 GMT
577
+ // * Retry-After: 120
578
+ func parseRetryAfterHeader (headers []string ) (time.Duration , bool ) {
579
+ if len (headers ) == 0 || headers [0 ] == "" {
580
+ return 0 , false
581
+ }
582
+ header := headers [0 ]
583
+ // Retry-After: 120
584
+ if sleep , err := strconv .ParseInt (header , 10 , 64 ); err == nil {
585
+ if sleep < 0 { // a negative sleep doesn't make sense
586
+ return 0 , false
587
+ }
588
+ return time .Second * time .Duration (sleep ), true
589
+ }
590
+
591
+ // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT
592
+ retryTime , err := time .Parse (time .RFC1123 , header )
593
+ if err != nil {
594
+ return 0 , false
595
+ }
596
+ if until := retryTime .Sub (timeNow ()); until > 0 {
597
+ return until , true
598
+ }
599
+ // date is in the past
600
+ return 0 , true
601
+ }
602
+
563
603
// LinearJitterBackoff provides a callback for Client.Backoff which will
564
604
// perform linear backoff based on the attempt number and with jitter to
565
605
// prevent a thundering herd.
@@ -587,13 +627,13 @@ func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Resp
587
627
}
588
628
589
629
// Seed rand; doing this every time is fine
590
- rand := rand .New (rand .NewSource (int64 (time .Now ().Nanosecond ())))
630
+ source := rand .New (rand .NewSource (int64 (time .Now ().Nanosecond ())))
591
631
592
632
// Pick a random number that lies somewhere between the min and max and
593
633
// multiply by the attemptNum. attemptNum starts at zero so we always
594
634
// increment here. We first get a random percentage, then apply that to the
595
635
// difference between min and max, and add to min.
596
- jitter := rand .Float64 () * float64 (max - min )
636
+ jitter := source .Float64 () * float64 (max - min )
597
637
jitterMin := int64 (jitter ) + int64 (min )
598
638
return time .Duration (jitterMin * int64 (attemptNum ))
599
639
}
@@ -627,10 +667,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
627
667
var resp * http.Response
628
668
var attempt int
629
669
var shouldRetry bool
630
- var doErr , respErr , checkErr error
670
+ var doErr , respErr , checkErr , prepareErr error
631
671
632
672
for i := 0 ; ; i ++ {
633
- doErr , respErr = nil , nil
673
+ doErr , respErr , prepareErr = nil , nil , nil
634
674
attempt ++
635
675
636
676
// Always rewind the request body when non-nil.
@@ -643,7 +683,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
643
683
if c , ok := body .(io.ReadCloser ); ok {
644
684
req .Body = c
645
685
} else {
646
- req .Body = ioutil .NopCloser (body )
686
+ req .Body = io .NopCloser (body )
647
687
}
648
688
}
649
689
@@ -737,17 +777,26 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
737
777
// without racing against the closeBody call in persistConn.writeLoop.
738
778
httpreq := * req .Request
739
779
req .Request = & httpreq
780
+
781
+ if c .PrepareRetry != nil {
782
+ if err := c .PrepareRetry (req .Request ); err != nil {
783
+ prepareErr = err
784
+ break
785
+ }
786
+ }
740
787
}
741
788
742
789
// this is the closest we have to success criteria
743
- if doErr == nil && respErr == nil && checkErr == nil && ! shouldRetry {
790
+ if doErr == nil && respErr == nil && checkErr == nil && prepareErr == nil && ! shouldRetry {
744
791
return resp , nil
745
792
}
746
793
747
794
defer c .HTTPClient .CloseIdleConnections ()
748
795
749
796
var err error
750
- if checkErr != nil {
797
+ if prepareErr != nil {
798
+ err = prepareErr
799
+ } else if checkErr != nil {
751
800
err = checkErr
752
801
} else if respErr != nil {
753
802
err = respErr
@@ -779,7 +828,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
779
828
// Try to read the response body so we can reuse this connection.
780
829
func (c * Client ) drainBody (body io.ReadCloser ) {
781
830
defer body .Close ()
782
- _ , err := io .Copy (ioutil .Discard , io .LimitReader (body , respReadLimit ))
831
+ _ , err := io .Copy (io .Discard , io .LimitReader (body , respReadLimit ))
783
832
if err != nil {
784
833
if c .logger () != nil {
785
834
switch v := c .logger ().(type ) {
0 commit comments