github.com/imgk/caddy-trojan@v0.0.0-20221206043256-2631719e16c8/trojan/trojan_tcp.go (about)

     1  package trojan
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"net"
     7  	"os"
     8  	"time"
     9  
    10  	"github.com/imgk/memory-go"
    11  )
    12  
    13  func copyBuffer(w io.Writer, r io.Reader, buf []byte) (n int64, err error) {
    14  	for {
    15  		nr, er := r.Read(buf)
    16  		if nr > 0 {
    17  			nw, ew := w.Write(buf[0:nr])
    18  			if nw < 0 || nr < nw {
    19  				nw = 0
    20  				if ew == nil {
    21  					ew = errors.New("invalid write result")
    22  				}
    23  			}
    24  			n += int64(nw)
    25  			if ew != nil {
    26  				err = ew
    27  				break
    28  			}
    29  			if nr != nw {
    30  				err = io.ErrShortWrite
    31  				break
    32  			}
    33  		}
    34  		if er != nil {
    35  			if !errors.Is(er, io.EOF) {
    36  				err = er
    37  			}
    38  			break
    39  		}
    40  	}
    41  	return n, err
    42  }
    43  
    44  // HandleTCP is ...
    45  // trojan TCP stream
    46  func HandleTCP(r io.Reader, w io.Writer, addr net.Addr, d Dialer) (int64, int64, error) {
    47  	rc, err := d.Dial("tcp", addr.String())
    48  	if err != nil {
    49  		return 0, 0, err
    50  	}
    51  	defer rc.Close()
    52  
    53  	type Result struct {
    54  		Num int64
    55  		Err error
    56  	}
    57  
    58  	errCh := make(chan Result, 0)
    59  	go func(rc net.Conn, r io.Reader, errCh chan Result) {
    60  		ptr, buf := memory.Alloc[byte](32 * 1024)
    61  		defer memory.Free(ptr)
    62  
    63  		nr, err := copyBuffer(io.Writer(rc), r, buf)
    64  		if err == nil || errors.Is(err, os.ErrDeadlineExceeded) {
    65  			if cw, ok := rc.(interface {
    66  				CloseWrite() error
    67  			}); ok {
    68  				cw.CloseWrite()
    69  			}
    70  			rc.SetReadDeadline(time.Now())
    71  			errCh <- Result{Num: nr, Err: nil}
    72  			return
    73  		}
    74  		if cw, ok := rc.(interface {
    75  			CloseWrite() error
    76  		}); ok {
    77  			cw.CloseWrite()
    78  		}
    79  		rc.SetReadDeadline(time.Now())
    80  		errCh <- Result{Num: nr, Err: err}
    81  	}(rc, r, errCh)
    82  
    83  	nr, nw, err := func(rc net.Conn, w io.Writer, errCh chan Result) (int64, int64, error) {
    84  		ptr, buf := memory.Alloc[byte](32 * 1024)
    85  		defer memory.Free(ptr)
    86  
    87  		nw, err := copyBuffer(w, io.Reader(rc), buf)
    88  		if err == nil {
    89  			if cw, ok := w.(interface {
    90  				CloseWrite() error
    91  			}); ok {
    92  				cw.CloseWrite()
    93  			}
    94  			r := <-errCh
    95  			return r.Num, nw, r.Err
    96  		}
    97  		if errors.Is(err, os.ErrDeadlineExceeded) {
    98  			select {
    99  			case r := <-errCh:
   100  				if r.Err == nil {
   101  					for {
   102  						rc.SetReadDeadline(time.Now().Add(time.Minute))
   103  						n, err := copyBuffer(w, io.Reader(rc), buf)
   104  						nw += n
   105  						if n == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
   106  							break
   107  						}
   108  					}
   109  					return r.Num, nw, r.Err
   110  				}
   111  
   112  				if cw, ok := w.(interface {
   113  					CloseWrite() error
   114  				}); ok {
   115  					cw.CloseWrite()
   116  				}
   117  				return r.Num, nw, r.Err
   118  			case <-time.After(time.Minute):
   119  			}
   120  			if cw, ok := w.(interface {
   121  				CloseWrite() error
   122  			}); ok {
   123  				cw.CloseWrite()
   124  			}
   125  			r := <-errCh
   126  			return r.Num, nw, r.Err
   127  		}
   128  		rc.SetWriteDeadline(time.Now())
   129  		if cw, ok := rc.(interface {
   130  			CloseWrite() error
   131  		}); ok {
   132  			cw.CloseWrite()
   133  		}
   134  		r := <-errCh
   135  		return r.Num, nw, err
   136  	}(rc, w, errCh)
   137  
   138  	return nr, nw, err
   139  }