gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/gpu/drivers/install_driver_test.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 package drivers 15 16 import ( 17 "bytes" 18 "context" 19 "crypto/sha256" 20 "fmt" 21 "io" 22 "strings" 23 "testing" 24 25 "gvisor.dev/gvisor/pkg/sentry/devices/nvproxy" 26 ) 27 28 // TestVersionInstalled tests when the version is already installed. 29 func TestVersionInstalled(t *testing.T) { 30 ctx := context.Background() 31 versionContent := []byte("some cool content") 32 checksum := fmt.Sprintf("%x", sha256.Sum256(versionContent)) 33 version := nvproxy.NewDriverVersion(1, 2, 3) 34 getFunction := func() (nvproxy.DriverVersion, error) { return version, nil } 35 downloadFunction := func(context.Context, string) (io.ReadCloser, error) { return nil, fmt.Errorf("should not get here") } 36 installer := &Installer{ 37 requestedVersion: version, 38 expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { 39 if v == version { 40 return checksum, true 41 } 42 return "", false 43 }, 44 getCurrentDriverFunc: getFunction, 45 downloadFunc: downloadFunction, 46 } 47 if err := installer.MaybeInstall(ctx); err != nil { 48 t.Fatalf("Installation failed: %v", err) 49 } 50 } 51 52 // TestVersionNotSupported tests when the version is not supported. 53 func TestVersionNotSupported(t *testing.T) { 54 ctx := context.Background() 55 unsupportedVersion := nvproxy.NewDriverVersion(1, 2, 3) 56 installer := &Installer{ 57 requestedVersion: unsupportedVersion, 58 expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { 59 return "", false 60 }, 61 } 62 err := installer.MaybeInstall(ctx) 63 if err == nil { 64 t.Fatalf("Installation succeeded, want error") 65 } 66 if !strings.Contains(err.Error(), "not supported") { 67 t.Errorf("Installation failed, want error containing 'not supported' got: %s", err.Error()) 68 } 69 } 70 71 // TestShaMismatch tests when a checksum of a driver doesn't match what's in the map. 72 func TestShaMismatch(t *testing.T) { 73 ctx := context.Background() 74 version := nvproxy.NewDriverVersion(1, 2, 3) 75 installer := &Installer{ 76 requestedVersion: version, 77 getCurrentDriverFunc: func() (nvproxy.DriverVersion, error) { 78 return nvproxy.DriverVersion{}, nil 79 }, 80 expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { 81 if v == version { 82 return "mismatch", true 83 } 84 return "", false 85 }, 86 downloadFunc: func(context.Context, string) (io.ReadCloser, error) { 87 reader := bytes.NewReader([]byte("some content")) 88 return io.NopCloser(reader), nil 89 }, 90 } 91 err := installer.MaybeInstall(ctx) 92 if err == nil { 93 t.Fatalf("Installation succeeded, want error") 94 } 95 if !strings.Contains(err.Error(), "checksum mismatch") { 96 t.Errorf("Installation failed, want error containing 'mismatch checksum' got: %s", err.Error()) 97 } 98 } 99 100 // TestDriverInstalls tests the successful installation of a driver. 101 func TestDriverInstalls(t *testing.T) { 102 ctx := context.Background() 103 content := []byte("some content") 104 checksum := fmt.Sprintf("%x", sha256.Sum256(content)) 105 version := nvproxy.NewDriverVersion(1, 2, 3) 106 installer := &Installer{ 107 requestedVersion: version, 108 getCurrentDriverFunc: func() (nvproxy.DriverVersion, error) { 109 return nvproxy.DriverVersion{}, nil 110 }, 111 expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { 112 if v == version { 113 return checksum, true 114 } 115 return "", false 116 }, 117 downloadFunc: func(context.Context, string) (io.ReadCloser, error) { 118 reader := bytes.NewReader(content) 119 return io.NopCloser(reader), nil 120 }, 121 installFunc: func(_ string) error { 122 return nil 123 }, 124 } 125 if err := installer.MaybeInstall(ctx); err != nil { 126 t.Fatalf("Installation failed: %v", err) 127 } 128 }