github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/cshirt2/transport.go (about) 1 package cshirt2 2 3 import ( 4 "bufio" 5 "bytes" 6 "crypto/cipher" 7 "encoding/binary" 8 "fmt" 9 "io" 10 "net" 11 "time" 12 13 "github.com/geph-official/geph2/libs/erand" 14 pool "github.com/libp2p/go-buffer-pool" 15 "golang.org/x/crypto/chacha20poly1305" 16 ) 17 18 type transport struct { 19 readCrypt cipher.AEAD 20 readNonce uint64 21 writeCrypt cipher.AEAD 22 writeNonce uint64 23 wireBuf *bufio.Reader 24 wire net.Conn 25 readbuf bytes.Buffer 26 27 buf [128]byte 28 } 29 30 const maxWriteSize = 16384 31 32 func (tp *transport) getiWriteNonce() uint64 { 33 n := tp.writeNonce 34 tp.writeNonce++ 35 return n 36 } 37 38 func (tp *transport) getiReadNonce() uint64 { 39 n := tp.readNonce 40 tp.readNonce++ 41 return n 42 } 43 44 func addPadding(appendTo, data []byte) []byte { 45 padAmount := erand.Int(200) 46 appendTo = append(appendTo, byte(padAmount)) 47 appendTo = append(appendTo, make([]byte, padAmount)...) 48 appendTo = append(appendTo, data...) 49 return appendTo 50 //return data 51 } 52 53 func remPadding(data []byte) []byte { 54 if len(data) > 0 && len(data) > int(data[0]) { 55 return data[data[0]+1:] 56 } 57 return data 58 } 59 60 func (tp *transport) writeSegment(unpadded []byte) (err error) { 61 toWrite := pool.Get(len(unpadded) + 256)[:0] 62 defer pool.Put(toWrite) 63 toWrite = addPadding(toWrite, unpadded) 64 lengthNonce := tp.buf[:12] 65 binary.LittleEndian.PutUint64(lengthNonce, tp.getiWriteNonce()) 66 bodyNonce := tp.buf[12:24] 67 binary.LittleEndian.PutUint64(bodyNonce, tp.getiWriteNonce()) 68 length := tp.buf[24:26] 69 binary.LittleEndian.PutUint16(length, uint16(len(toWrite)+tp.readCrypt.Overhead())) 70 buffer := pool.Get(maxWriteSize + 128)[:0] 71 defer pool.Put(buffer) 72 buffer = tp.writeCrypt.Seal(buffer, lengthNonce, length, nil) 73 buffer = tp.writeCrypt.Seal(buffer, bodyNonce, toWrite, nil) 74 _, err = tp.wire.Write(buffer) 75 return 76 } 77 78 func (tp *transport) Write(p []byte) (n int, err error) { 79 ptr := p 80 for len(ptr) > maxWriteSize { 81 err = tp.writeSegment(ptr[:maxWriteSize]) 82 if err != nil { 83 return 84 } 85 ptr = ptr[maxWriteSize:] 86 } 87 err = tp.writeSegment(ptr) 88 if err != nil { 89 return 90 } 91 n = len(p) 92 return 93 } 94 95 func (tp *transport) Read(p []byte) (n int, err error) { 96 for tp.readbuf.Len() == 0 { 97 cryptLength := tp.buf[64:][:2+tp.readCrypt.Overhead()] 98 _, err = io.ReadFull(tp.wireBuf, cryptLength) 99 if err != nil { 100 err = fmt.Errorf("can't read encrypted length: %w", err) 101 return 102 } 103 nonce := tp.buf[32:][:12] 104 binary.LittleEndian.PutUint64(nonce, tp.getiReadNonce()) 105 // decrypt length 106 var length []byte 107 length, err = tp.readCrypt.Open(p[:0], nonce, cryptLength, nil) 108 if err != nil { 109 err = fmt.Errorf("can't decrypt length: %w", err) 110 return 111 } 112 // read body 113 ctext := pool.Get(int(binary.LittleEndian.Uint16(length))) 114 defer pool.Put(ctext) 115 _, err = io.ReadFull(tp.wireBuf, ctext) 116 if err != nil { 117 err = fmt.Errorf("can't read body: %w", err) 118 return 119 } 120 binary.LittleEndian.PutUint64(nonce, tp.getiReadNonce()) 121 var ptext []byte 122 ptext, err = tp.readCrypt.Open(ctext[:0], nonce, ctext, nil) 123 if err != nil { 124 err = fmt.Errorf("can't decrypt body: %w", err) 125 return 126 } 127 tp.readbuf.Write(remPadding(ptext)) 128 } 129 return tp.readbuf.Read(p) 130 } 131 132 func (tp *transport) Close() error { 133 return tp.wire.Close() 134 } 135 136 func (tp *transport) LocalAddr() net.Addr { 137 return tp.wire.LocalAddr() 138 } 139 140 func (tp *transport) RemoteAddr() net.Addr { 141 return tp.wire.RemoteAddr() 142 } 143 144 func (tp *transport) SetDeadline(t time.Time) error { 145 return tp.wire.SetDeadline(t) 146 } 147 148 func (tp *transport) SetWriteDeadline(t time.Time) error { 149 return tp.wire.SetWriteDeadline(t) 150 } 151 152 func (tp *transport) SetReadDeadline(t time.Time) error { 153 return tp.wire.SetReadDeadline(t) 154 } 155 156 func newTransport(wire net.Conn, ss []byte, isServer bool) *transport { 157 tp := new(transport) 158 readKey := mac256(ss, []byte("c2s")) 159 writeKey := mac256(ss, []byte("c2c")) 160 if !isServer { 161 readKey, writeKey = writeKey, readKey 162 } 163 var err error 164 tp.readCrypt, err = chacha20poly1305.New(readKey) 165 if err != nil { 166 panic(err) 167 } 168 tp.writeCrypt, err = chacha20poly1305.New(writeKey) 169 if err != nil { 170 panic(err) 171 } 172 tp.wire = wire 173 tp.wireBuf = bufio.NewReader(wire) 174 return tp 175 }