github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/runsc/config/config_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package config
    16  
    17  import (
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/SagerNet/gvisor/runsc/flag"
    22  )
    23  
    24  func init() {
    25  	RegisterFlags()
    26  }
    27  
    28  func TestDefault(t *testing.T) {
    29  	c, err := NewFromFlags()
    30  	if err != nil {
    31  		t.Fatal(err)
    32  	}
    33  	// "--root" is always set to something different than the default. Reset it
    34  	// to make it easier to test that default values do not generate flags.
    35  	c.RootDir = ""
    36  
    37  	// All defaults doesn't require setting flags.
    38  	flags := c.ToFlags()
    39  	if len(flags) > 0 {
    40  		t.Errorf("default flags not set correctly for: %s", flags)
    41  	}
    42  }
    43  
    44  func setDefault(name string) {
    45  	fl := flag.CommandLine.Lookup(name)
    46  	fl.Value.Set(fl.DefValue)
    47  }
    48  
    49  func TestFromFlags(t *testing.T) {
    50  	flag.CommandLine.Lookup("root").Value.Set("some-path")
    51  	flag.CommandLine.Lookup("debug").Value.Set("true")
    52  	flag.CommandLine.Lookup("num-network-channels").Value.Set("123")
    53  	flag.CommandLine.Lookup("network").Value.Set("none")
    54  	defer func() {
    55  		setDefault("root")
    56  		setDefault("debug")
    57  		setDefault("num-network-channels")
    58  		setDefault("network")
    59  	}()
    60  
    61  	c, err := NewFromFlags()
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  	if want := "some-path"; c.RootDir != want {
    66  		t.Errorf("RootDir=%v, want: %v", c.RootDir, want)
    67  	}
    68  	if want := true; c.Debug != want {
    69  		t.Errorf("Debug=%v, want: %v", c.Debug, want)
    70  	}
    71  	if want := 123; c.NumNetworkChannels != want {
    72  		t.Errorf("NumNetworkChannels=%v, want: %v", c.NumNetworkChannels, want)
    73  	}
    74  	if want := NetworkNone; c.Network != want {
    75  		t.Errorf("Network=%v, want: %v", c.Network, want)
    76  	}
    77  }
    78  
    79  func TestToFlags(t *testing.T) {
    80  	c, err := NewFromFlags()
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	c.RootDir = "some-path"
    85  	c.Debug = true
    86  	c.NumNetworkChannels = 123
    87  	c.Network = NetworkNone
    88  
    89  	flags := c.ToFlags()
    90  	if len(flags) != 4 {
    91  		t.Errorf("wrong number of flags set, want: 4, got: %d: %s", len(flags), flags)
    92  	}
    93  	t.Logf("Flags: %s", flags)
    94  	fm := map[string]string{}
    95  	for _, f := range flags {
    96  		kv := strings.Split(f, "=")
    97  		fm[kv[0]] = kv[1]
    98  	}
    99  	for name, want := range map[string]string{
   100  		"--root":                 "some-path",
   101  		"--debug":                "true",
   102  		"--num-network-channels": "123",
   103  		"--network":              "none",
   104  	} {
   105  		if got, ok := fm[name]; ok {
   106  			if got != want {
   107  				t.Errorf("flag %q, want: %q, got: %q", name, want, got)
   108  			}
   109  		} else {
   110  			t.Errorf("flag %q not set", name)
   111  		}
   112  	}
   113  }
   114  
   115  // TestInvalidFlags checks that enum flags fail when value is not in enum set.
   116  func TestInvalidFlags(t *testing.T) {
   117  	for _, tc := range []struct {
   118  		name  string
   119  		error string
   120  	}{
   121  		{
   122  			name:  "file-access",
   123  			error: "invalid file access type",
   124  		},
   125  		{
   126  			name:  "network",
   127  			error: "invalid network type",
   128  		},
   129  		{
   130  			name:  "qdisc",
   131  			error: "invalid qdisc",
   132  		},
   133  		{
   134  			name:  "watchdog-action",
   135  			error: "invalid watchdog action",
   136  		},
   137  		{
   138  			name:  "ref-leak-mode",
   139  			error: "invalid ref leak mode",
   140  		},
   141  	} {
   142  		t.Run(tc.name, func(t *testing.T) {
   143  			defer setDefault(tc.name)
   144  			if err := flag.CommandLine.Lookup(tc.name).Value.Set("invalid"); err == nil || !strings.Contains(err.Error(), tc.error) {
   145  				t.Errorf("flag.Value.Set(invalid) wrong error reported: %v", err)
   146  			}
   147  		})
   148  	}
   149  }
   150  
   151  func TestValidationFail(t *testing.T) {
   152  	for _, tc := range []struct {
   153  		name  string
   154  		flags map[string]string
   155  		error string
   156  	}{
   157  		{
   158  			name: "shared+overlay",
   159  			flags: map[string]string{
   160  				"file-access": "shared",
   161  				"overlay":     "true",
   162  			},
   163  			error: "overlay flag is incompatible",
   164  		},
   165  		{
   166  			name: "network-channels",
   167  			flags: map[string]string{
   168  				"num-network-channels": "-1",
   169  			},
   170  			error: "num_network_channels must be > 0",
   171  		},
   172  	} {
   173  		t.Run(tc.name, func(t *testing.T) {
   174  			for name, val := range tc.flags {
   175  				defer setDefault(name)
   176  				if err := flag.CommandLine.Lookup(name).Value.Set(val); err != nil {
   177  					t.Errorf("%s=%q: %v", name, val, err)
   178  				}
   179  			}
   180  			if _, err := NewFromFlags(); err == nil || !strings.Contains(err.Error(), tc.error) {
   181  				t.Errorf("NewFromFlags() wrong error reported: %v", err)
   182  			}
   183  		})
   184  	}
   185  }
   186  
   187  func TestOverride(t *testing.T) {
   188  	c, err := NewFromFlags()
   189  	if err != nil {
   190  		t.Fatal(err)
   191  	}
   192  	c.AllowFlagOverride = true
   193  
   194  	t.Run("string", func(t *testing.T) {
   195  		c.RootDir = "foobar"
   196  		if err := c.Override("root", "bar"); err != nil {
   197  			t.Fatalf("Override(root, bar) failed: %v", err)
   198  		}
   199  		defer setDefault("root")
   200  		if c.RootDir != "bar" {
   201  			t.Errorf("Override(root, bar) didn't work: %+v", c)
   202  		}
   203  	})
   204  
   205  	t.Run("bool", func(t *testing.T) {
   206  		c.Debug = true
   207  		if err := c.Override("debug", "false"); err != nil {
   208  			t.Fatalf("Override(debug, false) failed: %v", err)
   209  		}
   210  		defer setDefault("debug")
   211  		if c.Debug {
   212  			t.Errorf("Override(debug, false) didn't work: %+v", c)
   213  		}
   214  	})
   215  
   216  	t.Run("enum", func(t *testing.T) {
   217  		c.FileAccess = FileAccessShared
   218  		if err := c.Override("file-access", "exclusive"); err != nil {
   219  			t.Fatalf("Override(file-access, exclusive) failed: %v", err)
   220  		}
   221  		defer setDefault("file-access")
   222  		if c.FileAccess != FileAccessExclusive {
   223  			t.Errorf("Override(file-access, exclusive) didn't work: %+v", c)
   224  		}
   225  	})
   226  }
   227  
   228  func TestOverrideDisabled(t *testing.T) {
   229  	c, err := NewFromFlags()
   230  	if err != nil {
   231  		t.Fatal(err)
   232  	}
   233  	const errMsg = "flag override disabled"
   234  	if err := c.Override("root", "path"); err == nil || !strings.Contains(err.Error(), errMsg) {
   235  		t.Errorf("Override() wrong error: %v", err)
   236  	}
   237  }
   238  
   239  func TestOverrideError(t *testing.T) {
   240  	c, err := NewFromFlags()
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	c.AllowFlagOverride = true
   245  	for _, tc := range []struct {
   246  		name  string
   247  		value string
   248  		error string
   249  	}{
   250  		{
   251  			name:  "invalid",
   252  			value: "valid",
   253  			error: `flag "invalid" not found`,
   254  		},
   255  		{
   256  			name:  "debug",
   257  			value: "invalid",
   258  			error: "error setting flag debug",
   259  		},
   260  		{
   261  			name:  "file-access",
   262  			value: "invalid",
   263  			error: "invalid file access type",
   264  		},
   265  	} {
   266  		t.Run(tc.name, func(t *testing.T) {
   267  			if err := c.Override(tc.name, tc.value); err == nil || !strings.Contains(err.Error(), tc.error) {
   268  				t.Errorf("Override(%q, %q) wrong error: %v", tc.name, tc.value, err)
   269  			}
   270  		})
   271  	}
   272  }