gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/pytorch_test.go (about)

     1  // Copyright 2023 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package pytorch_test tests basic PyTorch workloads.
    16  package pytorch_test
    17  
    18  import (
    19  	"context"
    20  	"testing"
    21  
    22  	"gvisor.dev/gvisor/pkg/test/dockerutil"
    23  )
    24  
    25  // runPytorch runs the given script and command in a PyTorch container.
    26  func runPytorch(ctx context.Context, t *testing.T, scriptPath string, args ...string) {
    27  	t.Helper()
    28  	c := dockerutil.MakeContainer(ctx, t)
    29  	opts := dockerutil.GPURunOpts()
    30  	opts.Image = "gpu/pytorch"
    31  	cmd := append([]string{"python3", scriptPath}, args...)
    32  	out, err := c.Run(ctx, opts, cmd...)
    33  	if err != nil {
    34  		t.Errorf("Failed: %v\nContainer output:\n%s", err, out)
    35  	} else {
    36  		t.Logf("Container output:\n%s", out)
    37  	}
    38  }
    39  
    40  // TestCUDAIsAvailable checks that PyTorch recognizes that CUDA is available.
    41  func TestCUDAIsAvailable(t *testing.T) {
    42  	runPytorch(context.Background(), t, "/is_cuda_available.py")
    43  }
    44  
    45  // TestLinearRegressionModel runs a simple linear regression model.
    46  func TestLinearRegressionModel(t *testing.T) {
    47  	runPytorch(context.Background(), t, "/pytorch-examples/regression/main.py", "--cuda")
    48  }
    49  
    50  // TestMNIST runs an MNIST model.
    51  func TestMNIST(t *testing.T) {
    52  	runPytorch(context.Background(), t, "/pytorch-examples/mnist/main.py", "--epochs=1", "--dry-run")
    53  }
    54  
    55  // TestIssue9827 verifies that issue 9827 is fixed.
    56  func TestIssue9827(t *testing.T) {
    57  	// TODO(gvisor.dev/issue/9827): Don't skip this once the
    58  	// test works and doesn't run forever:
    59  	t.Skip("TODO(gvisor.dev/issue/9827): Issue 9827 is not yet fixed.")
    60  	runPytorch(context.Background(), t, "/issue_9827.py")
    61  }