github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/uds/uds.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package uds contains helpers for testing external UDS functionality.
    16  package uds
    17  
    18  import (
    19  	"fmt"
    20  	"io"
    21  	"io/ioutil"
    22  	"os"
    23  	"path/filepath"
    24  
    25  	"golang.org/x/sys/unix"
    26  	"github.com/SagerNet/gvisor/pkg/log"
    27  	"github.com/SagerNet/gvisor/pkg/unet"
    28  )
    29  
    30  // createEchoSocket creates a socket that echoes back anything received.
    31  //
    32  // Only works for stream, seqpacket sockets.
    33  func createEchoSocket(path string, protocol int) (cleanup func(), err error) {
    34  	fd, err := unix.Socket(unix.AF_UNIX, protocol, 0)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err)
    37  	}
    38  
    39  	if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil {
    40  		return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err)
    41  	}
    42  
    43  	if err := unix.Listen(fd, 0); err != nil {
    44  		return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err)
    45  	}
    46  
    47  	server, err := unet.NewServerSocket(fd)
    48  	if err != nil {
    49  		return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err)
    50  	}
    51  
    52  	acceptAndEchoOne := func() error {
    53  		s, err := server.Accept()
    54  		if err != nil {
    55  			return fmt.Errorf("failed to accept: %v", err)
    56  		}
    57  		defer s.Close()
    58  
    59  		for {
    60  			buf := make([]byte, 512)
    61  			for {
    62  				n, err := s.Read(buf)
    63  				if err == io.EOF {
    64  					return nil
    65  				}
    66  				if err != nil {
    67  					return fmt.Errorf("failed to read: %d, %v", n, err)
    68  				}
    69  
    70  				n, err = s.Write(buf[:n])
    71  				if err != nil {
    72  					return fmt.Errorf("failed to write: %d, %v", n, err)
    73  				}
    74  			}
    75  		}
    76  	}
    77  
    78  	go func() {
    79  		for {
    80  			if err := acceptAndEchoOne(); err != nil {
    81  				log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err)
    82  				return
    83  			}
    84  		}
    85  	}()
    86  
    87  	cleanup = func() {
    88  		if err := server.Close(); err != nil {
    89  			log.Warningf("Failed to close echo(%d) socket: %v", protocol, err)
    90  		}
    91  	}
    92  
    93  	return cleanup, nil
    94  }
    95  
    96  // createNonListeningSocket creates a socket that is bound but not listening.
    97  //
    98  // Only relevant for stream, seqpacket sockets.
    99  func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) {
   100  	fd, err := unix.Socket(unix.AF_UNIX, protocol, 0)
   101  	if err != nil {
   102  		return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err)
   103  	}
   104  
   105  	if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil {
   106  		return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err)
   107  	}
   108  
   109  	cleanup = func() {
   110  		if err := unix.Close(fd); err != nil {
   111  			log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err)
   112  		}
   113  	}
   114  
   115  	return cleanup, nil
   116  }
   117  
   118  // createNullSocket creates a socket that reads anything received.
   119  //
   120  // Only works for dgram sockets.
   121  func createNullSocket(path string, protocol int) (cleanup func(), err error) {
   122  	fd, err := unix.Socket(unix.AF_UNIX, protocol, 0)
   123  	if err != nil {
   124  		return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err)
   125  	}
   126  
   127  	if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil {
   128  		return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err)
   129  	}
   130  
   131  	s, err := unet.NewSocket(fd)
   132  	if err != nil {
   133  		return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err)
   134  	}
   135  
   136  	go func() {
   137  		buf := make([]byte, 512)
   138  		for {
   139  			n, err := s.Read(buf)
   140  			if err != nil {
   141  				log.Warningf("failed to read: %d, %v", n, err)
   142  				return
   143  			}
   144  		}
   145  	}()
   146  
   147  	cleanup = func() {
   148  		if err := s.Close(); err != nil {
   149  			log.Warningf("Failed to close null(%d) socket: %v", protocol, err)
   150  		}
   151  	}
   152  
   153  	return cleanup, nil
   154  }
   155  
   156  type socketCreator func(path string, proto int) (cleanup func(), err error)
   157  
   158  // CreateSocketTree creates a local tree of unix domain sockets for use in
   159  // testing:
   160  //  * /stream/echo
   161  //  * /stream/nonlistening
   162  //  * /seqpacket/echo
   163  //  * /seqpacket/nonlistening
   164  //  * /dgram/null
   165  func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) {
   166  	dir, err = ioutil.TempDir(baseDir, "sockets")
   167  	if err != nil {
   168  		return "", nil, fmt.Errorf("error creating temp dir: %v", err)
   169  	}
   170  
   171  	var protocols = []struct {
   172  		protocol int
   173  		name     string
   174  		sockets  map[string]socketCreator
   175  	}{
   176  		{
   177  			protocol: unix.SOCK_STREAM,
   178  			name:     "stream",
   179  			sockets: map[string]socketCreator{
   180  				"echo":         createEchoSocket,
   181  				"nonlistening": createNonListeningSocket,
   182  			},
   183  		},
   184  		{
   185  			protocol: unix.SOCK_SEQPACKET,
   186  			name:     "seqpacket",
   187  			sockets: map[string]socketCreator{
   188  				"echo":         createEchoSocket,
   189  				"nonlistening": createNonListeningSocket,
   190  			},
   191  		},
   192  		{
   193  			protocol: unix.SOCK_DGRAM,
   194  			name:     "dgram",
   195  			sockets: map[string]socketCreator{
   196  				"null": createNullSocket,
   197  			},
   198  		},
   199  	}
   200  
   201  	var cleanups []func()
   202  	for _, proto := range protocols {
   203  		protoDir := filepath.Join(dir, proto.name)
   204  		if err := os.Mkdir(protoDir, 0755); err != nil {
   205  			return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err)
   206  		}
   207  
   208  		for name, fn := range proto.sockets {
   209  			path := filepath.Join(protoDir, name)
   210  			cleanup, err := fn(path, proto.protocol)
   211  			if err != nil {
   212  				return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err)
   213  			}
   214  
   215  			cleanups = append(cleanups, cleanup)
   216  		}
   217  	}
   218  
   219  	cleanup = func() {
   220  		for _, c := range cleanups {
   221  			c()
   222  		}
   223  
   224  		os.RemoveAll(dir)
   225  	}
   226  
   227  	return dir, cleanup, nil
   228  }