github.com/imgk/caddy-trojan@v0.0.0-20221206043256-2631719e16c8/trojan/trojan_tcp.go (about) 1 package trojan 2 3 import ( 4 "errors" 5 "io" 6 "net" 7 "os" 8 "time" 9 10 "github.com/imgk/memory-go" 11 ) 12 13 func copyBuffer(w io.Writer, r io.Reader, buf []byte) (n int64, err error) { 14 for { 15 nr, er := r.Read(buf) 16 if nr > 0 { 17 nw, ew := w.Write(buf[0:nr]) 18 if nw < 0 || nr < nw { 19 nw = 0 20 if ew == nil { 21 ew = errors.New("invalid write result") 22 } 23 } 24 n += int64(nw) 25 if ew != nil { 26 err = ew 27 break 28 } 29 if nr != nw { 30 err = io.ErrShortWrite 31 break 32 } 33 } 34 if er != nil { 35 if !errors.Is(er, io.EOF) { 36 err = er 37 } 38 break 39 } 40 } 41 return n, err 42 } 43 44 // HandleTCP is ... 45 // trojan TCP stream 46 func HandleTCP(r io.Reader, w io.Writer, addr net.Addr, d Dialer) (int64, int64, error) { 47 rc, err := d.Dial("tcp", addr.String()) 48 if err != nil { 49 return 0, 0, err 50 } 51 defer rc.Close() 52 53 type Result struct { 54 Num int64 55 Err error 56 } 57 58 errCh := make(chan Result, 0) 59 go func(rc net.Conn, r io.Reader, errCh chan Result) { 60 ptr, buf := memory.Alloc[byte](32 * 1024) 61 defer memory.Free(ptr) 62 63 nr, err := copyBuffer(io.Writer(rc), r, buf) 64 if err == nil || errors.Is(err, os.ErrDeadlineExceeded) { 65 if cw, ok := rc.(interface { 66 CloseWrite() error 67 }); ok { 68 cw.CloseWrite() 69 } 70 rc.SetReadDeadline(time.Now()) 71 errCh <- Result{Num: nr, Err: nil} 72 return 73 } 74 if cw, ok := rc.(interface { 75 CloseWrite() error 76 }); ok { 77 cw.CloseWrite() 78 } 79 rc.SetReadDeadline(time.Now()) 80 errCh <- Result{Num: nr, Err: err} 81 }(rc, r, errCh) 82 83 nr, nw, err := func(rc net.Conn, w io.Writer, errCh chan Result) (int64, int64, error) { 84 ptr, buf := memory.Alloc[byte](32 * 1024) 85 defer memory.Free(ptr) 86 87 nw, err := copyBuffer(w, io.Reader(rc), buf) 88 if err == nil { 89 if cw, ok := w.(interface { 90 CloseWrite() error 91 }); ok { 92 cw.CloseWrite() 93 } 94 r := <-errCh 95 return r.Num, nw, r.Err 96 } 97 if errors.Is(err, os.ErrDeadlineExceeded) { 98 select { 99 case r := <-errCh: 100 if r.Err == nil { 101 for { 102 rc.SetReadDeadline(time.Now().Add(time.Minute)) 103 n, err := copyBuffer(w, io.Reader(rc), buf) 104 nw += n 105 if n == 0 || !errors.Is(err, os.ErrDeadlineExceeded) { 106 break 107 } 108 } 109 return r.Num, nw, r.Err 110 } 111 112 if cw, ok := w.(interface { 113 CloseWrite() error 114 }); ok { 115 cw.CloseWrite() 116 } 117 return r.Num, nw, r.Err 118 case <-time.After(time.Minute): 119 } 120 if cw, ok := w.(interface { 121 CloseWrite() error 122 }); ok { 123 cw.CloseWrite() 124 } 125 r := <-errCh 126 return r.Num, nw, r.Err 127 } 128 rc.SetWriteDeadline(time.Now()) 129 if cw, ok := rc.(interface { 130 CloseWrite() error 131 }); ok { 132 cw.CloseWrite() 133 } 134 r := <-errCh 135 return r.Num, nw, err 136 }(rc, w, errCh) 137 138 return nr, nw, err 139 }