github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/yuubinsya/crypto/handshake.go (about) 1 package crypto 2 3 import ( 4 "crypto/cipher" 5 "crypto/ecdh" 6 "crypto/rand" 7 "encoding/binary" 8 "errors" 9 "fmt" 10 "io" 11 "math" 12 "net" 13 "time" 14 15 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 16 "github.com/Asutorufa/yuhaiin/pkg/net/proxy/socks5/tools" 17 "github.com/Asutorufa/yuhaiin/pkg/net/proxy/yuubinsya/types" 18 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 19 "golang.org/x/crypto/chacha20" 20 "golang.org/x/crypto/hkdf" 21 ) 22 23 type encryptedHandshaker struct { 24 server bool 25 26 signer types.Signer 27 hash types.Hash 28 aead types.Aead 29 password []byte 30 } 31 32 func (t *encryptedHandshaker) EncodeHeader(net types.Protocol, buf types.Buffer, addr netapi.Address) { 33 _, _ = buf.Write([]byte{byte(net)}) 34 35 if net == types.TCP { 36 tools.EncodeAddr(addr, buf) 37 } 38 } 39 40 func (t *encryptedHandshaker) DecodeHeader(c net.Conn) (types.Protocol, error) { 41 z := make([]byte, 1) 42 43 if _, err := io.ReadFull(c, z); err != nil { 44 return 0, fmt.Errorf("read net type failed: %w", err) 45 } 46 net := types.Protocol(z[0]) 47 48 if net.Unknown() { 49 return 0, fmt.Errorf("unknown network") 50 } 51 52 return net, nil 53 } 54 55 func (h *encryptedHandshaker) Handshake(conn net.Conn) (net.Conn, error) { 56 if h.server { 57 return h.handshakeServer(conn) 58 } 59 60 return h.handshakeClient(conn) 61 } 62 63 func (h *encryptedHandshaker) handshakeClient(conn net.Conn) (net.Conn, error) { 64 header := newHeader(h) 65 defer header.Def() 66 67 salt := make([]byte, h.hash.Size()) 68 69 pk, time1, err := h.send(header, conn, nil) 70 if err != nil { 71 return nil, err 72 } 73 74 copy(salt, header.salt()) // client salt 75 76 rpb, time2, err := h.receive(header, conn, salt) 77 if err != nil { 78 return nil, err 79 } 80 81 if pk.PublicKey().Equal(rpb) { 82 return nil, fmt.Errorf("look like replay attack") 83 } 84 85 cryptKey, err := pk.ECDH(rpb) 86 if err != nil { 87 return nil, err 88 } 89 90 raead, rnonce, err := h.newAead(cryptKey, salt, time1) 91 if err != nil { 92 return nil, err 93 } 94 95 waead, wnonce, err := h.newAead(cryptKey, salt, time2) 96 if err != nil { 97 return nil, err 98 } 99 100 return NewConn(conn, rnonce, wnonce, raead, waead), nil 101 } 102 103 func (h *encryptedHandshaker) handshakeServer(conn net.Conn) (net.Conn, error) { 104 header := newHeader(h) 105 defer header.Def() 106 107 salt := make([]byte, h.hash.Size()) 108 109 rpb, time1, err := h.receive(header, conn, nil) 110 if err != nil { 111 return nil, err 112 } 113 114 copy(salt, header.salt()) // client salt 115 116 pk, time2, err := h.send(header, conn, salt) 117 if err != nil { 118 return nil, err 119 } 120 121 if pk.PublicKey().Equal(rpb) { 122 return nil, fmt.Errorf("look like replay attack") 123 } 124 125 cryptKey, err := pk.ECDH(rpb) 126 if err != nil { 127 return nil, err 128 } 129 130 raead, rnonce, err := h.newAead(cryptKey, salt, time1) 131 if err != nil { 132 return nil, err 133 } 134 135 waead, wnonce, err := h.newAead(cryptKey, salt, time2) 136 if err != nil { 137 return nil, err 138 } 139 140 return NewConn(conn, wnonce, rnonce, waead, raead), nil 141 } 142 143 func (h *encryptedHandshaker) newAead(cryptKey, salt, time []byte) (cipher.AEAD, []byte, error) { 144 keyNonce := make([]byte, h.aead.KeySize()+h.aead.NonceSize()) 145 if _, err := io.ReadFull(hkdf.New(h.hash.New, cryptKey, salt, append(h.aead.Name(), time...)), keyNonce); err != nil { 146 return nil, nil, err 147 } 148 aead, err := h.aead.New(keyNonce[:h.aead.KeySize()]) 149 if err != nil { 150 return nil, nil, err 151 } 152 153 return aead, keyNonce[h.aead.KeySize():], nil 154 } 155 156 func (h *encryptedHandshaker) receive(buf *header, conn net.Conn, salt []byte) (_ *ecdh.PublicKey, ttime []byte, _ error) { 157 _, err := io.ReadFull(conn, buf.Bytes()) 158 if err != nil { 159 return nil, nil, err 160 } 161 162 if salt != nil { 163 copy(buf.salt(), salt) // client: verify signature with client salt 164 } 165 166 if !h.signer.Verify(buf.saltTimeSignature(), buf.signature()) { 167 return nil, nil, errors.New("can't verify signature") 168 } 169 170 ttime = make([]byte, 8) 171 if err = h.encryptTime(h.password, buf.salt(), ttime, buf.time()); err != nil { 172 return nil, nil, fmt.Errorf("decrypt time failed: %w", err) 173 } 174 175 if math.Abs(float64(time.Now().Unix()-int64(binary.BigEndian.Uint64(ttime)))) > 30 { // check time is in +-30s 176 return nil, nil, errors.New("bad timestamp") 177 } 178 179 pubkey, err := ecdh.P256().NewPublicKey(buf.publickey()) 180 if err != nil { 181 return nil, nil, err 182 } 183 184 return pubkey, ttime, nil 185 } 186 187 func (h *encryptedHandshaker) send(buf *header, conn net.Conn, salt []byte) (_ *ecdh.PrivateKey, ttime []byte, _ error) { 188 pk, err := ecdh.P256().GenerateKey(rand.Reader) 189 if err != nil { 190 return nil, nil, err 191 } 192 193 if salt != nil { 194 copy(buf.salt(), salt) // server: sign with client salt 195 } else { 196 if _, err = rand.Read(buf.salt()); err != nil { // client: read random bytes to salt 197 return nil, nil, fmt.Errorf("read salt from rand failed: %w", err) 198 } 199 } 200 201 copy(buf.publickey(), pk.PublicKey().Bytes()) 202 203 ttime = make([]byte, 8) 204 binary.BigEndian.PutUint64(ttime, uint64(time.Now().Unix())) 205 206 if err = h.encryptTime(h.password, buf.salt(), buf.time(), ttime); err != nil { 207 return nil, nil, fmt.Errorf("encrypt time failed: %w", err) 208 } 209 210 signature, err := h.signer.Sign(rand.Reader, buf.saltTimeSignature()) 211 if err != nil { 212 return nil, nil, err 213 } 214 215 copy(buf.signature(), signature) 216 217 if salt != nil { 218 if _, err := rand.Read(buf.salt()); err != nil { // server: read random bytes to padding 219 return nil, nil, fmt.Errorf("read salt from rand failed: %w", err) 220 } 221 } 222 223 if _, err = conn.Write(buf.Bytes()); err != nil { 224 return nil, nil, err 225 } 226 227 return pk, ttime, nil 228 } 229 230 type header struct { 231 bytes *pool.Bytes 232 th *encryptedHandshaker 233 } 234 235 func newHeader(h *encryptedHandshaker) *header { 236 return &header{pool.GetBytesBuffer(h.hash.Size() + 8 + h.signer.SignatureSize() + 65), h} 237 } 238 func (h *header) Bytes() []byte { return h.bytes.Bytes() } 239 func (h *header) signature() []byte { 240 return h.Bytes()[:h.th.signer.SignatureSize()] 241 } 242 func (h *header) publickey() []byte { 243 return h.Bytes()[h.th.hash.Size()+8+h.th.signer.SignatureSize():] 244 } 245 func (h *header) time() []byte { 246 return h.Bytes()[h.th.hash.Size()+h.th.signer.SignatureSize() : h.th.hash.Size()+8+h.th.signer.SignatureSize()] 247 } 248 func (h *header) salt() []byte { 249 return h.Bytes()[h.th.signer.SignatureSize() : h.th.signer.SignatureSize()+h.th.hash.Size()] 250 } 251 func (h *header) saltTimeSignature() []byte { 252 return h.Bytes()[h.th.signer.SignatureSize():] 253 } 254 func (h *header) Def() { defer h.bytes.Free() } 255 256 func (h *encryptedHandshaker) encryptTime(password, salt, dst, src []byte) error { 257 nonce := make([]byte, chacha20.NonceSize) 258 key := make([]byte, chacha20.KeySize) 259 260 kdf := hkdf.New(h.hash.New, password, salt, []byte{'t', 'i', 'm', 'e'}) 261 262 if _, err := io.ReadFull(kdf, key); err != nil { 263 return err 264 } 265 if _, err := io.ReadFull(kdf, nonce); err != nil { 266 return err 267 } 268 269 cipher, err := chacha20.NewUnauthenticatedCipher(key, nonce) 270 if err != nil { 271 return err 272 } 273 274 cipher.XORKeyStream(dst, src) 275 276 return nil 277 } 278 279 func NewHandshaker(server bool, hash []byte, password []byte) *encryptedHandshaker { 280 // sha256-hkdf-ecdh-ed25519-chacha20poly1305 281 return &encryptedHandshaker{ 282 signer: NewEd25519(Sha256, hash), 283 hash: Sha256, 284 aead: Chacha20poly1305, 285 password: password, 286 server: server, 287 } 288 }