github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/uds/uds.go (about) 1 // Copyright 2019 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package uds contains helpers for testing external UDS functionality. 16 package uds 17 18 import ( 19 "fmt" 20 "io" 21 "io/ioutil" 22 "os" 23 "path/filepath" 24 25 "golang.org/x/sys/unix" 26 "github.com/SagerNet/gvisor/pkg/log" 27 "github.com/SagerNet/gvisor/pkg/unet" 28 ) 29 30 // createEchoSocket creates a socket that echoes back anything received. 31 // 32 // Only works for stream, seqpacket sockets. 33 func createEchoSocket(path string, protocol int) (cleanup func(), err error) { 34 fd, err := unix.Socket(unix.AF_UNIX, protocol, 0) 35 if err != nil { 36 return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err) 37 } 38 39 if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil { 40 return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err) 41 } 42 43 if err := unix.Listen(fd, 0); err != nil { 44 return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err) 45 } 46 47 server, err := unet.NewServerSocket(fd) 48 if err != nil { 49 return nil, fmt.Errorf("error creating echo(%d) unet socket: %v", protocol, err) 50 } 51 52 acceptAndEchoOne := func() error { 53 s, err := server.Accept() 54 if err != nil { 55 return fmt.Errorf("failed to accept: %v", err) 56 } 57 defer s.Close() 58 59 for { 60 buf := make([]byte, 512) 61 for { 62 n, err := s.Read(buf) 63 if err == io.EOF { 64 return nil 65 } 66 if err != nil { 67 return fmt.Errorf("failed to read: %d, %v", n, err) 68 } 69 70 n, err = s.Write(buf[:n]) 71 if err != nil { 72 return fmt.Errorf("failed to write: %d, %v", n, err) 73 } 74 } 75 } 76 } 77 78 go func() { 79 for { 80 if err := acceptAndEchoOne(); err != nil { 81 log.Warningf("Failed to handle echo(%d) socket: %v", protocol, err) 82 return 83 } 84 } 85 }() 86 87 cleanup = func() { 88 if err := server.Close(); err != nil { 89 log.Warningf("Failed to close echo(%d) socket: %v", protocol, err) 90 } 91 } 92 93 return cleanup, nil 94 } 95 96 // createNonListeningSocket creates a socket that is bound but not listening. 97 // 98 // Only relevant for stream, seqpacket sockets. 99 func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) { 100 fd, err := unix.Socket(unix.AF_UNIX, protocol, 0) 101 if err != nil { 102 return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err) 103 } 104 105 if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil { 106 return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err) 107 } 108 109 cleanup = func() { 110 if err := unix.Close(fd); err != nil { 111 log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err) 112 } 113 } 114 115 return cleanup, nil 116 } 117 118 // createNullSocket creates a socket that reads anything received. 119 // 120 // Only works for dgram sockets. 121 func createNullSocket(path string, protocol int) (cleanup func(), err error) { 122 fd, err := unix.Socket(unix.AF_UNIX, protocol, 0) 123 if err != nil { 124 return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err) 125 } 126 127 if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil { 128 return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err) 129 } 130 131 s, err := unet.NewSocket(fd) 132 if err != nil { 133 return nil, fmt.Errorf("error creating null(%d) unet socket: %v", protocol, err) 134 } 135 136 go func() { 137 buf := make([]byte, 512) 138 for { 139 n, err := s.Read(buf) 140 if err != nil { 141 log.Warningf("failed to read: %d, %v", n, err) 142 return 143 } 144 } 145 }() 146 147 cleanup = func() { 148 if err := s.Close(); err != nil { 149 log.Warningf("Failed to close null(%d) socket: %v", protocol, err) 150 } 151 } 152 153 return cleanup, nil 154 } 155 156 type socketCreator func(path string, proto int) (cleanup func(), err error) 157 158 // CreateSocketTree creates a local tree of unix domain sockets for use in 159 // testing: 160 // * /stream/echo 161 // * /stream/nonlistening 162 // * /seqpacket/echo 163 // * /seqpacket/nonlistening 164 // * /dgram/null 165 func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { 166 dir, err = ioutil.TempDir(baseDir, "sockets") 167 if err != nil { 168 return "", nil, fmt.Errorf("error creating temp dir: %v", err) 169 } 170 171 var protocols = []struct { 172 protocol int 173 name string 174 sockets map[string]socketCreator 175 }{ 176 { 177 protocol: unix.SOCK_STREAM, 178 name: "stream", 179 sockets: map[string]socketCreator{ 180 "echo": createEchoSocket, 181 "nonlistening": createNonListeningSocket, 182 }, 183 }, 184 { 185 protocol: unix.SOCK_SEQPACKET, 186 name: "seqpacket", 187 sockets: map[string]socketCreator{ 188 "echo": createEchoSocket, 189 "nonlistening": createNonListeningSocket, 190 }, 191 }, 192 { 193 protocol: unix.SOCK_DGRAM, 194 name: "dgram", 195 sockets: map[string]socketCreator{ 196 "null": createNullSocket, 197 }, 198 }, 199 } 200 201 var cleanups []func() 202 for _, proto := range protocols { 203 protoDir := filepath.Join(dir, proto.name) 204 if err := os.Mkdir(protoDir, 0755); err != nil { 205 return "", nil, fmt.Errorf("error creating %s dir: %v", proto.name, err) 206 } 207 208 for name, fn := range proto.sockets { 209 path := filepath.Join(protoDir, name) 210 cleanup, err := fn(path, proto.protocol) 211 if err != nil { 212 return "", nil, fmt.Errorf("error creating %s %s socket: %v", proto.name, name, err) 213 } 214 215 cleanups = append(cleanups, cleanup) 216 } 217 } 218 219 cleanup = func() { 220 for _, c := range cleanups { 221 c() 222 } 223 224 os.RemoveAll(dir) 225 } 226 227 return dir, cleanup, nil 228 }