github.com/gospider007/requests@v0.0.0-20240506025355-c73d46169a23/roundTripper.go (about)

     1  package requests
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"crypto/tls"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/textproto"
    11  	"net/url"
    12  	"strings"
    13  	"time"
    14  
    15  	"net/http"
    16  
    17  	"github.com/gospider007/gtls"
    18  	"github.com/gospider007/net/http2"
    19  	"github.com/gospider007/tools"
    20  	utls "github.com/refraction-networking/utls"
    21  )
    22  
    23  type reqTask struct {
    24  	ctx           context.Context
    25  	cnl           context.CancelFunc
    26  	req           *http.Request
    27  	res           *http.Response
    28  	emptyPool     chan struct{}
    29  	err           error
    30  	orderHeaders  []string
    31  	orderHeaders2 []string
    32  }
    33  
    34  func (obj *reqTask) inPool() bool {
    35  	return obj.err == nil && obj.res != nil && obj.res.StatusCode != 101 && !strings.Contains(obj.res.Header.Get("Content-Type"), "text/event-stream")
    36  }
    37  
    38  func getKey(ctxData *reqCtxData, req *http.Request) (key string) {
    39  	var proxyUser string
    40  	if ctxData.proxy != nil {
    41  		proxyUser = ctxData.proxy.User.String()
    42  	}
    43  	return fmt.Sprintf("%s@%s@%s", proxyUser, getAddr(ctxData.proxy), getAddr(req.URL))
    44  }
    45  
    46  type roundTripper struct {
    47  	ctx        context.Context
    48  	cnl        context.CancelFunc
    49  	connPools  *connPools
    50  	dialer     *DialClient
    51  	tlsConfig  *tls.Config
    52  	utlsConfig *utls.Config
    53  	getProxy   func(ctx context.Context, url *url.URL) (string, error)
    54  }
    55  
    56  func newRoundTripper(preCtx context.Context, option ClientOption) *roundTripper {
    57  	if preCtx == nil {
    58  		preCtx = context.TODO()
    59  	}
    60  	ctx, cnl := context.WithCancel(preCtx)
    61  	dialClient := NewDail(DialOption{
    62  		DialTimeout: option.DialTimeout,
    63  		Dns:         option.Dns,
    64  		KeepAlive:   option.KeepAlive,
    65  		LocalAddr:   option.LocalAddr,
    66  		AddrType:    option.AddrType,
    67  		GetAddrType: option.GetAddrType,
    68  	})
    69  	tlsConfig := &tls.Config{
    70  		InsecureSkipVerify: true,
    71  		SessionTicketKey:   [32]byte{},
    72  		ClientSessionCache: tls.NewLRUClientSessionCache(0),
    73  	}
    74  	utlsConfig := &utls.Config{
    75  		InsecureSkipVerify:                 true,
    76  		InsecureSkipTimeVerify:             true,
    77  		SessionTicketKey:                   [32]byte{},
    78  		ClientSessionCache:                 utls.NewLRUClientSessionCache(0),
    79  		OmitEmptyPsk:                       true,
    80  		PreferSkipResumptionOnNilExtension: true,
    81  	}
    82  	return &roundTripper{
    83  		tlsConfig:  tlsConfig,
    84  		utlsConfig: utlsConfig,
    85  		ctx:        ctx,
    86  		cnl:        cnl,
    87  		dialer:     dialClient,
    88  		getProxy:   option.GetProxy,
    89  		connPools:  newConnPools(),
    90  	}
    91  }
    92  func (obj *roundTripper) newConnPool(conn *connecotr, key string) *connPool {
    93  	pool := new(connPool)
    94  	pool.connKey = key
    95  	pool.deleteCtx, pool.deleteCnl = context.WithCancelCause(obj.ctx)
    96  	pool.closeCtx, pool.closeCnl = context.WithCancelCause(pool.deleteCtx)
    97  	pool.tasks = make(chan *reqTask)
    98  	pool.connPools = obj.connPools
    99  	pool.total.Add(1)
   100  	go pool.rwMain(conn)
   101  	return pool
   102  }
   103  func (obj *roundTripper) putConnPool(key string, conn *connecotr) {
   104  	conn.inPool = true
   105  	if conn.h2RawConn == nil {
   106  		go conn.read()
   107  	}
   108  	pool := obj.connPools.get(key)
   109  	if pool != nil {
   110  		pool.total.Add(1)
   111  		go pool.rwMain(conn)
   112  	} else {
   113  		obj.connPools.set(key, obj.newConnPool(conn, key))
   114  	}
   115  }
   116  func (obj *roundTripper) tlsConfigClone() *tls.Config {
   117  	return obj.tlsConfig.Clone()
   118  }
   119  func (obj *roundTripper) utlsConfigClone() *utls.Config {
   120  	return obj.utlsConfig.Clone()
   121  }
   122  func (obj *roundTripper) newConnecotr(netConn net.Conn) *connecotr {
   123  	conne := new(connecotr)
   124  	conne.withCancel(obj.ctx, obj.ctx)
   125  	conne.rawConn = netConn
   126  	return conne
   127  }
   128  func (obj *roundTripper) dial(ctxData *reqCtxData, req *http.Request) (conn *connecotr, err error) {
   129  	var proxy *url.URL
   130  	if !ctxData.disProxy {
   131  		if proxy = cloneUrl(ctxData.proxy); proxy == nil && obj.getProxy != nil {
   132  			proxyStr, err := obj.getProxy(req.Context(), proxy)
   133  			if err != nil {
   134  				return conn, err
   135  			}
   136  			if proxy, err = gtls.VerifyProxy(proxyStr); err != nil {
   137  				return conn, err
   138  			}
   139  		}
   140  	}
   141  	netConn, err := obj.dialer.DialContextWithProxy(req.Context(), ctxData, "tcp", req.URL.Scheme, getAddr(req.URL), getHost(req), proxy, obj.tlsConfigClone())
   142  	if err != nil {
   143  		return conn, err
   144  	}
   145  	var h2 bool
   146  	if req.URL.Scheme == "https" {
   147  		ctx, cnl := context.WithTimeout(req.Context(), ctxData.tlsHandshakeTimeout)
   148  		defer cnl()
   149  		if ctxData.ja3Spec.IsSet() {
   150  			tlsConfig := obj.utlsConfigClone()
   151  			if ctxData.forceHttp1 {
   152  				tlsConfig.NextProtos = []string{"http/1.1"}
   153  			}
   154  			tlsConn, err := obj.dialer.addJa3Tls(ctx, netConn, getHost(req), ctxData.isWs || ctxData.forceHttp1, ctxData.ja3Spec, tlsConfig)
   155  			if err != nil {
   156  				return conn, tools.WrapError(err, "add ja3 tls error")
   157  			}
   158  			h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2"
   159  			netConn = tlsConn
   160  		} else {
   161  			tlsConn, err := obj.dialer.addTls(ctx, netConn, getHost(req), ctxData.isWs || ctxData.forceHttp1, obj.tlsConfigClone())
   162  			if err != nil {
   163  				return conn, tools.WrapError(err, "add tls error")
   164  			}
   165  			h2 = tlsConn.ConnectionState().NegotiatedProtocol == "h2"
   166  			netConn = tlsConn
   167  		}
   168  	}
   169  	conne := obj.newConnecotr(netConn)
   170  	if proxy != nil {
   171  		conne.proxy = proxy.String()
   172  	}
   173  	if h2 {
   174  		if conne.h2RawConn, err = http2.NewClientConn(func() {
   175  			conne.closeCnl(errors.New("http2 client close"))
   176  		}, netConn, ctxData.h2Ja3Spec); err != nil {
   177  			return conne, err
   178  		}
   179  	} else {
   180  		conne.r, conne.w = textproto.NewReader(bufio.NewReader(conne)), bufio.NewWriter(conne)
   181  	}
   182  	return conne, err
   183  }
   184  func (obj *roundTripper) setGetProxy(getProxy func(ctx context.Context, url *url.URL) (string, error)) {
   185  	obj.getProxy = getProxy
   186  }
   187  
   188  func (obj *roundTripper) poolRoundTrip(ctxData *reqCtxData, task *reqTask, key string) (newConn bool) {
   189  	pool := obj.connPools.get(key)
   190  	if pool == nil {
   191  		return true
   192  	}
   193  	task.ctx, task.cnl = context.WithTimeout(task.req.Context(), ctxData.responseHeaderTimeout)
   194  	select {
   195  	case pool.tasks <- task:
   196  		select {
   197  		case <-task.emptyPool:
   198  			return true
   199  		case <-task.ctx.Done():
   200  			if task.err == nil && task.res == nil {
   201  				task.err = task.ctx.Err()
   202  			}
   203  			return false
   204  		}
   205  	default:
   206  		return true
   207  	}
   208  }
   209  func (obj *roundTripper) connRoundTrip(ctxData *reqCtxData, task *reqTask, key string) (retry bool) {
   210  	conn, err := obj.dial(ctxData, task.req)
   211  	if err != nil {
   212  		task.err = err
   213  		return
   214  	}
   215  	task.ctx, task.cnl = context.WithTimeout(task.req.Context(), ctxData.responseHeaderTimeout)
   216  	retry = conn.taskMain(task, false)
   217  	if retry || task.err != nil {
   218  		return retry
   219  	}
   220  	if task.inPool() && !ctxData.disAlive {
   221  		obj.putConnPool(key, conn)
   222  	}
   223  	return retry
   224  }
   225  
   226  func (obj *roundTripper) closeConns() {
   227  	obj.connPools.iter(func(key string, pool *connPool) bool {
   228  		pool.close()
   229  		obj.connPools.del(key)
   230  		return true
   231  	})
   232  }
   233  func (obj *roundTripper) forceCloseConns() {
   234  	obj.connPools.iter(func(key string, pool *connPool) bool {
   235  		pool.forceClose()
   236  		obj.connPools.del(key)
   237  		return true
   238  	})
   239  }
   240  func (obj *roundTripper) newReqTask(req *http.Request, ctxData *reqCtxData) *reqTask {
   241  	if ctxData.responseHeaderTimeout == 0 {
   242  		ctxData.responseHeaderTimeout = time.Second * 300
   243  	}
   244  	task := new(reqTask)
   245  	task.req = req
   246  	task.emptyPool = make(chan struct{})
   247  	task.orderHeaders = ctxData.orderHeaders
   248  	if ctxData.h2Ja3Spec.OrderHeaders != nil {
   249  		task.orderHeaders2 = ctxData.h2Ja3Spec.OrderHeaders
   250  	} else {
   251  		task.orderHeaders2 = ctxData.orderHeaders
   252  	}
   253  	return task
   254  }
   255  func (obj *roundTripper) RoundTrip(req *http.Request) (response *http.Response, err error) {
   256  	ctxData := GetReqCtxData(req.Context())
   257  	if ctxData.requestCallBack != nil {
   258  		if err = ctxData.requestCallBack(req.Context(), req, nil); err != nil {
   259  			if err == http.ErrUseLastResponse {
   260  				if req.Response == nil {
   261  					return nil, errors.New("errUseLastResponse response is nil")
   262  				} else {
   263  					return req.Response, nil
   264  				}
   265  			}
   266  			return nil, err
   267  		}
   268  	}
   269  	key := getKey(ctxData, req) //pool key
   270  	task := obj.newReqTask(req, ctxData)
   271  	//get pool conn
   272  	var isNewConn bool
   273  	if !ctxData.disAlive {
   274  		isNewConn = obj.poolRoundTrip(ctxData, task, key)
   275  	}
   276  	if ctxData.disAlive || isNewConn {
   277  		ctxData.isNewConn = true
   278  		for {
   279  			retry := obj.connRoundTrip(ctxData, task, key)
   280  			if !retry {
   281  				break
   282  			}
   283  		}
   284  	}
   285  	if task.err == nil && ctxData.requestCallBack != nil {
   286  		if err = ctxData.requestCallBack(task.req.Context(), task.req, task.res); err != nil {
   287  			task.err = err
   288  		}
   289  	}
   290  	return task.res, task.err
   291  }