github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/transport/tcp/dispatcher.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tcp 16 17 import ( 18 "encoding/binary" 19 "math/rand" 20 21 "github.com/SagerNet/gvisor/pkg/sleep" 22 "github.com/SagerNet/gvisor/pkg/sync" 23 "github.com/SagerNet/gvisor/pkg/tcpip" 24 "github.com/SagerNet/gvisor/pkg/tcpip/hash/jenkins" 25 "github.com/SagerNet/gvisor/pkg/tcpip/header" 26 "github.com/SagerNet/gvisor/pkg/tcpip/stack" 27 ) 28 29 // epQueue is a queue of endpoints. 30 type epQueue struct { 31 mu sync.Mutex 32 list endpointList 33 } 34 35 // enqueue adds e to the queue if the endpoint is not already on the queue. 36 func (q *epQueue) enqueue(e *endpoint) { 37 q.mu.Lock() 38 if e.pendingProcessing { 39 q.mu.Unlock() 40 return 41 } 42 q.list.PushBack(e) 43 e.pendingProcessing = true 44 q.mu.Unlock() 45 } 46 47 // dequeue removes and returns the first element from the queue if available, 48 // returns nil otherwise. 49 func (q *epQueue) dequeue() *endpoint { 50 q.mu.Lock() 51 if e := q.list.Front(); e != nil { 52 q.list.Remove(e) 53 e.pendingProcessing = false 54 q.mu.Unlock() 55 return e 56 } 57 q.mu.Unlock() 58 return nil 59 } 60 61 // empty returns true if the queue is empty, false otherwise. 62 func (q *epQueue) empty() bool { 63 q.mu.Lock() 64 v := q.list.Empty() 65 q.mu.Unlock() 66 return v 67 } 68 69 // processor is responsible for processing packets queued to a tcp endpoint. 70 type processor struct { 71 epQ epQueue 72 sleeper sleep.Sleeper 73 newEndpointWaker sleep.Waker 74 closeWaker sleep.Waker 75 } 76 77 func (p *processor) close() { 78 p.closeWaker.Assert() 79 } 80 81 func (p *processor) queueEndpoint(ep *endpoint) { 82 // Queue an endpoint for processing by the processor goroutine. 83 p.epQ.enqueue(ep) 84 p.newEndpointWaker.Assert() 85 } 86 87 const ( 88 newEndpointWaker = 1 89 closeWaker = 2 90 ) 91 92 func (p *processor) start(wg *sync.WaitGroup) { 93 defer wg.Done() 94 defer p.sleeper.Done() 95 96 for { 97 if id, _ := p.sleeper.Fetch(true); id == closeWaker { 98 break 99 } 100 for { 101 ep := p.epQ.dequeue() 102 if ep == nil { 103 break 104 } 105 if ep.segmentQueue.empty() { 106 continue 107 } 108 109 // If socket has transitioned out of connected state then just let the 110 // worker handle the packet. 111 // 112 // NOTE: We read this outside of e.mu lock which means that by the time 113 // we get to handleSegments the endpoint may not be in ESTABLISHED. But 114 // this should be fine as all normal shutdown states are handled by 115 // handleSegments and if the endpoint moves to a CLOSED/ERROR state 116 // then handleSegments is a noop. 117 if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { 118 // If the endpoint is in a connected state then we do direct delivery 119 // to ensure low latency and avoid scheduler interactions. 120 switch err := ep.handleSegmentsLocked(true /* fastPath */); { 121 case err != nil: 122 // Send any active resets if required. 123 ep.resetConnectionLocked(err) 124 fallthrough 125 case ep.EndpointState() == StateClose: 126 ep.notifyProtocolGoroutine(notifyTickleWorker) 127 case !ep.segmentQueue.empty(): 128 p.epQ.enqueue(ep) 129 } 130 ep.mu.Unlock() // +checklocksforce 131 } else { 132 ep.newSegmentWaker.Assert() 133 } 134 } 135 } 136 } 137 138 // dispatcher manages a pool of TCP endpoint processors which are responsible 139 // for the processing of inbound segments. This fixed pool of processor 140 // goroutines do full tcp processing. The processor is selected based on the 141 // hash of the endpoint id to ensure that delivery for the same endpoint happens 142 // in-order. 143 type dispatcher struct { 144 processors []processor 145 // seed is a random secret for a jenkins hash. 146 seed uint32 147 wg sync.WaitGroup 148 } 149 150 func (d *dispatcher) init(rng *rand.Rand, nProcessors int) { 151 d.close() 152 d.wait() 153 d.processors = make([]processor, nProcessors) 154 d.seed = rng.Uint32() 155 for i := range d.processors { 156 p := &d.processors[i] 157 p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) 158 p.sleeper.AddWaker(&p.closeWaker, closeWaker) 159 d.wg.Add(1) 160 // NB: sleeper-waker registration must happen synchronously to avoid races 161 // with `close`. It's possible to pull all this logic into `start`, but 162 // that results in a heap-allocated function literal. 163 go p.start(&d.wg) 164 } 165 } 166 167 func (d *dispatcher) close() { 168 for i := range d.processors { 169 d.processors[i].close() 170 } 171 } 172 173 func (d *dispatcher) wait() { 174 d.wg.Wait() 175 } 176 177 func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, clock tcpip.Clock, pkt *stack.PacketBuffer) { 178 ep := stackEP.(*endpoint) 179 180 s := newIncomingSegment(id, clock, pkt) 181 if !s.parse(pkt.RXTransportChecksumValidated) { 182 ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() 183 ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() 184 s.decRef() 185 return 186 } 187 188 if !s.csumValid { 189 ep.stack.Stats().TCP.ChecksumErrors.Increment() 190 ep.stats.ReceiveErrors.ChecksumErrors.Increment() 191 s.decRef() 192 return 193 } 194 195 ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() 196 ep.stats.SegmentsReceived.Increment() 197 if (s.flags & header.TCPFlagRst) != 0 { 198 ep.stack.Stats().TCP.ResetsReceived.Increment() 199 } 200 201 if !ep.enqueueSegment(s) { 202 s.decRef() 203 return 204 } 205 206 // For sockets not in established state let the worker goroutine 207 // handle the packets. 208 if ep.EndpointState() != StateEstablished { 209 ep.newSegmentWaker.Assert() 210 return 211 } 212 213 d.selectProcessor(id).queueEndpoint(ep) 214 } 215 216 func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { 217 var payload [4]byte 218 binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) 219 binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) 220 221 h := jenkins.Sum32(d.seed) 222 h.Write(payload[:]) 223 h.Write([]byte(id.LocalAddress)) 224 h.Write([]byte(id.RemoteAddress)) 225 226 return &d.processors[h.Sum32()%uint32(len(d.processors))] 227 }