github.com/mirantis/virtlet@v1.5.2-0.20191204181327-1659b8a48e9b/pkg/tapmanager/fdserver_test.go (about)

     1  /*
     2  Copyright 2017 Mirantis
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tapmanager
    18  
    19  import (
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"os"
    26  	"path/filepath"
    27  	"testing"
    28  )
    29  
    30  type sampleFDData struct {
    31  	Content string
    32  }
    33  
    34  type sampleFDSource struct {
    35  	tmpDir  string
    36  	files   map[string]*os.File
    37  	stopped bool
    38  }
    39  
    40  var _ FDSource = &sampleFDSource{}
    41  
    42  func newSampleFDSource(tmpDir string) *sampleFDSource {
    43  	return &sampleFDSource{
    44  		tmpDir: tmpDir,
    45  		files:  make(map[string]*os.File),
    46  	}
    47  }
    48  
    49  func (s *sampleFDSource) GetFDs(key string, data []byte) ([]int, []byte, error) {
    50  	if s.stopped {
    51  		return nil, nil, errors.New("sampleFDSource is stopped")
    52  	}
    53  
    54  	var fdData sampleFDData
    55  	if err := json.Unmarshal(data, &fdData); err != nil {
    56  		return nil, nil, fmt.Errorf("error unmarshalling json: %v", err)
    57  	}
    58  	if _, found := s.files[key]; found {
    59  		return nil, nil, fmt.Errorf("file already exists: %q", key)
    60  	}
    61  	filename := filepath.Join(s.tmpDir, key)
    62  	f, err := os.Create(filename)
    63  	if err != nil {
    64  		return nil, nil, fmt.Errorf("error creating file %q: %v", filename, err)
    65  	}
    66  	if err := os.Remove(f.Name()); err != nil {
    67  		f.Close()
    68  		return nil, nil, fmt.Errorf("Remove(): %v", err)
    69  	}
    70  	if _, err := f.Write([]byte(fdData.Content)); err != nil {
    71  		f.Close()
    72  		return nil, nil, fmt.Errorf("Write(): %v", err)
    73  	}
    74  	if _, err := f.Seek(0, io.SeekStart); err != nil {
    75  		f.Close()
    76  		return nil, nil, fmt.Errorf("Seek(): %v", err)
    77  	}
    78  	s.files[key] = f
    79  	return []int{int(f.Fd())}, []byte("abcdef"), nil
    80  }
    81  
    82  func (s *sampleFDSource) RetrieveFDs(key string) ([]int, error) {
    83  	if s.stopped {
    84  		return nil, errors.New("sampleFDSource is stopped")
    85  	}
    86  
    87  	f, err := os.Open(filepath.Join(s.tmpDir, key))
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	content, err := ioutil.ReadAll(f)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	if string(content) != "42" {
    98  		return nil, fmt.Errorf("bad data passed to RetrieveFDs: %q", content)
    99  	}
   100  
   101  	if err = os.Remove(filepath.Join(s.tmpDir, key)); err != nil {
   102  		return nil, err
   103  	}
   104  	return []int{int(f.Fd())}, nil
   105  }
   106  
   107  func (s *sampleFDSource) Recover(key string, data []byte) error {
   108  	if s.stopped {
   109  		return errors.New("sampleFDSource is stopped")
   110  	}
   111  
   112  	var fdData sampleFDData
   113  	if err := json.Unmarshal(data, &fdData); err != nil {
   114  		return fmt.Errorf("error unmarshalling json: %v", err)
   115  	}
   116  
   117  	if fdData.Content != "42" {
   118  		return fmt.Errorf("bad data passed to Recover: %q", data)
   119  	}
   120  
   121  	if _, found := s.files[key]; found {
   122  		return fmt.Errorf("key %q is already present", key)
   123  	}
   124  
   125  	s.files[key] = nil
   126  	return ioutil.WriteFile(filepath.Join(s.tmpDir, key), []byte(fdData.Content), 0644)
   127  }
   128  
   129  func (s *sampleFDSource) Release(key string) error {
   130  	if s.stopped {
   131  		return errors.New("sampleFDSource is stopped")
   132  	}
   133  
   134  	f, found := s.files[key]
   135  	if !found {
   136  		return fmt.Errorf("file not found: %q", key)
   137  	}
   138  	delete(s.files, key)
   139  
   140  	// "recovered" entries don't have FDs
   141  	if f != nil {
   142  		if err := f.Close(); err != nil {
   143  			return fmt.Errorf("can't close file %q: %v", f.Name(), err)
   144  		}
   145  	}
   146  
   147  	return nil
   148  }
   149  
   150  func (s *sampleFDSource) GetInfo(key string) ([]byte, error) {
   151  	if s.stopped {
   152  		return nil, errors.New("sampleFDSource is stopped")
   153  	}
   154  
   155  	_, found := s.files[key]
   156  	if !found {
   157  		return nil, fmt.Errorf("file not found: %q", key)
   158  	}
   159  	return []byte("info_" + key), nil
   160  }
   161  
   162  func (s *sampleFDSource) Stop() error {
   163  	s.stopped = true
   164  	return nil
   165  }
   166  
   167  func (s *sampleFDSource) isEmpty() bool {
   168  	return len(s.files) == 0
   169  }
   170  
   171  func (s *sampleFDSource) isRecovered(key string) bool {
   172  	f, found := s.files[key]
   173  	return found && f == nil
   174  }
   175  
   176  func verifyFD(t *testing.T, c *FDClient, key string, data string) {
   177  	fds, info, err := c.GetFDs(key)
   178  	if err != nil {
   179  		t.Fatalf("GetFDs(): %v", err)
   180  	}
   181  
   182  	expectedInfo := "info_" + key
   183  	if string(info) != expectedInfo {
   184  		t.Errorf("bad info: %q instead of %q", info, expectedInfo)
   185  	}
   186  
   187  	f1 := os.NewFile(uintptr(fds[0]), "acquired-fd")
   188  	defer f1.Close()
   189  
   190  	content, err := ioutil.ReadAll(f1)
   191  	if err != nil {
   192  		t.Fatalf("ReadAll(): %v", err)
   193  	}
   194  
   195  	if string(content) != data {
   196  		t.Errorf("bad content: %q instead of %q", content, data)
   197  	}
   198  }
   199  
   200  func withFDClient(t *testing.T, toCall func(*FDClient, *sampleFDSource)) {
   201  	tmpDir, err := ioutil.TempDir("", "pass-fd-test")
   202  	if err != nil {
   203  		t.Fatalf("ioutil.TempDir(): %v", err)
   204  	}
   205  	defer os.RemoveAll(tmpDir)
   206  
   207  	socketPath := filepath.Join(tmpDir, "passfd")
   208  	src := newSampleFDSource(tmpDir)
   209  	s := NewFDServer(socketPath, src)
   210  	if err := s.Serve(); err != nil {
   211  		t.Fatalf("Serve(): %v", err)
   212  	}
   213  	defer func() {
   214  		s.Stop()
   215  		if !src.stopped {
   216  			t.Errorf("FDSource not stopped")
   217  		}
   218  	}()
   219  	c := NewFDClient(socketPath)
   220  
   221  	toCall(c, src)
   222  }
   223  
   224  func TestFDServer(t *testing.T) {
   225  	withFDClient(t, func(c *FDClient, src *sampleFDSource) {
   226  		content := []string{"foo", "bar", "baz"}
   227  		for _, data := range content {
   228  			var err error
   229  			key := "k_" + data
   230  			respData, err := c.AddFDs(key, sampleFDData{Content: data})
   231  			if err != nil {
   232  				t.Fatalf("AddFDs(): %v", err)
   233  			}
   234  			expectedRespData := "abcdef"
   235  			if string(respData) != expectedRespData {
   236  				t.Errorf("bad data returned from add: %q instead of %q", data, expectedRespData)
   237  			}
   238  		}
   239  
   240  		for _, data := range content {
   241  			key := "k_" + data
   242  			verifyFD(t, c, key, data)
   243  		}
   244  
   245  		for _, data := range content {
   246  			key := "k_" + data
   247  			if err := c.ReleaseFDs(key); err != nil {
   248  				t.Fatalf("ReleaseFD(): key %q: %v", key, err)
   249  			}
   250  		}
   251  
   252  		// here we make sure that releasing FDs works and also that passing errors from the
   253  		// server works, too
   254  		expectedErrorMessage := fmt.Sprintf("server returned error: bad fd key: \"k_foo\"")
   255  		if _, _, err := c.GetFDs("k_foo"); err == nil {
   256  			t.Errorf("GetFDs didn't return an error for a released fd")
   257  		} else if err.Error() != expectedErrorMessage {
   258  			t.Errorf("Bad error message from GetFD: %q instead of %q", err.Error(), expectedErrorMessage)
   259  		}
   260  
   261  		if !src.isEmpty() {
   262  			t.Errorf("fd source is not empty (but it should be)")
   263  		}
   264  	})
   265  }
   266  
   267  func TestFDServerRecovery(t *testing.T) {
   268  	withFDClient(t, func(c *FDClient, src *sampleFDSource) {
   269  		if err := c.Recover("foobar", sampleFDData{"42"}); err != nil {
   270  			t.Errorf("Recover(): %v", err)
   271  		}
   272  		if !src.isRecovered("foobar") {
   273  			t.Errorf("the key is not recovered")
   274  		}
   275  		if err := c.ReleaseFDs("foobar"); err != nil {
   276  			t.Errorf("Error releasing the recovered FDs: %v", err)
   277  		}
   278  	})
   279  }
   280  
   281  func TestFDServerRetrieveFds(t *testing.T) {
   282  	withFDClient(t, func(c *FDClient, src *sampleFDSource) {
   283  		if err := c.Recover("foobar", sampleFDData{"42"}); err != nil {
   284  			t.Errorf("Recover(): %v", err)
   285  		}
   286  		if !src.isRecovered("foobar") {
   287  			t.Errorf("the key is not recovered")
   288  		}
   289  
   290  		if _, err := src.RetrieveFDs("foobar"); err != nil {
   291  			t.Errorf("failed to RetrieveFDs: %v", err)
   292  		}
   293  
   294  		if err := c.ReleaseFDs("foobar"); err != nil {
   295  			t.Errorf("Error releasing the recovered FDs: %v", err)
   296  		}
   297  	})
   298  }