gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/seccheck/sinks/remote/server/server.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package server provides a common server implementation that can connect with
    16  // remote.Remote.
    17  package server
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"os"
    24  
    25  	"golang.org/x/sys/unix"
    26  	"google.golang.org/protobuf/proto"
    27  	"gvisor.dev/gvisor/pkg/cleanup"
    28  	"gvisor.dev/gvisor/pkg/log"
    29  	pb "gvisor.dev/gvisor/pkg/sentry/seccheck/points/points_go_proto"
    30  	"gvisor.dev/gvisor/pkg/sentry/seccheck/sinks/remote/wire"
    31  	"gvisor.dev/gvisor/pkg/sync"
    32  	"gvisor.dev/gvisor/pkg/unet"
    33  )
    34  
    35  // ClientHandler is used to interface with client that connect to the server.
    36  type ClientHandler interface {
    37  	// NewClient is called when a new client connects to the server. It returns
    38  	// a handler that will be bound to the client.
    39  	NewClient() (MessageHandler, error)
    40  }
    41  
    42  // MessageHandler is used to process messages from a client.
    43  type MessageHandler interface {
    44  	// Message processes a single message. raw contains the entire unparsed
    45  	// message. hdr is the parser message header and payload is the unparsed
    46  	// message data.
    47  	Message(raw []byte, hdr wire.Header, payload []byte) error
    48  
    49  	// Version returns what wire version of the protocol is supported.
    50  	Version() uint32
    51  
    52  	// Close closes the handler.
    53  	Close()
    54  }
    55  
    56  type client struct {
    57  	socket  *unet.Socket
    58  	handler MessageHandler
    59  }
    60  
    61  func (c client) close() {
    62  	_ = c.socket.Close()
    63  	c.handler.Close()
    64  }
    65  
    66  // CommonServer provides common functionality to connect and process messages
    67  // from different clients. Implementors decide how clients and messages are
    68  // handled, e.g. counting messages for testing.
    69  type CommonServer struct {
    70  	// Endpoint is the path to the socket that the server listens to.
    71  	Endpoint string
    72  
    73  	socket *unet.ServerSocket
    74  
    75  	handler ClientHandler
    76  
    77  	cond sync.Cond
    78  
    79  	// +checklocks:cond.L
    80  	clients []client
    81  }
    82  
    83  // Init initializes the server. It must be called before it is used.
    84  func (s *CommonServer) Init(path string, handler ClientHandler) {
    85  	s.Endpoint = path
    86  	s.handler = handler
    87  	s.cond = sync.Cond{L: &sync.Mutex{}}
    88  }
    89  
    90  // Start creates the socket file and listens for new connections.
    91  func (s *CommonServer) Start() error {
    92  	socket, err := unix.Socket(unix.AF_UNIX, unix.SOCK_SEQPACKET, 0)
    93  	if err != nil {
    94  		return fmt.Errorf("socket(AF_UNIX, SOCK_SEQPACKET, 0): %w", err)
    95  	}
    96  	cu := cleanup.Make(func() {
    97  		_ = unix.Close(socket)
    98  	})
    99  	defer cu.Clean()
   100  
   101  	sa := &unix.SockaddrUnix{Name: s.Endpoint}
   102  	if err := unix.Bind(socket, sa); err != nil {
   103  		return fmt.Errorf("bind(%q): %w", s.Endpoint, err)
   104  	}
   105  
   106  	s.socket, err = unet.NewServerSocket(socket)
   107  	if err != nil {
   108  		return err
   109  	}
   110  	cu.Add(func() { s.socket.Close() })
   111  
   112  	if err := s.socket.Listen(); err != nil {
   113  		return err
   114  	}
   115  
   116  	go s.run()
   117  	cu.Release()
   118  	return nil
   119  }
   120  
   121  func (s *CommonServer) run() {
   122  	for {
   123  		socket, err := s.socket.Accept()
   124  		if err != nil {
   125  			// EBADF returns when the socket closes.
   126  			if !errors.Is(err, unix.EBADF) {
   127  				log.Warningf("socket.Accept(): %v", err)
   128  			}
   129  			return
   130  		}
   131  		msgHandler, err := s.handler.NewClient()
   132  		if err != nil {
   133  			log.Warningf("handler.NewClient: %v", err)
   134  			return
   135  		}
   136  		client := client{
   137  			socket:  socket,
   138  			handler: msgHandler,
   139  		}
   140  		s.cond.L.Lock()
   141  		s.clients = append(s.clients, client)
   142  		s.cond.Broadcast()
   143  		s.cond.L.Unlock()
   144  
   145  		if err := s.handshake(client); err != nil {
   146  			log.Warningf(err.Error())
   147  			s.closeClient(client)
   148  			continue
   149  		}
   150  		go s.handleClient(client)
   151  	}
   152  }
   153  
   154  // handshake performs version exchange with client. See common.proto for details
   155  // about the protocol.
   156  func (s *CommonServer) handshake(client client) error {
   157  	var in [1024]byte
   158  	read, err := client.socket.Read(in[:])
   159  	if err != nil {
   160  		return fmt.Errorf("reading handshake message: %w", err)
   161  	}
   162  	hsIn := pb.Handshake{}
   163  	if err := proto.Unmarshal(in[:read], &hsIn); err != nil {
   164  		return fmt.Errorf("unmarshalling handshake message: %w", err)
   165  	}
   166  	if hsIn.Version != wire.CurrentVersion {
   167  		return fmt.Errorf("wrong version number, want: %d, got, %d", wire.CurrentVersion, hsIn.Version)
   168  	}
   169  
   170  	hsOut := pb.Handshake{Version: client.handler.Version()}
   171  	out, err := proto.Marshal(&hsOut)
   172  	if err != nil {
   173  		return fmt.Errorf("marshalling handshake message: %w", err)
   174  	}
   175  	if _, err := client.socket.Write(out); err != nil {
   176  		return fmt.Errorf("sending handshake message: %w", err)
   177  	}
   178  	return nil
   179  }
   180  
   181  func (s *CommonServer) handleClient(client client) {
   182  	defer s.closeClient(client)
   183  
   184  	var buf = make([]byte, 1024*1024)
   185  	for {
   186  		read, err := client.socket.Read(buf)
   187  		if err != nil {
   188  			if errors.Is(err, io.EOF) || errors.Is(err, unix.EBADF) {
   189  				// Both errors indicate that the socket has been closed.
   190  				return
   191  			}
   192  			panic(err)
   193  		}
   194  		if read < wire.HeaderStructSize {
   195  			panic("message too small")
   196  		}
   197  		hdr := wire.Header{}
   198  		hdr.UnmarshalUnsafe(buf[0:wire.HeaderStructSize])
   199  		if read < int(hdr.HeaderSize) {
   200  			panic(fmt.Sprintf("message truncated, header size: %d, read: %d", hdr.HeaderSize, read))
   201  		}
   202  		if err := client.handler.Message(buf[:read], hdr, buf[hdr.HeaderSize:read]); err != nil {
   203  			panic(err)
   204  		}
   205  	}
   206  }
   207  
   208  func (s *CommonServer) closeClient(client client) {
   209  	client.close()
   210  
   211  	// Stop tracking this client.
   212  	s.cond.L.Lock()
   213  	for i, c := range s.clients {
   214  		if c == client {
   215  			s.clients = append(s.clients[:i], s.clients[i+1:]...)
   216  			break
   217  		}
   218  	}
   219  	s.cond.Broadcast()
   220  	s.cond.L.Unlock()
   221  }
   222  
   223  // Close stops listening and closes all connections.
   224  func (s *CommonServer) Close() {
   225  	if s.socket != nil {
   226  		_ = s.socket.Close()
   227  	}
   228  	s.cond.L.Lock()
   229  	for _, client := range s.clients {
   230  		client.close()
   231  	}
   232  	s.clients = nil
   233  	s.cond.Broadcast()
   234  	s.cond.L.Unlock()
   235  	_ = os.Remove(s.Endpoint)
   236  }
   237  
   238  // WaitForNoClients waits until the number of clients connected reaches 0.
   239  func (s *CommonServer) WaitForNoClients() {
   240  	s.cond.L.Lock()
   241  	defer s.cond.L.Unlock()
   242  	for len(s.clients) > 0 {
   243  		s.cond.Wait()
   244  	}
   245  }