github.com/livekit/protocol@v1.39.3/utils/xtwirp/timeout.go (about)

     1  package xtwirp
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"strconv"
     7  	"time"
     8  
     9  	"github.com/twitchtv/twirp"
    10  )
    11  
    12  const timeoutHeader = "X-Twirp-Timeout-Ms"
    13  
    14  // ClientPassTimout adds context timeout as a Twirp request header.
    15  func ClientPassTimout() twirp.ClientOption {
    16  	return twirp.WithClientInterceptors(func(fnc twirp.Method) twirp.Method {
    17  		return func(ctx context.Context, req any) (any, error) {
    18  			if deadline, ok := ctx.Deadline(); ok {
    19  				dt := time.Until(deadline)
    20  				if dt > 0 {
    21  					h, ok := twirp.HTTPRequestHeaders(ctx)
    22  					if !ok {
    23  						h = make(http.Header)
    24  					}
    25  					h.Add(timeoutHeader, strconv.FormatInt(dt.Milliseconds(), 10))
    26  					var err error
    27  					ctx, err = twirp.WithHTTPRequestHeaders(ctx, h)
    28  					if err != nil {
    29  						return nil, err
    30  					}
    31  				}
    32  			}
    33  			return fnc(ctx, req)
    34  		}
    35  	})
    36  }
    37  
    38  // ServerPassTimeout uses context timeout from Twirp request header.
    39  // It requires that Twirp server handler is wrapped with PassHeadersHandler.
    40  func ServerPassTimeout() twirp.ServerOption {
    41  	return twirp.WithServerInterceptors(func(fnc twirp.Method) twirp.Method {
    42  		return func(ctx context.Context, req any) (any, error) {
    43  			if h := GetHeaders(ctx); h != nil {
    44  				if v, err := strconv.ParseInt(h.Get(timeoutHeader), 10, 64); err == nil {
    45  					var cancel context.CancelFunc
    46  					ctx, cancel = context.WithTimeout(ctx, time.Duration(v)*time.Millisecond)
    47  					defer cancel()
    48  				}
    49  			}
    50  			return fnc(ctx, req)
    51  		}
    52  	})
    53  }