github.com/gospider007/requests@v0.0.0-20240506025355-c73d46169a23/conn.go (about)

     1  package requests
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/textproto"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/gospider007/net/http2"
    15  	"github.com/gospider007/tools"
    16  )
    17  
    18  type connecotr struct {
    19  	deleteCtx context.Context //force close
    20  	deleteCnl context.CancelCauseFunc
    21  	closeCtx  context.Context //safe close
    22  	closeCnl  context.CancelCauseFunc
    23  
    24  	bodyCtx context.Context //body close
    25  	bodyCnl context.CancelCauseFunc
    26  
    27  	rawConn   net.Conn
    28  	h2RawConn *http2.ClientConn
    29  	proxy     string
    30  	r         *textproto.Reader
    31  	w         *bufio.Writer
    32  	pr        *pipCon
    33  	inPool    bool
    34  }
    35  
    36  func (obj *connecotr) withCancel(deleteCtx context.Context, closeCtx context.Context) {
    37  	obj.deleteCtx, obj.deleteCnl = context.WithCancelCause(deleteCtx)
    38  	obj.closeCtx, obj.closeCnl = context.WithCancelCause(closeCtx)
    39  }
    40  func (obj *connecotr) Close() error {
    41  	obj.deleteCnl(errors.New("connecotr close"))
    42  	if obj.pr != nil {
    43  		obj.pr.Close(errors.New("connecotr close"))
    44  	}
    45  	if obj.h2RawConn != nil {
    46  		obj.h2RawConn.Close()
    47  	}
    48  	return obj.rawConn.Close()
    49  }
    50  func (obj *connecotr) read() (err error) {
    51  	if obj.pr != nil {
    52  		return nil
    53  	}
    54  	var pw *pipCon
    55  	obj.pr, pw = pipe(obj.deleteCtx)
    56  	if _, err = io.Copy(pw, obj.rawConn); err == nil {
    57  		err = io.EOF
    58  	}
    59  	pw.Close(err)
    60  	obj.Close()
    61  	return
    62  }
    63  func (obj *connecotr) Read(b []byte) (i int, err error) {
    64  	if obj.pr == nil {
    65  		return obj.rawConn.Read(b)
    66  	}
    67  	return obj.pr.Read(b)
    68  }
    69  func (obj *connecotr) Write(b []byte) (int, error) {
    70  	return obj.rawConn.Write(b)
    71  }
    72  func (obj *connecotr) LocalAddr() net.Addr {
    73  	return obj.rawConn.LocalAddr()
    74  }
    75  func (obj *connecotr) RemoteAddr() net.Addr {
    76  	return obj.rawConn.RemoteAddr()
    77  }
    78  func (obj *connecotr) SetDeadline(t time.Time) error {
    79  	return obj.rawConn.SetDeadline(t)
    80  }
    81  func (obj *connecotr) SetReadDeadline(t time.Time) error {
    82  	return obj.rawConn.SetReadDeadline(t)
    83  }
    84  func (obj *connecotr) SetWriteDeadline(t time.Time) error {
    85  	return obj.rawConn.SetWriteDeadline(t)
    86  }
    87  
    88  func (obj *connecotr) h2Closed() bool {
    89  	if obj.h2RawConn == nil {
    90  		return false
    91  	}
    92  	state := obj.h2RawConn.State()
    93  	return state.Closed || state.Closing
    94  }
    95  func (obj *connecotr) wrapBody(task *reqTask) {
    96  	body := new(readWriteCloser)
    97  	obj.bodyCtx, obj.bodyCnl = context.WithCancelCause(task.req.Context())
    98  	body.body = task.res.Body
    99  	body.conn = obj
   100  	task.res.Body = body
   101  }
   102  func (obj *connecotr) http1Req(task *reqTask) {
   103  	if task.err = httpWrite(task.req, obj.w, task.orderHeaders); task.err == nil {
   104  		task.res, task.err = readResponse(obj.r, task.req)
   105  		if task.err != nil {
   106  			task.err = tools.WrapError(task.err, "http1 read error")
   107  		} else if task.res == nil {
   108  			task.err = errors.New("response is nil")
   109  		} else {
   110  			obj.wrapBody(task)
   111  		}
   112  	}
   113  	task.cnl()
   114  }
   115  
   116  func (obj *connecotr) http2Req(task *reqTask) {
   117  	if task.res, task.err = obj.h2RawConn.RoundTripWithOrderHeaders(task.req, task.orderHeaders2); task.res != nil && task.err == nil {
   118  		obj.wrapBody(task)
   119  	} else if task.err != nil {
   120  		task.err = tools.WrapError(task.err, "http2 roundTrip error")
   121  	}
   122  	task.cnl()
   123  }
   124  func (obj *connecotr) waitBodyClose() error {
   125  	select {
   126  	case <-obj.bodyCtx.Done(): //wait body close
   127  		if err := context.Cause(obj.bodyCtx); errors.Is(err, ErrgospiderBodyClose) {
   128  			return nil
   129  		} else {
   130  			return err
   131  		}
   132  	case <-obj.deleteCtx.Done(): //force conn close
   133  		return tools.WrapError(context.Cause(obj.deleteCtx), "delete ctx error: ")
   134  	}
   135  }
   136  
   137  func (obj *connecotr) taskMain(task *reqTask, waitBody bool) (retry bool) {
   138  	defer func() {
   139  		if retry || task.err != nil {
   140  			obj.Close()
   141  		}
   142  	}()
   143  	if obj.h2Closed() {
   144  		return true
   145  	}
   146  	select {
   147  	case <-obj.closeCtx.Done():
   148  		return true
   149  	default:
   150  	}
   151  	if obj.h2RawConn != nil {
   152  		go obj.http2Req(task)
   153  	} else {
   154  		go obj.http1Req(task)
   155  	}
   156  	select {
   157  	case <-task.ctx.Done():
   158  		if task.err != nil {
   159  			return false
   160  		}
   161  		if task.res == nil {
   162  			task.err = task.ctx.Err()
   163  			if task.err == nil {
   164  				task.err = errors.New("response is nil")
   165  			}
   166  			return false
   167  		}
   168  		if waitBody {
   169  			task.err = obj.waitBodyClose()
   170  		}
   171  		return false
   172  	case <-obj.deleteCtx.Done(): //force conn close
   173  		task.err = tools.WrapError(obj.deleteCtx.Err(), "delete ctx error: ")
   174  		task.cnl()
   175  		return false
   176  	}
   177  }
   178  
   179  type connPool struct {
   180  	deleteCtx context.Context
   181  	deleteCnl context.CancelCauseFunc
   182  	closeCtx  context.Context
   183  	closeCnl  context.CancelCauseFunc
   184  	connKey   string
   185  	total     atomic.Int64
   186  	tasks     chan *reqTask
   187  	connPools *connPools
   188  }
   189  
   190  type connPools struct {
   191  	connPools sync.Map
   192  }
   193  
   194  func newConnPools() *connPools {
   195  	return new(connPools)
   196  }
   197  
   198  func (obj *connPools) get(key string) *connPool {
   199  	val, ok := obj.connPools.Load(key)
   200  	if !ok {
   201  		return nil
   202  	}
   203  	return val.(*connPool)
   204  }
   205  
   206  func (obj *connPools) set(key string, pool *connPool) {
   207  	obj.connPools.Store(key, pool)
   208  }
   209  
   210  func (obj *connPools) del(key string) {
   211  	obj.connPools.Delete(key)
   212  }
   213  
   214  func (obj *connPools) iter(f func(key string, value *connPool) bool) {
   215  	obj.connPools.Range(func(key, value any) bool {
   216  		return f(key.(string), value.(*connPool))
   217  	})
   218  }
   219  
   220  func (obj *connPool) notice(task *reqTask) {
   221  	select {
   222  	case obj.tasks <- task:
   223  	case task.emptyPool <- struct{}{}:
   224  	}
   225  }
   226  
   227  func (obj *connPool) rwMain(conn *connecotr) {
   228  	conn.withCancel(obj.deleteCtx, obj.closeCtx)
   229  	defer func() {
   230  		conn.Close()
   231  		obj.total.Add(-1)
   232  		if obj.total.Load() <= 0 {
   233  			obj.close()
   234  		}
   235  	}()
   236  	if err := conn.waitBodyClose(); err != nil {
   237  		return
   238  	}
   239  	for {
   240  		select {
   241  		case <-conn.closeCtx.Done(): //safe close conn
   242  			return
   243  		case task := <-obj.tasks: //recv task
   244  			if task == nil {
   245  				return
   246  			}
   247  			if conn.taskMain(task, true) {
   248  				obj.notice(task)
   249  				return
   250  			}
   251  			if task.err != nil {
   252  				return
   253  			}
   254  		}
   255  	}
   256  }
   257  func (obj *connPool) forceClose() {
   258  	obj.close()
   259  	obj.deleteCnl(errors.New("connPool forceClose"))
   260  }
   261  func (obj *connPool) close() {
   262  	obj.connPools.del(obj.connKey)
   263  	obj.closeCnl(errors.New("connPool close"))
   264  }