github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/test/benchmarks/ml/tensorflow_test.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package ml
    15  
    16  import (
    17  	"context"
    18  	"os"
    19  	"testing"
    20  
    21  	"github.com/SagerNet/gvisor/pkg/test/dockerutil"
    22  	"github.com/SagerNet/gvisor/test/benchmarks/harness"
    23  	"github.com/SagerNet/gvisor/test/benchmarks/tools"
    24  )
    25  
    26  // BenchmarkTensorflow runs workloads from a TensorFlow tutorial.
    27  // See: https://github.com/aymericdamien/TensorFlow-Examples
    28  func BenchmarkTensorflow(b *testing.B) {
    29  	workloads := map[string]string{
    30  		"GradientDecisionTree": "2_BasicModels/gradient_boosted_decision_tree.py",
    31  		"Kmeans":               "2_BasicModels/kmeans.py",
    32  		"LogisticRegression":   "2_BasicModels/logistic_regression.py",
    33  		"NearestNeighbor":      "2_BasicModels/nearest_neighbor.py",
    34  		"RandomForest":         "2_BasicModels/random_forest.py",
    35  		"ConvolutionalNetwork": "3_NeuralNetworks/convolutional_network.py",
    36  		"MultilayerPerceptron": "3_NeuralNetworks/multilayer_perceptron.py",
    37  		"NeuralNetwork":        "3_NeuralNetworks/neural_network.py",
    38  	}
    39  
    40  	machine, err := harness.GetMachine()
    41  	if err != nil {
    42  		b.Fatalf("failed to get machine: %v", err)
    43  	}
    44  	defer machine.CleanUp()
    45  
    46  	for name, workload := range workloads {
    47  		runName, err := tools.ParametersToName(tools.Parameter{
    48  			Name:  "operation",
    49  			Value: name,
    50  		})
    51  		if err != nil {
    52  			b.Fatalf("Faile to parse param: %v", err)
    53  		}
    54  
    55  		b.Run(runName, func(b *testing.B) {
    56  			ctx := context.Background()
    57  
    58  			b.ResetTimer()
    59  			b.StopTimer()
    60  
    61  			for i := 0; i < b.N; i++ {
    62  				container := machine.GetContainer(ctx, b)
    63  				defer container.CleanUp(ctx)
    64  				if err := harness.DropCaches(machine); err != nil {
    65  					b.Skipf("failed to drop caches: %v. You probably need root.", err)
    66  				}
    67  
    68  				// Run tensorflow.
    69  				b.StartTimer()
    70  				if out, err := container.Run(ctx, dockerutil.RunOpts{
    71  					Image:   "benchmarks/tensorflow",
    72  					Env:     []string{"PYTHONPATH=$PYTHONPATH:/TensorFlow-Examples/examples"},
    73  					WorkDir: "/TensorFlow-Examples/examples",
    74  				}, "python", workload); err != nil {
    75  					b.Fatalf("failed to run container: %v logs: %s", err, out)
    76  				}
    77  				b.StopTimer()
    78  			}
    79  		})
    80  	}
    81  }
    82  
    83  func TestMain(m *testing.M) {
    84  	harness.Init()
    85  	harness.SetFixedBenchmarks()
    86  	os.Exit(m.Run())
    87  }