github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/tinyss/socket.go (about) 1 package tinyss 2 3 import ( 4 "bufio" 5 "bytes" 6 "crypto/cipher" 7 "crypto/hmac" 8 "crypto/sha256" 9 "encoding/binary" 10 "io" 11 "net" 12 "time" 13 14 "github.com/geph-official/geph2/libs/c25519" 15 pool "github.com/libp2p/go-buffer-pool" 16 "golang.org/x/crypto/chacha20poly1305" 17 "golang.org/x/crypto/curve25519" 18 ) 19 20 // Socket represents a TinySS connection; it implements net.Conn but with more methods. 21 type Socket struct { 22 rxctr uint64 23 rxerr error 24 rxcrypt cipher.AEAD 25 rxbuf bytes.Buffer 26 27 txctr uint64 28 txcrypt cipher.AEAD 29 30 plain net.Conn 31 plainBuffered *bufio.Reader 32 sharedsec []byte 33 34 nextprot byte 35 } 36 37 func hm(m, k []byte) []byte { 38 h := hmac.New(sha256.New, k) 39 h.Write(m) 40 return h.Sum(nil) 41 } 42 43 func aead(key []byte) cipher.AEAD { 44 k, e := chacha20poly1305.New(key) 45 if e != nil { 46 panic(e) 47 } 48 return k 49 } 50 51 func newSocket(plain net.Conn, repk, lesk [32]byte) (sok *Socket) { 52 // calc 53 var lepk [32]byte 54 curve25519.ScalarBaseMult(&lepk, &lesk) 55 // calculate shared secrets 56 var sharedsec [32]byte 57 curve25519.ScalarMult(&sharedsec, &lesk, &repk) 58 s1 := hm(sharedsec[:], []byte("tinyss-s1")) 59 s2 := hm(sharedsec[:], []byte("tinyss-s2")) 60 // derive keys 61 var rxkey []byte 62 var txkey []byte 63 if bytes.Compare(lepk[:], repk[:]) < 0 { 64 rxkey = s1 65 txkey = s2 66 } else { 67 txkey = s1 68 rxkey = s2 69 } 70 // create socket 71 sok = &Socket{ 72 rxcrypt: aead(rxkey), 73 txcrypt: aead(txkey), 74 plain: plain, 75 sharedsec: sharedsec[:], 76 plainBuffered: bufio.NewReader(plain), 77 } 78 79 return 80 } 81 82 var decctr1 uint64 83 84 // NextProt returns the "next protocol" signal given by the remote. 85 func (sk *Socket) NextProt() byte { 86 return sk.nextprot 87 } 88 89 // Read reads into the given byte slice. 90 func (sk *Socket) Read(p []byte) (n int, err error) { 91 // if any in buffer, read from buffer 92 if sk.rxbuf.Len() > 0 { 93 return sk.rxbuf.Read(p) 94 } 95 // if error exists, return it 96 err = sk.rxerr 97 if err != nil { 98 return 99 } 100 // otherwise wait for record 101 lenbts := pool.GlobalPool.Get(2) 102 defer pool.GlobalPool.Put(lenbts) 103 _, err = io.ReadFull(sk.plainBuffered, lenbts) 104 if err != nil { 105 sk.rxerr = err 106 return 107 } 108 ciph := pool.GlobalPool.Get(int(binary.BigEndian.Uint16(lenbts))) 109 defer pool.GlobalPool.Put(ciph) 110 _, err = io.ReadFull(sk.plainBuffered, ciph) 111 if err != nil { 112 sk.rxerr = err 113 return 114 } 115 // decrypt the ciphertext 116 nonce := pool.GlobalPool.Get(sk.rxcrypt.NonceSize()) 117 for i := range nonce { 118 nonce[i] = 0 119 } 120 defer pool.GlobalPool.Put(nonce) 121 binary.BigEndian.PutUint64(nonce, sk.rxctr) 122 sk.rxctr++ 123 data, err := sk.rxcrypt.Open(ciph[:0], nonce, ciph, nil) 124 if err != nil { 125 sk.rxerr = err 126 return 127 } 128 // copy the data into the buffer 129 n = copy(p, data) 130 if n < len(data) { 131 sk.rxbuf.Write(data[n:]) 132 } 133 return 134 } 135 136 // Write writes out the given byte slice. No guarantees are made regarding the number of low-level segments sent over the wire. 137 func (sk *Socket) Write(p []byte) (n int, err error) { 138 if len(p) > 32768 { 139 // recurse 140 var n1 int 141 var n2 int 142 n1, err = sk.Write(p[:32768]) 143 if err != nil { 144 return 145 } 146 n2, err = sk.Write(p[32768:]) 147 if err != nil { 148 return 149 } 150 n = n1 + n2 151 return 152 } 153 // main work here 154 backing := pool.GlobalPool.Get(sk.txcrypt.Overhead() + 2 + len(p)) 155 nonce := pool.GlobalPool.Get(sk.txcrypt.NonceSize()) 156 defer pool.GlobalPool.Put(nonce) 157 for i := range nonce { 158 nonce[i] = 0 159 } 160 binary.BigEndian.PutUint64(nonce, sk.txctr) 161 sk.txctr++ 162 ciph := sk.txcrypt.Seal(backing[2:][:0], nonce, p, nil) 163 binary.BigEndian.PutUint16(backing[:2], uint16(len(ciph))) 164 _, err = sk.plain.Write(backing) 165 n = len(p) 166 return 167 } 168 169 // Close closes the socket. 170 func (sk *Socket) Close() error { 171 return sk.plain.Close() 172 } 173 174 // LocalAddr returns the local address. 175 func (sk *Socket) LocalAddr() net.Addr { 176 return sk.plain.LocalAddr() 177 } 178 179 // RemoteAddr returns the remote address. 180 func (sk *Socket) RemoteAddr() net.Addr { 181 return sk.plain.RemoteAddr() 182 } 183 184 // SetDeadline sets the deadline. 185 func (sk *Socket) SetDeadline(t time.Time) error { 186 return sk.plain.SetDeadline(t) 187 } 188 189 // SetReadDeadline sets the read deadline. 190 func (sk *Socket) SetReadDeadline(t time.Time) error { 191 return sk.plain.SetReadDeadline(t) 192 } 193 194 // SetWriteDeadline sets the write deadline. 195 func (sk *Socket) SetWriteDeadline(t time.Time) error { 196 return sk.plain.SetWriteDeadline(t) 197 } 198 199 // SharedSec returns the shared secret. Use this to authenticate the connection (through signing etc). 200 func (sk *Socket) SharedSec() []byte { 201 return sk.sharedsec 202 } 203 204 // Handshake upgrades a plaintext socket to a MiniSS socket, given our secret key. 205 func Handshake(plain net.Conn, nextProtocol byte) (sok *Socket, err error) { 206 // generate ephemeral key 207 myesk := c25519.GenSK() 208 // in another thread, send over hello 209 wet := make(chan bool) 210 go func() { 211 var msgb bytes.Buffer 212 // if nextProtocol isn't zero, we send a different protocol header 213 if nextProtocol == 0 { 214 msgb.Write([]byte("TinySS-1")) 215 } else { 216 msgb.Write([]byte("TinySS-2")) 217 } 218 var pub [32]byte 219 curve25519.ScalarBaseMult(&pub, &myesk) 220 msgb.Write(pub[:]) 221 io.Copy(plain, &msgb) 222 close(wet) 223 }() 224 // read hello 225 bts := make([]byte, 32+8) 226 _, err = io.ReadFull(plain, bts) 227 if err != nil { 228 return 229 } 230 // check version 231 if string(bts[:7]) != "TinySS-" { 232 err = io.ErrClosedPipe 233 return 234 } 235 <-wet 236 // read rest of hello 237 var repk [32]byte 238 copy(repk[:], bts[8:][:32]) 239 ns := newSocket(plain, repk, myesk) 240 wait := make(chan bool) 241 if nextProtocol != 0 { 242 go func() { 243 binary.Write(ns, binary.BigEndian, nextProtocol) 244 close(wait) 245 }() 246 } 247 switch string(bts[:8]) { 248 case "TinySS-1": 249 case "TinySS-2": 250 // then we wait for their next protocol 251 var theirNextProt byte 252 err = binary.Read(ns, binary.BigEndian, &theirNextProt) 253 if err != nil { 254 return 255 } 256 ns.nextprot = theirNextProt 257 } 258 sok = ns 259 if nextProtocol != 0 { 260 <-wait 261 } 262 return 263 }