github.com/mirantis/virtlet@v1.5.2-0.20191204181327-1659b8a48e9b/pkg/tapmanager/fdserver_test.go (about) 1 /* 2 Copyright 2017 Mirantis 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tapmanager 18 19 import ( 20 "encoding/json" 21 "errors" 22 "fmt" 23 "io" 24 "io/ioutil" 25 "os" 26 "path/filepath" 27 "testing" 28 ) 29 30 type sampleFDData struct { 31 Content string 32 } 33 34 type sampleFDSource struct { 35 tmpDir string 36 files map[string]*os.File 37 stopped bool 38 } 39 40 var _ FDSource = &sampleFDSource{} 41 42 func newSampleFDSource(tmpDir string) *sampleFDSource { 43 return &sampleFDSource{ 44 tmpDir: tmpDir, 45 files: make(map[string]*os.File), 46 } 47 } 48 49 func (s *sampleFDSource) GetFDs(key string, data []byte) ([]int, []byte, error) { 50 if s.stopped { 51 return nil, nil, errors.New("sampleFDSource is stopped") 52 } 53 54 var fdData sampleFDData 55 if err := json.Unmarshal(data, &fdData); err != nil { 56 return nil, nil, fmt.Errorf("error unmarshalling json: %v", err) 57 } 58 if _, found := s.files[key]; found { 59 return nil, nil, fmt.Errorf("file already exists: %q", key) 60 } 61 filename := filepath.Join(s.tmpDir, key) 62 f, err := os.Create(filename) 63 if err != nil { 64 return nil, nil, fmt.Errorf("error creating file %q: %v", filename, err) 65 } 66 if err := os.Remove(f.Name()); err != nil { 67 f.Close() 68 return nil, nil, fmt.Errorf("Remove(): %v", err) 69 } 70 if _, err := f.Write([]byte(fdData.Content)); err != nil { 71 f.Close() 72 return nil, nil, fmt.Errorf("Write(): %v", err) 73 } 74 if _, err := f.Seek(0, io.SeekStart); err != nil { 75 f.Close() 76 return nil, nil, fmt.Errorf("Seek(): %v", err) 77 } 78 s.files[key] = f 79 return []int{int(f.Fd())}, []byte("abcdef"), nil 80 } 81 82 func (s *sampleFDSource) RetrieveFDs(key string) ([]int, error) { 83 if s.stopped { 84 return nil, errors.New("sampleFDSource is stopped") 85 } 86 87 f, err := os.Open(filepath.Join(s.tmpDir, key)) 88 if err != nil { 89 return nil, err 90 } 91 92 content, err := ioutil.ReadAll(f) 93 if err != nil { 94 return nil, err 95 } 96 97 if string(content) != "42" { 98 return nil, fmt.Errorf("bad data passed to RetrieveFDs: %q", content) 99 } 100 101 if err = os.Remove(filepath.Join(s.tmpDir, key)); err != nil { 102 return nil, err 103 } 104 return []int{int(f.Fd())}, nil 105 } 106 107 func (s *sampleFDSource) Recover(key string, data []byte) error { 108 if s.stopped { 109 return errors.New("sampleFDSource is stopped") 110 } 111 112 var fdData sampleFDData 113 if err := json.Unmarshal(data, &fdData); err != nil { 114 return fmt.Errorf("error unmarshalling json: %v", err) 115 } 116 117 if fdData.Content != "42" { 118 return fmt.Errorf("bad data passed to Recover: %q", data) 119 } 120 121 if _, found := s.files[key]; found { 122 return fmt.Errorf("key %q is already present", key) 123 } 124 125 s.files[key] = nil 126 return ioutil.WriteFile(filepath.Join(s.tmpDir, key), []byte(fdData.Content), 0644) 127 } 128 129 func (s *sampleFDSource) Release(key string) error { 130 if s.stopped { 131 return errors.New("sampleFDSource is stopped") 132 } 133 134 f, found := s.files[key] 135 if !found { 136 return fmt.Errorf("file not found: %q", key) 137 } 138 delete(s.files, key) 139 140 // "recovered" entries don't have FDs 141 if f != nil { 142 if err := f.Close(); err != nil { 143 return fmt.Errorf("can't close file %q: %v", f.Name(), err) 144 } 145 } 146 147 return nil 148 } 149 150 func (s *sampleFDSource) GetInfo(key string) ([]byte, error) { 151 if s.stopped { 152 return nil, errors.New("sampleFDSource is stopped") 153 } 154 155 _, found := s.files[key] 156 if !found { 157 return nil, fmt.Errorf("file not found: %q", key) 158 } 159 return []byte("info_" + key), nil 160 } 161 162 func (s *sampleFDSource) Stop() error { 163 s.stopped = true 164 return nil 165 } 166 167 func (s *sampleFDSource) isEmpty() bool { 168 return len(s.files) == 0 169 } 170 171 func (s *sampleFDSource) isRecovered(key string) bool { 172 f, found := s.files[key] 173 return found && f == nil 174 } 175 176 func verifyFD(t *testing.T, c *FDClient, key string, data string) { 177 fds, info, err := c.GetFDs(key) 178 if err != nil { 179 t.Fatalf("GetFDs(): %v", err) 180 } 181 182 expectedInfo := "info_" + key 183 if string(info) != expectedInfo { 184 t.Errorf("bad info: %q instead of %q", info, expectedInfo) 185 } 186 187 f1 := os.NewFile(uintptr(fds[0]), "acquired-fd") 188 defer f1.Close() 189 190 content, err := ioutil.ReadAll(f1) 191 if err != nil { 192 t.Fatalf("ReadAll(): %v", err) 193 } 194 195 if string(content) != data { 196 t.Errorf("bad content: %q instead of %q", content, data) 197 } 198 } 199 200 func withFDClient(t *testing.T, toCall func(*FDClient, *sampleFDSource)) { 201 tmpDir, err := ioutil.TempDir("", "pass-fd-test") 202 if err != nil { 203 t.Fatalf("ioutil.TempDir(): %v", err) 204 } 205 defer os.RemoveAll(tmpDir) 206 207 socketPath := filepath.Join(tmpDir, "passfd") 208 src := newSampleFDSource(tmpDir) 209 s := NewFDServer(socketPath, src) 210 if err := s.Serve(); err != nil { 211 t.Fatalf("Serve(): %v", err) 212 } 213 defer func() { 214 s.Stop() 215 if !src.stopped { 216 t.Errorf("FDSource not stopped") 217 } 218 }() 219 c := NewFDClient(socketPath) 220 221 toCall(c, src) 222 } 223 224 func TestFDServer(t *testing.T) { 225 withFDClient(t, func(c *FDClient, src *sampleFDSource) { 226 content := []string{"foo", "bar", "baz"} 227 for _, data := range content { 228 var err error 229 key := "k_" + data 230 respData, err := c.AddFDs(key, sampleFDData{Content: data}) 231 if err != nil { 232 t.Fatalf("AddFDs(): %v", err) 233 } 234 expectedRespData := "abcdef" 235 if string(respData) != expectedRespData { 236 t.Errorf("bad data returned from add: %q instead of %q", data, expectedRespData) 237 } 238 } 239 240 for _, data := range content { 241 key := "k_" + data 242 verifyFD(t, c, key, data) 243 } 244 245 for _, data := range content { 246 key := "k_" + data 247 if err := c.ReleaseFDs(key); err != nil { 248 t.Fatalf("ReleaseFD(): key %q: %v", key, err) 249 } 250 } 251 252 // here we make sure that releasing FDs works and also that passing errors from the 253 // server works, too 254 expectedErrorMessage := fmt.Sprintf("server returned error: bad fd key: \"k_foo\"") 255 if _, _, err := c.GetFDs("k_foo"); err == nil { 256 t.Errorf("GetFDs didn't return an error for a released fd") 257 } else if err.Error() != expectedErrorMessage { 258 t.Errorf("Bad error message from GetFD: %q instead of %q", err.Error(), expectedErrorMessage) 259 } 260 261 if !src.isEmpty() { 262 t.Errorf("fd source is not empty (but it should be)") 263 } 264 }) 265 } 266 267 func TestFDServerRecovery(t *testing.T) { 268 withFDClient(t, func(c *FDClient, src *sampleFDSource) { 269 if err := c.Recover("foobar", sampleFDData{"42"}); err != nil { 270 t.Errorf("Recover(): %v", err) 271 } 272 if !src.isRecovered("foobar") { 273 t.Errorf("the key is not recovered") 274 } 275 if err := c.ReleaseFDs("foobar"); err != nil { 276 t.Errorf("Error releasing the recovered FDs: %v", err) 277 } 278 }) 279 } 280 281 func TestFDServerRetrieveFds(t *testing.T) { 282 withFDClient(t, func(c *FDClient, src *sampleFDSource) { 283 if err := c.Recover("foobar", sampleFDData{"42"}); err != nil { 284 t.Errorf("Recover(): %v", err) 285 } 286 if !src.isRecovered("foobar") { 287 t.Errorf("the key is not recovered") 288 } 289 290 if _, err := src.RetrieveFDs("foobar"); err != nil { 291 t.Errorf("failed to RetrieveFDs: %v", err) 292 } 293 294 if err := c.ReleaseFDs("foobar"); err != nil { 295 t.Errorf("Error releasing the recovered FDs: %v", err) 296 } 297 }) 298 }