github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/link/sharedmem/sharedmem_server.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 //go:build linux 16 // +build linux 17 18 package sharedmem 19 20 import ( 21 "github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops" 22 "github.com/nicocha30/gvisor-ligolo/pkg/buffer" 23 "github.com/nicocha30/gvisor-ligolo/pkg/sync" 24 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip" 25 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/header" 26 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/link/rawfile" 27 "github.com/nicocha30/gvisor-ligolo/pkg/tcpip/stack" 28 ) 29 30 type serverEndpoint struct { 31 // mtu (maximum transmission unit) is the maximum size of a packet. 32 // mtu is immutable. 33 mtu uint32 34 35 // bufferSize is the size of each individual buffer. 36 // bufferSize is immutable. 37 bufferSize uint32 38 39 // addr is the local address of this endpoint. 40 // addr is immutable 41 addr tcpip.LinkAddress 42 43 // rx is the receive queue. 44 rx serverRx 45 46 // stopRequested determines whether the worker goroutines should stop. 47 stopRequested atomicbitops.Uint32 48 49 // Wait group used to indicate that all workers have stopped. 50 completed sync.WaitGroup 51 52 // peerFD is an fd to the peer that can be used to detect when the peer is 53 // gone. 54 // peerFD is immutable. 55 peerFD int 56 57 // caps holds the endpoint capabilities. 58 caps stack.LinkEndpointCapabilities 59 60 // hdrSize is the size of the link layer header if any. 61 // hdrSize is immutable. 62 hdrSize uint32 63 64 // virtioNetHeaderRequired if true indicates that a virtio header is expected 65 // in all inbound/outbound packets. 66 virtioNetHeaderRequired bool 67 68 // onClosed is a function to be called when the FD's peer (if any) closes its 69 // end of the communication pipe. 70 onClosed func(tcpip.Error) 71 72 // mu protects the following fields. 73 mu sync.Mutex 74 75 // tx is the transmit queue. 76 // +checklocks:mu 77 tx serverTx 78 79 // workerStarted specifies whether the worker goroutine was started. 80 // +checklocks:mu 81 workerStarted bool 82 } 83 84 // NewServerEndpoint creates a new shared-memory-based endpoint. Buffers will be 85 // broken up into buffers of "bufferSize" bytes. 86 func NewServerEndpoint(opts Options) (stack.LinkEndpoint, error) { 87 e := &serverEndpoint{ 88 mtu: opts.MTU, 89 bufferSize: opts.BufferSize, 90 addr: opts.LinkAddress, 91 peerFD: opts.PeerFD, 92 onClosed: opts.OnClosed, 93 } 94 95 if err := e.tx.init(&opts.RX); err != nil { 96 return nil, err 97 } 98 99 if err := e.rx.init(&opts.TX); err != nil { 100 e.tx.cleanup() 101 return nil, err 102 } 103 104 e.caps = stack.LinkEndpointCapabilities(0) 105 if opts.RXChecksumOffload { 106 e.caps |= stack.CapabilityRXChecksumOffload 107 } 108 109 if opts.TXChecksumOffload { 110 e.caps |= stack.CapabilityTXChecksumOffload 111 } 112 113 if opts.LinkAddress != "" { 114 e.hdrSize = header.EthernetMinimumSize 115 e.caps |= stack.CapabilityResolutionRequired 116 } 117 118 return e, nil 119 } 120 121 // Close frees all resources associated with the endpoint. 122 func (e *serverEndpoint) Close() { 123 // Tell dispatch goroutine to stop, then write to the eventfd so that it wakes 124 // up in case it's sleeping. 125 e.stopRequested.Store(1) 126 e.rx.eventFD.Notify() 127 128 // Cleanup the queues inline if the worker hasn't started yet; we also know it 129 // won't start from now on because stopRequested is set to 1. 130 e.mu.Lock() 131 defer e.mu.Unlock() 132 workerPresent := e.workerStarted 133 134 if !workerPresent { 135 e.tx.cleanup() 136 e.rx.cleanup() 137 } 138 } 139 140 // Wait implements stack.LinkEndpoint.Wait. It waits until all workers have 141 // stopped after a Close() call. 142 func (e *serverEndpoint) Wait() { 143 e.completed.Wait() 144 } 145 146 // Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that 147 // reads packets from the rx queue. 148 func (e *serverEndpoint) Attach(dispatcher stack.NetworkDispatcher) { 149 e.mu.Lock() 150 if !e.workerStarted && e.stopRequested.Load() == 0 { 151 e.workerStarted = true 152 e.completed.Add(1) 153 if e.peerFD >= 0 { 154 e.completed.Add(1) 155 // Spin up a goroutine to monitor for peer shutdown. 156 go func() { 157 b := make([]byte, 1) 158 // When sharedmem endpoint is in use the peerFD is never used for any 159 // data transfer and this Read should only return if the peer is 160 // shutting down. 161 _, err := rawfile.BlockingRead(e.peerFD, b) 162 if e.onClosed != nil { 163 e.onClosed(err) 164 } 165 e.completed.Done() 166 }() 167 } 168 // Link endpoints are not savable. When transportation endpoints are saved, 169 // they stop sending outgoing packets and all incoming packets are rejected. 170 go e.dispatchLoop(dispatcher) // S/R-SAFE: see above. 171 } 172 e.mu.Unlock() 173 } 174 175 // IsAttached implements stack.LinkEndpoint.IsAttached. 176 func (e *serverEndpoint) IsAttached() bool { 177 e.mu.Lock() 178 defer e.mu.Unlock() 179 return e.workerStarted 180 } 181 182 // MTU implements stack.LinkEndpoint.MTU. It returns the value initialized 183 // during construction. 184 func (e *serverEndpoint) MTU() uint32 { 185 return e.mtu 186 } 187 188 // Capabilities implements stack.LinkEndpoint.Capabilities. 189 func (e *serverEndpoint) Capabilities() stack.LinkEndpointCapabilities { 190 return e.caps 191 } 192 193 // MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the 194 // ethernet frame header size. 195 func (e *serverEndpoint) MaxHeaderLength() uint16 { 196 return uint16(e.hdrSize) 197 } 198 199 // LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local 200 // link address. 201 func (e *serverEndpoint) LinkAddress() tcpip.LinkAddress { 202 return e.addr 203 } 204 205 // AddHeader implements stack.LinkEndpoint.AddHeader. 206 func (e *serverEndpoint) AddHeader(pkt stack.PacketBufferPtr) { 207 // Add ethernet header if needed. 208 if len(e.addr) == 0 { 209 return 210 } 211 212 eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) 213 eth.Encode(&header.EthernetFields{ 214 SrcAddr: pkt.EgressRoute.LocalLinkAddress, 215 DstAddr: pkt.EgressRoute.RemoteLinkAddress, 216 Type: pkt.NetworkProtocolNumber, 217 }) 218 } 219 220 func (e *serverEndpoint) parseHeader(pkt stack.PacketBufferPtr) bool { 221 _, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) 222 return ok 223 } 224 225 // ParseHeader implements stack.LinkEndpoint.ParseHeader. 226 func (e *serverEndpoint) ParseHeader(pkt stack.PacketBufferPtr) bool { 227 // Add ethernet header if needed. 228 if len(e.addr) == 0 { 229 return true 230 } 231 232 return e.parseHeader(pkt) 233 } 234 235 func (e *serverEndpoint) AddVirtioNetHeader(pkt stack.PacketBufferPtr) { 236 virtio := header.VirtioNetHeader(pkt.VirtioNetHeader().Push(header.VirtioNetHeaderSize)) 237 virtio.Encode(&header.VirtioNetHeaderFields{}) 238 } 239 240 // +checklocks:e.mu 241 func (e *serverEndpoint) writePacketLocked(r stack.RouteInfo, protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) tcpip.Error { 242 if e.virtioNetHeaderRequired { 243 e.AddVirtioNetHeader(pkt) 244 } 245 246 ok := e.tx.transmit(pkt) 247 if !ok { 248 return &tcpip.ErrWouldBlock{} 249 } 250 251 return nil 252 } 253 254 // WritePacket writes outbound packets to the file descriptor. If it is not 255 // currently writable, the packet is dropped. 256 // WritePacket implements stack.LinkEndpoint.WritePacket. 257 func (e *serverEndpoint) WritePacket(_ stack.RouteInfo, _ tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) tcpip.Error { 258 // Transmit the packet. 259 e.mu.Lock() 260 defer e.mu.Unlock() 261 if err := e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { 262 return err 263 } 264 e.tx.notify() 265 return nil 266 } 267 268 // WritePackets implements stack.LinkEndpoint.WritePackets. 269 func (e *serverEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { 270 n := 0 271 var err tcpip.Error 272 e.mu.Lock() 273 defer e.mu.Unlock() 274 for _, pkt := range pkts.AsSlice() { 275 if err = e.writePacketLocked(pkt.EgressRoute, pkt.NetworkProtocolNumber, pkt); err != nil { 276 break 277 } 278 n++ 279 } 280 // WritePackets never returns an error if it successfully transmitted at least 281 // one packet. 282 if err != nil && n == 0 { 283 return 0, err 284 } 285 e.tx.notify() 286 return n, nil 287 } 288 289 // dispatchLoop reads packets from the rx queue in a loop and dispatches them 290 // to the network stack. 291 func (e *serverEndpoint) dispatchLoop(d stack.NetworkDispatcher) { 292 for e.stopRequested.Load() == 0 { 293 b := e.rx.receive() 294 if b == nil { 295 e.rx.EnableNotification() 296 // Now pull again to make sure we didn't receive any packets 297 // while notifications were not enabled. 298 for { 299 b = e.rx.receive() 300 if b != nil { 301 // Disable notifications as we only need to be notified when we are going 302 // to block on eventFD. This should prevent the peer from needlessly 303 // writing to eventFD when this end is already awake and processing 304 // packets. 305 e.rx.DisableNotification() 306 break 307 } 308 e.rx.waitForPackets() 309 } 310 } 311 pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ 312 Payload: buffer.MakeWithView(b), 313 }) 314 if e.virtioNetHeaderRequired { 315 _, ok := pkt.VirtioNetHeader().Consume(header.VirtioNetHeaderSize) 316 if !ok { 317 pkt.DecRef() 318 continue 319 } 320 } 321 var proto tcpip.NetworkProtocolNumber 322 if len(e.addr) != 0 { 323 if !e.parseHeader(pkt) { 324 pkt.DecRef() 325 continue 326 } 327 proto = header.Ethernet(pkt.LinkHeader().Slice()).Type() 328 } else { 329 // We don't get any indication of what the packet is, so try to guess 330 // if it's an IPv4 or IPv6 packet. 331 // IP version information is at the first octet, so pulling up 1 byte. 332 h, ok := pkt.Data().PullUp(1) 333 if !ok { 334 pkt.DecRef() 335 continue 336 } 337 switch header.IPVersion(h) { 338 case header.IPv4Version: 339 proto = header.IPv4ProtocolNumber 340 case header.IPv6Version: 341 proto = header.IPv6ProtocolNumber 342 default: 343 pkt.DecRef() 344 continue 345 } 346 } 347 // Send packet up the stack. 348 d.DeliverNetworkPacket(proto, pkt) 349 pkt.DecRef() 350 } 351 352 e.mu.Lock() 353 defer e.mu.Unlock() 354 355 // Clean state. 356 e.tx.cleanup() 357 e.rx.cleanup() 358 359 e.completed.Done() 360 } 361 362 // ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType 363 func (e *serverEndpoint) ARPHardwareType() header.ARPHardwareType { 364 if e.hdrSize > 0 { 365 return header.ARPHardwareEther 366 } 367 return header.ARPHardwareNone 368 }