gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/pytorch_test.go (about) 1 // Copyright 2023 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package pytorch_test tests basic PyTorch workloads. 16 package pytorch_test 17 18 import ( 19 "context" 20 "testing" 21 22 "gvisor.dev/gvisor/pkg/test/dockerutil" 23 ) 24 25 // runPytorch runs the given script and command in a PyTorch container. 26 func runPytorch(ctx context.Context, t *testing.T, scriptPath string, args ...string) { 27 t.Helper() 28 c := dockerutil.MakeContainer(ctx, t) 29 opts := dockerutil.GPURunOpts() 30 opts.Image = "gpu/pytorch" 31 cmd := append([]string{"python3", scriptPath}, args...) 32 out, err := c.Run(ctx, opts, cmd...) 33 if err != nil { 34 t.Errorf("Failed: %v\nContainer output:\n%s", err, out) 35 } else { 36 t.Logf("Container output:\n%s", out) 37 } 38 } 39 40 // TestCUDAIsAvailable checks that PyTorch recognizes that CUDA is available. 41 func TestCUDAIsAvailable(t *testing.T) { 42 runPytorch(context.Background(), t, "/is_cuda_available.py") 43 } 44 45 // TestLinearRegressionModel runs a simple linear regression model. 46 func TestLinearRegressionModel(t *testing.T) { 47 runPytorch(context.Background(), t, "/pytorch-examples/regression/main.py", "--cuda") 48 } 49 50 // TestMNIST runs an MNIST model. 51 func TestMNIST(t *testing.T) { 52 runPytorch(context.Background(), t, "/pytorch-examples/mnist/main.py", "--epochs=1", "--dry-run") 53 } 54 55 // TestIssue9827 verifies that issue 9827 is fixed. 56 func TestIssue9827(t *testing.T) { 57 // TODO(gvisor.dev/issue/9827): Don't skip this once the 58 // test works and doesn't run forever: 59 t.Skip("TODO(gvisor.dev/issue/9827): Issue 9827 is not yet fixed.") 60 runPytorch(context.Background(), t, "/issue_9827.py") 61 }