github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/backedtcp/backed.go (about) 1 package backedtcp 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "errors" 7 "io" 8 "net" 9 "strings" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 pool "github.com/libp2p/go-buffer-pool" 15 "gopkg.in/tomb.v1" 16 ) 17 18 const maxBufferSize = 500 * 1000 19 20 type backedWriter struct { 21 lastsn uint64 22 buffer []byte 23 lk sync.Mutex 24 } 25 26 func (br *backedWriter) addData(ob []byte) { 27 b := make([]byte, len(ob)) 28 copy(b, ob) 29 br.lastsn += uint64(len(b)) 30 br.buffer = append(br.buffer, b...) 31 //log.Println("addData buffer size now", cap(br.buffer)) 32 if len(br.buffer) >= maxBufferSize { 33 br.buffer = br.buffer[len(br.buffer)-maxBufferSize:] 34 } 35 } 36 37 func (br *backedWriter) reset() { 38 if len(br.buffer) > 10*1000 { 39 newlen := len(br.buffer) / 2 40 nbuf := make([]byte, newlen) 41 copy(nbuf, br.buffer[len(br.buffer)-newlen:]) 42 br.buffer = nbuf 43 } 44 } 45 46 func (br *backedWriter) since(sn uint64) []byte { 47 if sn == br.lastsn { 48 return make([]byte, 0) 49 } 50 if sn > br.lastsn || sn < br.lastsn-uint64(len(br.buffer)) { 51 return nil 52 } 53 return br.buffer[len(br.buffer)-int(br.lastsn-sn):] 54 } 55 56 // Socket represents a single BackedTCP connection 57 type Socket struct { 58 bw backedWriter 59 getWire func() (net.Conn, error) 60 cachedWires chan net.Conn 61 chWrite chan []byte 62 chRead chan []byte 63 chReplace chan struct{} 64 readBuf bytes.Buffer 65 readBytes uint64 66 death tomb.Tomb 67 remAddr atomic.Value 68 locAddr atomic.Value 69 70 rDeadline atomic.Value 71 wDeadline atomic.Value 72 } 73 74 // NewSocket constructs a new BackedTCP connection. 75 func NewSocket(getWire func() (net.Conn, error)) *Socket { 76 s := &Socket{ 77 getWire: getWire, 78 chWrite: make(chan []byte), 79 cachedWires: make(chan net.Conn, 10000), 80 chRead: make(chan []byte), 81 chReplace: make(chan struct{}), 82 } 83 s.SetDeadline(time.Time{}) 84 go s.mainLoop() 85 return s 86 } 87 88 func (sock *Socket) realGetWire() (net.Conn, error) { 89 select { 90 case c := <-sock.cachedWires: 91 return c, nil 92 default: 93 return sock.getWire() 94 } 95 } 96 97 func (sock *Socket) mainLoop() { 98 for { 99 select { 100 case <-sock.death.Dying(): 101 return 102 default: 103 } 104 // first we get a wire 105 wire, err := sock.realGetWire() 106 if err != nil { 107 // this is fatal 108 sock.death.Kill(err) 109 return 110 } 111 wra := wire.RemoteAddr() 112 wla := wire.LocalAddr() 113 sock.remAddr.Store(&wra) 114 sock.locAddr.Store(&wla) 115 stopWrite := make(chan struct{}) 116 // negotiation shouldn't take more than 10 secs 117 wire.SetDeadline(time.Now().Add(time.Second * 10)) 118 sent := make(chan bool) 119 // negotiate 120 go func() { 121 defer close(sent) 122 // we write our total bytes read. in a new goroutine to prevent dedlock 123 binary.Write(wire, binary.BigEndian, sock.readBytes) 124 }() 125 // read the remote bytes read 126 var theirReadBytes uint64 127 err = binary.Read(wire, binary.BigEndian, &theirReadBytes) 128 if err != nil { 129 wire.Close() 130 continue 131 } 132 <-sent 133 // get the data that needs to be resent 134 toResend := sock.bw.since(theirReadBytes) 135 if toResend == nil { 136 // out of range 137 sock.death.Kill(errors.New("out of resumption range")) 138 return 139 } 140 wire.SetDeadline(time.Time{}) 141 done := make(chan bool) 142 go func() { 143 defer close(done) 144 defer close(stopWrite) 145 sock.readLoop(wire) 146 }() 147 sock.writeLoop(toResend, wire, stopWrite) 148 <-done 149 } 150 } 151 152 func (sock *Socket) writeLoop(toResend []byte, wire net.Conn, stopWrite chan struct{}) { 153 defer wire.Close() 154 wire.SetWriteDeadline(sock.wDeadline.Load().(time.Time)) 155 _, err := wire.Write(toResend) 156 if err != nil { 157 return 158 } 159 for { 160 var timeout <-chan time.Time 161 if sock.bw.buffer != nil { 162 timeout = time.After(time.Second * 10) 163 } 164 select { 165 case toWrite := <-sock.chWrite: 166 // first we remember this so that we can restore 167 sock.bw.addData(toWrite) 168 wire.SetWriteDeadline(sock.wDeadline.Load().(time.Time)) 169 // then we try to write. it's okay if we fail! 170 _, err := wire.Write(toWrite) 171 pool.GlobalPool.Put(toWrite) 172 if err != nil { 173 if strings.Contains(err.Error(), "timeout") { 174 sock.death.Kill(err) 175 } 176 return 177 } 178 case <-timeout: 179 sock.bw.reset() 180 case <-stopWrite: 181 //log.Println("writeLoop stopped") 182 return 183 case <-sock.chReplace: 184 //log.Println("writeLoop stopped for replace") 185 return 186 case <-sock.death.Dying(): 187 //log.Println("writeLoop forced to die", sock.death.Err()) 188 return 189 } 190 } 191 } 192 193 func (sock *Socket) readLoop(wire net.Conn) { 194 defer wire.Close() 195 // just loop and read and feed into the channel 196 for { 197 wire.SetReadDeadline(sock.rDeadline.Load().(time.Time)) 198 buf := pool.GlobalPool.Get(65536) 199 n, err := wire.Read(buf) 200 if err != nil { 201 return 202 } 203 sock.readBytes += uint64(n) 204 sock.chRead <- buf[:n] 205 } 206 } 207 208 // Reset forces the socket to discard its underlying connection and reconnect. 209 func (sock *Socket) Reset() (err error) { 210 wire, err := sock.getWire() 211 if err != nil { 212 return 213 } 214 sock.cachedWires <- wire 215 select { 216 case sock.chReplace <- struct{}{}: 217 return 218 case <-sock.death.Dying(): 219 err = sock.death.Err() 220 return 221 } 222 } 223 224 // Close closes the socket. 225 func (sock *Socket) Close() (err error) { 226 sock.death.Kill(io.ErrClosedPipe) 227 return 228 } 229 230 func (sock *Socket) Read(p []byte) (n int, err error) { 231 for { 232 if sock.readBuf.Len() > 0 { 233 return sock.readBuf.Read(p) 234 } 235 select { 236 case <-sock.death.Dying(): 237 err = sock.death.Err() 238 return 239 case bts := <-sock.chRead: 240 sock.readBuf.Write(bts) 241 pool.GlobalPool.Put(bts) 242 } 243 } 244 } 245 246 func (sock *Socket) Write(p []byte) (n int, err error) { 247 buf := pool.GlobalPool.Get(len(p)) 248 copy(buf, p) 249 select { 250 case sock.chWrite <- buf: 251 n = len(p) 252 return 253 case <-sock.death.Dying(): 254 err = sock.death.Err() 255 return 256 } 257 } 258 259 func (sock *Socket) LocalAddr() net.Addr { 260 zz := sock.locAddr.Load() 261 if zz == nil { 262 return dummyAddr("dummy-local") 263 } 264 return *(zz.(*net.Addr)) 265 } 266 267 func (sock *Socket) RemoteAddr() net.Addr { 268 zz := sock.remAddr.Load() 269 if zz == nil { 270 return dummyAddr("dummy-remote") 271 } 272 return *(zz.(*net.Addr)) 273 } 274 275 func (sock *Socket) SetDeadline(t time.Time) error { 276 sock.SetReadDeadline(t) 277 sock.SetWriteDeadline(t) 278 return nil 279 } 280 281 func (sock *Socket) SetReadDeadline(t time.Time) error { 282 sock.rDeadline.Store(t) 283 return nil 284 } 285 286 func (sock *Socket) SetWriteDeadline(t time.Time) error { 287 sock.wDeadline.Store(t) 288 return nil 289 } 290 291 type dummyAddr string 292 293 func (da dummyAddr) String() string { 294 return string(da) 295 } 296 297 func (da dummyAddr) Network() string { 298 return string(da) 299 }