github.com/yaling888/clash@v1.53.0/transport/crypto/conn.go (about) 1 package crypto 2 3 import ( 4 "crypto/rand" 5 "encoding/binary" 6 "fmt" 7 "io" 8 "net" 9 "strings" 10 "sync" 11 12 "github.com/yaling888/clash/common/pool" 13 ) 14 15 type AEADOption struct { 16 Cipher string `proxy:"cipher,omitempty"` 17 Key string `proxy:"key,omitempty"` 18 Salt string `proxy:"salt,omitempty"` 19 } 20 21 var _ net.Conn = (*aeadConn)(nil) 22 23 type aeadConn struct { 24 net.Conn 25 cipher *AEAD 26 27 rMux sync.Mutex 28 buf []byte 29 lasR int 30 } 31 32 func (c *aeadConn) Read(p []byte) (n int, err error) { 33 c.rMux.Lock() 34 defer c.rMux.Unlock() 35 36 if c.lasR > 0 && c.buf != nil { 37 n = copy(p, c.buf[len(c.buf)-c.lasR:]) 38 c.lasR -= n 39 return 40 } 41 42 if c.buf == nil { 43 c.buf = make([]byte, 64<<10) 44 } else { 45 c.buf = c.buf[:64<<10] 46 } 47 48 defer func() { 49 if err != nil { 50 c.lasR = 0 51 c.buf = nil 52 } 53 }() 54 55 hdSize := c.cipher.NonceSize() + 2 56 _, err = io.ReadFull(c.Conn, c.buf[:hdSize]) 57 if err != nil { 58 return 59 } 60 61 length := binary.BigEndian.Uint16(c.buf[c.cipher.NonceSize():]) 62 if length == 0 { 63 err = io.EOF 64 return 65 } 66 67 nonce := make([]byte, c.cipher.NonceSize()) 68 copy(nonce, c.buf[:c.cipher.NonceSize()]) 69 70 _, err = io.ReadAtLeast(c.Conn, c.buf[:length], int(length)) 71 if err != nil { 72 return 73 } 74 75 b, err := c.cipher.Open(c.buf[:0], nonce, c.buf[:length], nil) 76 if err != nil { 77 return 78 } 79 80 c.lasR = len(b) 81 c.buf = c.buf[:c.lasR] 82 83 n = copy(p, c.buf[len(c.buf)-c.lasR:]) 84 c.lasR -= n 85 return 86 } 87 88 func (c *aeadConn) Write(p []byte) (n int, err error) { 89 bufP := pool.GetBufferWriter() 90 defer pool.PutBufferWriter(bufP) 91 92 bufP.Grow(c.cipher.NonceSize() + 2 + c.cipher.Overhead() + len(p)) 93 94 nonce := (*bufP)[:c.cipher.NonceSize()] 95 if _, err = rand.Read(nonce); err != nil { 96 return 97 } 98 99 b := c.cipher.Seal((*bufP)[:c.cipher.NonceSize()+2], nonce, p, nil) 100 lenB := len(b) 101 102 binary.BigEndian.PutUint16(b[c.cipher.NonceSize():], uint16(lenB-c.cipher.NonceSize()-2)) 103 104 lenP := len(p) 105 delta := lenB - lenP 106 nw, err := c.Conn.Write(b) 107 n = max(nw-delta, 0) 108 if n < lenP && err == nil { 109 err = io.ErrShortWrite 110 } 111 return 112 } 113 114 func (c *aeadConn) Close() (err error) { 115 err = c.Conn.Close() 116 117 c.rMux.Lock() 118 defer c.rMux.Unlock() 119 120 c.lasR = 0 121 c.buf = nil 122 return 123 } 124 125 func StreamAEADConn(conn net.Conn, opt AEADOption) (net.Conn, error) { 126 aead, err := NewAEAD(opt.Cipher, opt.Key, opt.Salt) 127 if err != nil { 128 return nil, err 129 } 130 131 if aead == nil { 132 return nil, fmt.Errorf("unsupported cipher: %s", opt.Cipher) 133 } 134 135 return &aeadConn{ 136 Conn: conn, 137 cipher: aead, 138 }, nil 139 } 140 141 func StreamAEADConnOrNot(conn net.Conn, opt AEADOption) (net.Conn, error) { 142 if opt.Cipher == "" || strings.ToLower(opt.Cipher) == "none" { 143 return conn, nil 144 } 145 146 return StreamAEADConn(conn, opt) 147 } 148 149 func VerifyAEADOption(opt AEADOption, allowNone bool) (bool, error) { 150 if !allowNone && (opt.Cipher == "" || strings.ToLower(opt.Cipher) == "none" || opt.Key == "") { 151 return false, nil 152 } 153 if _, err := NewAEAD(opt.Cipher, opt.Key, opt.Salt); err != nil { 154 return false, err 155 } 156 return true, nil 157 }