gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/seccheck/sinks/remote/server/server.go (about) 1 // Copyright 2022 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package server provides a common server implementation that can connect with 16 // remote.Remote. 17 package server 18 19 import ( 20 "errors" 21 "fmt" 22 "io" 23 "os" 24 25 "golang.org/x/sys/unix" 26 "google.golang.org/protobuf/proto" 27 "gvisor.dev/gvisor/pkg/cleanup" 28 "gvisor.dev/gvisor/pkg/log" 29 pb "gvisor.dev/gvisor/pkg/sentry/seccheck/points/points_go_proto" 30 "gvisor.dev/gvisor/pkg/sentry/seccheck/sinks/remote/wire" 31 "gvisor.dev/gvisor/pkg/sync" 32 "gvisor.dev/gvisor/pkg/unet" 33 ) 34 35 // ClientHandler is used to interface with client that connect to the server. 36 type ClientHandler interface { 37 // NewClient is called when a new client connects to the server. It returns 38 // a handler that will be bound to the client. 39 NewClient() (MessageHandler, error) 40 } 41 42 // MessageHandler is used to process messages from a client. 43 type MessageHandler interface { 44 // Message processes a single message. raw contains the entire unparsed 45 // message. hdr is the parser message header and payload is the unparsed 46 // message data. 47 Message(raw []byte, hdr wire.Header, payload []byte) error 48 49 // Version returns what wire version of the protocol is supported. 50 Version() uint32 51 52 // Close closes the handler. 53 Close() 54 } 55 56 type client struct { 57 socket *unet.Socket 58 handler MessageHandler 59 } 60 61 func (c client) close() { 62 _ = c.socket.Close() 63 c.handler.Close() 64 } 65 66 // CommonServer provides common functionality to connect and process messages 67 // from different clients. Implementors decide how clients and messages are 68 // handled, e.g. counting messages for testing. 69 type CommonServer struct { 70 // Endpoint is the path to the socket that the server listens to. 71 Endpoint string 72 73 socket *unet.ServerSocket 74 75 handler ClientHandler 76 77 cond sync.Cond 78 79 // +checklocks:cond.L 80 clients []client 81 } 82 83 // Init initializes the server. It must be called before it is used. 84 func (s *CommonServer) Init(path string, handler ClientHandler) { 85 s.Endpoint = path 86 s.handler = handler 87 s.cond = sync.Cond{L: &sync.Mutex{}} 88 } 89 90 // Start creates the socket file and listens for new connections. 91 func (s *CommonServer) Start() error { 92 socket, err := unix.Socket(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0) 93 if err != nil { 94 return fmt.Errorf("socket(AF_UNIX, SOCK_SEQPACKET, 0): %w", err) 95 } 96 cu := cleanup.Make(func() { 97 _ = unix.Close(socket) 98 }) 99 defer cu.Clean() 100 101 sa := &unix.SockaddrUnix{Name: s.Endpoint} 102 if err := unix.Bind(socket, sa); err != nil { 103 return fmt.Errorf("bind(%q): %w", s.Endpoint, err) 104 } 105 106 s.socket, err = unet.NewServerSocket(socket) 107 if err != nil { 108 return err 109 } 110 cu.Add(func() { s.socket.Close() }) 111 112 if err := s.socket.Listen(); err != nil { 113 return err 114 } 115 116 go s.run() 117 cu.Release() 118 return nil 119 } 120 121 func (s *CommonServer) run() { 122 for { 123 socket, err := s.socket.Accept() 124 if err != nil { 125 // EBADF returns when the socket closes. 126 if !errors.Is(err, unix.EBADF) { 127 log.Warningf("socket.Accept(): %v", err) 128 } 129 return 130 } 131 msgHandler, err := s.handler.NewClient() 132 if err != nil { 133 log.Warningf("handler.NewClient: %v", err) 134 return 135 } 136 client := client{ 137 socket: socket, 138 handler: msgHandler, 139 } 140 s.cond.L.Lock() 141 s.clients = append(s.clients, client) 142 s.cond.Broadcast() 143 s.cond.L.Unlock() 144 145 if err := s.handshake(client); err != nil { 146 log.Warningf(err.Error()) 147 s.closeClient(client) 148 continue 149 } 150 go s.handleClient(client) 151 } 152 } 153 154 // handshake performs version exchange with client. See common.proto for details 155 // about the protocol. 156 func (s *CommonServer) handshake(client client) error { 157 var in [1024]byte 158 read, err := client.socket.Read(in[:]) 159 if err != nil { 160 return fmt.Errorf("reading handshake message: %w", err) 161 } 162 hsIn := pb.Handshake{} 163 if err := proto.Unmarshal(in[:read], &hsIn); err != nil { 164 return fmt.Errorf("unmarshalling handshake message: %w", err) 165 } 166 if hsIn.Version != wire.CurrentVersion { 167 return fmt.Errorf("wrong version number, want: %d, got, %d", wire.CurrentVersion, hsIn.Version) 168 } 169 170 hsOut := pb.Handshake{Version: client.handler.Version()} 171 out, err := proto.Marshal(&hsOut) 172 if err != nil { 173 return fmt.Errorf("marshalling handshake message: %w", err) 174 } 175 if _, err := client.socket.Write(out); err != nil { 176 return fmt.Errorf("sending handshake message: %w", err) 177 } 178 return nil 179 } 180 181 func (s *CommonServer) handleClient(client client) { 182 defer s.closeClient(client) 183 184 var buf = make([]byte, 1024*1024) 185 for { 186 read, err := client.socket.Read(buf) 187 if err != nil { 188 if errors.Is(err, io.EOF) || errors.Is(err, unix.EBADF) { 189 // Both errors indicate that the socket has been closed. 190 return 191 } 192 panic(err) 193 } 194 if read < wire.HeaderStructSize { 195 panic("message too small") 196 } 197 hdr := wire.Header{} 198 hdr.UnmarshalUnsafe(buf[0:wire.HeaderStructSize]) 199 if read < int(hdr.HeaderSize) { 200 panic(fmt.Sprintf("message truncated, header size: %d, read: %d", hdr.HeaderSize, read)) 201 } 202 if err := client.handler.Message(buf[:read], hdr, buf[hdr.HeaderSize:read]); err != nil { 203 panic(err) 204 } 205 } 206 } 207 208 func (s *CommonServer) closeClient(client client) { 209 client.close() 210 211 // Stop tracking this client. 212 s.cond.L.Lock() 213 for i, c := range s.clients { 214 if c == client { 215 s.clients = append(s.clients[:i], s.clients[i+1:]...) 216 break 217 } 218 } 219 s.cond.Broadcast() 220 s.cond.L.Unlock() 221 } 222 223 // Close stops listening and closes all connections. 224 func (s *CommonServer) Close() { 225 if s.socket != nil { 226 _ = s.socket.Close() 227 } 228 s.cond.L.Lock() 229 for _, client := range s.clients { 230 client.close() 231 } 232 s.clients = nil 233 s.cond.Broadcast() 234 s.cond.L.Unlock() 235 _ = os.Remove(s.Endpoint) 236 } 237 238 // WaitForNoClients waits until the number of clients connected reaches 0. 239 func (s *CommonServer) WaitForNoClients() { 240 s.cond.L.Lock() 241 defer s.cond.L.Unlock() 242 for len(s.clients) > 0 { 243 s.cond.Wait() 244 } 245 }