github.com/webmeshproj/webmesh-cni@v0.0.27/internal/types/netconf_test.go (about)

     1  /*
     2  Copyright 2023 Avi Zimmerman <avi.zimmerman@gmail.com>.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package types
    18  
    19  import (
    20  	"encoding/json"
    21  	"io"
    22  	"log/slog"
    23  	"os"
    24  	"testing"
    25  
    26  	"github.com/containernetworking/cni/pkg/skel"
    27  	meshsys "github.com/webmeshproj/webmesh/pkg/meshnet/system"
    28  	meshtypes "github.com/webmeshproj/webmesh/pkg/storage/types"
    29  
    30  	v1 "github.com/webmeshproj/webmesh-cni/api/v1"
    31  )
    32  
    33  // TestNetConf tests the NetConf type.
    34  func TestNetConf(t *testing.T) {
    35  	t.Parallel()
    36  
    37  	testConf := &NetConf{
    38  		Kubernetes: Kubernetes{
    39  			Kubeconfig: "foo",
    40  			NodeName:   "bar",
    41  			K8sAPIRoot: "http://localhost:8080",
    42  			Namespace:  "baz",
    43  		},
    44  		Interface: Interface{
    45  			MTU:         1234,
    46  			DisableIPv4: true,
    47  			DisableIPv6: true,
    48  		},
    49  		LogLevel: "info",
    50  	}
    51  	testData, err := json.Marshal(testConf)
    52  	if err != nil {
    53  		t.Fatal("marshal test data", err)
    54  	}
    55  
    56  	t.Run("Defaults", func(t *testing.T) {
    57  		t.Parallel()
    58  		// Make sure a nil configuration produces the correct defaults.
    59  		var conf *NetConf
    60  		if conf.DeepEqual(testConf) {
    61  			t.Errorf("expected testConf to not be equal to nil conf")
    62  		}
    63  		if testConf.DeepEqual(conf) {
    64  			t.Errorf("expected nil conf to not be equal to testConf")
    65  		}
    66  		conf = conf.SetDefaults()
    67  		if conf.LogLevel != "info" {
    68  			t.Errorf("expected default log level to be info, got %s", conf.LogLevel)
    69  		}
    70  		if conf.Kubernetes.Kubeconfig != DefaultKubeconfigPath {
    71  			t.Errorf("expected default kubeconfig to be %s, got %s", DefaultKubeconfigPath, conf.Kubernetes.Kubeconfig)
    72  		}
    73  		if conf.Interface.MTU != meshsys.DefaultMTU {
    74  			t.Errorf("expected default MTU to be %d, got %d", meshsys.DefaultMTU, conf.Interface.MTU)
    75  		}
    76  		if conf.DeepEqual(testConf) {
    77  			t.Errorf("expected testConf to not be equal to conf")
    78  		}
    79  		// Make sure the same goes for an empty one
    80  		conf = &NetConf{}
    81  		conf = conf.SetDefaults()
    82  		if conf.LogLevel != "info" {
    83  			t.Errorf("expected default log level to be info, got %s", conf.LogLevel)
    84  		}
    85  		if conf.Kubernetes.Kubeconfig != DefaultKubeconfigPath {
    86  			t.Errorf("expected default kubeconfig to be %s, got %s", DefaultKubeconfigPath, conf.Kubernetes.Kubeconfig)
    87  		}
    88  		if conf.Interface.MTU != meshsys.DefaultMTU {
    89  			t.Errorf("expected default MTU to be %d, got %d", meshsys.DefaultMTU, conf.Interface.MTU)
    90  		}
    91  		if conf.DeepEqual(testConf) {
    92  			t.Errorf("expected testConf to not be equal to conf")
    93  		}
    94  		// Make sure defaults dont override existing values.
    95  		conf = &NetConf{
    96  			LogLevel: "debug",
    97  			Kubernetes: Kubernetes{
    98  				Kubeconfig: "foo",
    99  			},
   100  			Interface: Interface{
   101  				MTU: 1234,
   102  			},
   103  		}
   104  		conf = conf.SetDefaults()
   105  		if conf.LogLevel != "debug" {
   106  			t.Errorf("expected log level to be debug, got %s", conf.LogLevel)
   107  		}
   108  		if conf.Kubernetes.Kubeconfig != "foo" {
   109  			t.Errorf("expected kubeconfig to be foo, got %s", conf.Kubernetes.Kubeconfig)
   110  		}
   111  		if conf.Interface.MTU != 1234 {
   112  			t.Errorf("expected MTU to be 1234, got %d", conf.Interface.MTU)
   113  		}
   114  	})
   115  
   116  	t.Run("Decoders", func(t *testing.T) {
   117  		t.Parallel()
   118  
   119  		t.Run("FromFile", func(t *testing.T) {
   120  			t.Parallel()
   121  			f, err := os.CreateTemp("", "")
   122  			if err != nil {
   123  				t.Fatal("create temporary file", err)
   124  			}
   125  			defer os.Remove(f.Name())
   126  			_, err = f.Write(testData)
   127  			if err != nil {
   128  				t.Fatal("write test data", err)
   129  			}
   130  			err = f.Close()
   131  			if err != nil {
   132  				t.Fatal("close file", err)
   133  			}
   134  			conf, err := LoadNetConfFromFile(f.Name())
   135  			if err != nil {
   136  				t.Fatal("load config from file", err)
   137  			}
   138  			if !testConf.DeepEqual(conf) {
   139  				t.Errorf("expected config to be equal to test config, got %v", conf)
   140  			}
   141  			t.Run("NonExist", func(t *testing.T) {
   142  				t.Parallel()
   143  				_, err := LoadNetConfFromFile("nonexist")
   144  				if err == nil {
   145  					t.Error("expected error, got nil")
   146  				}
   147  			})
   148  		})
   149  
   150  		t.Run("FromArgs", func(t *testing.T) {
   151  			t.Parallel()
   152  			conf, err := LoadNetConfFromArgs(&skel.CmdArgs{
   153  				StdinData: testData,
   154  			})
   155  			if err != nil {
   156  				t.Fatal("load config from file", err)
   157  			}
   158  			if !testConf.DeepEqual(conf) {
   159  				t.Errorf("expected config to be equal to test config, got %v", conf)
   160  			}
   161  		})
   162  
   163  		t.Run("InvalidData", func(t *testing.T) {
   164  			_, err := DecodeNetConf([]byte("invalid"))
   165  			if err == nil {
   166  				t.Error("expected error, got nil")
   167  			}
   168  		})
   169  	})
   170  
   171  	t.Run("Logging", func(t *testing.T) {
   172  		t.Parallel()
   173  
   174  		t.Run("NewLogger", func(t *testing.T) {
   175  			t.Parallel()
   176  			// NewLogger should never return nil.
   177  			log := testConf.NewLogger(&skel.CmdArgs{})
   178  			if log == nil {
   179  				t.Error("expected logger to not be nil")
   180  			}
   181  		})
   182  
   183  		t.Run("LogWriter", func(t *testing.T) {
   184  			t.Parallel()
   185  			conf := &NetConf{}
   186  			tc := []struct {
   187  				name     string
   188  				level    string
   189  				expected io.Writer
   190  			}{
   191  				{
   192  					name:     "Default",
   193  					expected: os.Stderr,
   194  				},
   195  				{
   196  					name:     "Debug",
   197  					level:    "debug",
   198  					expected: os.Stderr,
   199  				},
   200  				{
   201  					name:     "Info",
   202  					level:    "info",
   203  					expected: os.Stderr,
   204  				},
   205  				{
   206  					name:     "Warn",
   207  					level:    "warn",
   208  					expected: os.Stderr,
   209  				},
   210  				{
   211  					name:     "Error",
   212  					level:    "error",
   213  					expected: os.Stderr,
   214  				},
   215  				{
   216  					name:     "Silent",
   217  					level:    "silent",
   218  					expected: io.Discard,
   219  				},
   220  				{
   221  					name:     "Off",
   222  					level:    "off",
   223  					expected: io.Discard,
   224  				},
   225  			}
   226  			for _, c := range tc {
   227  				conf.LogLevel = c.level
   228  				if conf.LogWriter() != c.expected {
   229  					t.Errorf("expected log writer to be %v, got %v", c.expected, conf.LogWriter())
   230  				}
   231  			}
   232  		})
   233  
   234  		t.Run("LogLevels", func(t *testing.T) {
   235  			t.Parallel()
   236  			conf := &NetConf{}
   237  			tc := []struct {
   238  				name     string
   239  				level    string
   240  				expected slog.Level
   241  			}{
   242  				{
   243  					name:     "Default",
   244  					expected: slog.LevelInfo,
   245  				},
   246  				{
   247  					name:     "Debug",
   248  					level:    "debug",
   249  					expected: slog.LevelDebug,
   250  				},
   251  				{
   252  					name:     "DebugAllCaps",
   253  					level:    "DEBUG",
   254  					expected: slog.LevelDebug,
   255  				},
   256  				{
   257  					name:     "DebugMixedCase",
   258  					level:    "DeBuG",
   259  					expected: slog.LevelDebug,
   260  				},
   261  				{
   262  					name:     "Info",
   263  					level:    "info",
   264  					expected: slog.LevelInfo,
   265  				},
   266  				{
   267  					name:     "InfoAllCaps",
   268  					level:    "INFO",
   269  					expected: slog.LevelInfo,
   270  				},
   271  				{
   272  					name:     "InfoMixedCase",
   273  					level:    "InFo",
   274  					expected: slog.LevelInfo,
   275  				},
   276  				{
   277  					name:     "Warn",
   278  					level:    "warn",
   279  					expected: slog.LevelWarn,
   280  				},
   281  				{
   282  					name:     "WarnAllCaps",
   283  					level:    "WARN",
   284  					expected: slog.LevelWarn,
   285  				},
   286  				{
   287  					name:     "WarnMixedCase",
   288  					level:    "WaRn",
   289  					expected: slog.LevelWarn,
   290  				},
   291  				{
   292  					name:     "Error",
   293  					level:    "error",
   294  					expected: slog.LevelError,
   295  				},
   296  				{
   297  					name:     "ErrorAllCaps",
   298  					level:    "ERROR",
   299  					expected: slog.LevelError,
   300  				},
   301  				{
   302  					name:     "ErrorMixedCase",
   303  					level:    "ErRoR",
   304  					expected: slog.LevelError,
   305  				},
   306  			}
   307  			for _, c := range tc {
   308  				conf.LogLevel = c.level
   309  				if conf.SlogLevel() != c.expected {
   310  					t.Errorf("expected slog level to be %v, got %v", c.expected, conf.SlogLevel())
   311  				}
   312  			}
   313  		})
   314  
   315  	})
   316  
   317  	t.Run("PeerContainers", func(t *testing.T) {
   318  		t.Parallel()
   319  
   320  		t.Run("ObjectKeys", func(t *testing.T) {
   321  			t.Parallel()
   322  			// Object keys should be the container ID and configured namespace.
   323  			conf := &NetConf{
   324  				Kubernetes: Kubernetes{
   325  					Namespace: "foo",
   326  				},
   327  			}
   328  			args := &skel.CmdArgs{
   329  				ContainerID: "bar",
   330  			}
   331  			key := conf.ObjectKeyFromArgs(args)
   332  			if key.Name != args.ContainerID {
   333  				t.Errorf("expected object key name to be %s, got %s", args.ContainerID, key.Name)
   334  			}
   335  			if key.Namespace != conf.Kubernetes.Namespace {
   336  				t.Errorf("expected object key namespace to be %s, got %s", conf.Kubernetes.Namespace, key.Namespace)
   337  			}
   338  		})
   339  
   340  		t.Run("ContainerObjects", func(t *testing.T) {
   341  			t.Parallel()
   342  			// A new container's spec should match the given args and configuration.
   343  			conf := &NetConf{
   344  				Interface: Interface{
   345  					MTU:         1234,
   346  					DisableIPv4: true,
   347  					DisableIPv6: true,
   348  				},
   349  				Kubernetes: Kubernetes{
   350  					NodeName:  "k8s-node",
   351  					Namespace: "default",
   352  				},
   353  				LogLevel: "debug",
   354  			}
   355  			args := &skel.CmdArgs{
   356  				ContainerID: "bar",
   357  				Netns:       "/proc/1234/ns/net",
   358  			}
   359  			container := conf.ContainerFromArgs(args)
   360  
   361  			// Make sure the container's spec matches the configuration.
   362  			EnsureContainerEqualsTestConf(t, conf, &container, args)
   363  
   364  			// Set the container ID to a really long name and make sure the interface
   365  			// is truncated to 15 characters.
   366  			args.ContainerID = "reallylongcontainerid"
   367  			container = conf.ContainerFromArgs(args)
   368  			if len(container.Spec.IfName) != 15 {
   369  				t.Errorf("expected container ifname to be truncated to 15 characters, got %s", container.Spec.IfName)
   370  			}
   371  			if container.Spec.IfName != IfacePrefix+"reallylon0" {
   372  				t.Errorf("expected container ifname to be %s, got %s", "wmeshreallylongc0", container.Spec.IfName)
   373  			}
   374  		})
   375  	})
   376  }
   377  
   378  func EnsureContainerEqualsTestConf(t *testing.T, conf *NetConf, container *v1.PeerContainer, args *skel.CmdArgs) {
   379  	if container.Name != args.ContainerID {
   380  		t.Errorf("expected container name to be %s, got %s", args.ContainerID, container.Name)
   381  	}
   382  	if container.Namespace != conf.Kubernetes.Namespace {
   383  		t.Errorf("expected container namespace to be %s, got %s", conf.Kubernetes.Namespace, container.Namespace)
   384  	}
   385  	if container.Spec.NodeID != meshtypes.TruncateID(args.ContainerID) {
   386  		t.Errorf("expected container node ID to be %s, got %s", args.ContainerID, container.Spec.NodeID)
   387  	}
   388  	if container.Spec.Netns != args.Netns {
   389  		t.Errorf("expected container netns to be %s, got %s", args.Netns, container.Spec.Netns)
   390  	}
   391  	expectedIfName := IfNameFromID(meshtypes.TruncateID(args.ContainerID))
   392  	if container.Spec.IfName != expectedIfName {
   393  		t.Errorf("expected container ifname to be %s, got %s", expectedIfName, container.Spec.IfName)
   394  	}
   395  	if container.Spec.NodeName != conf.Kubernetes.NodeName {
   396  		t.Errorf("expected container node name to be %s, got %s", conf.Kubernetes.NodeName, container.Spec.NodeName)
   397  	}
   398  	if container.Spec.MTU != conf.Interface.MTU {
   399  		t.Errorf("expected container mtu to be %d, got %d", conf.Interface.MTU, container.Spec.MTU)
   400  	}
   401  	if container.Spec.DisableIPv4 != conf.Interface.DisableIPv4 {
   402  		t.Errorf("expected container disable ipv4 to be %t, got %t", conf.Interface.DisableIPv4, container.Spec.DisableIPv4)
   403  	}
   404  	if container.Spec.DisableIPv6 != conf.Interface.DisableIPv6 {
   405  		t.Errorf("expected container disable ipv6 to be %t, got %t", conf.Interface.DisableIPv6, container.Spec.DisableIPv6)
   406  	}
   407  	if container.Spec.LogLevel != conf.LogLevel {
   408  		t.Errorf("expected container log level to be %s, got %s", conf.LogLevel, container.Spec.LogLevel)
   409  	}
   410  }