github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/initcontainer_test.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License
    14  
    15  package pytorch
    16  
    17  import (
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/go-logr/logr"
    22  	"github.com/onsi/ginkgo/v2"
    23  	"github.com/onsi/gomega"
    24  	"k8s.io/utils/pointer"
    25  
    26  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    27  	"github.com/kubeflow/training-operator/pkg/config"
    28  )
    29  
    30  func TestInitContainer(t *testing.T) {
    31  	gomega.RegisterFailHandler(ginkgo.Fail)
    32  	defer ginkgo.GinkgoRecover()
    33  
    34  	config.Config.PyTorchInitContainerImage = config.PyTorchInitContainerImageDefault
    35  	config.Config.PyTorchInitContainerTemplateFile = config.PyTorchInitContainerTemplateFileDefault
    36  	config.Config.PyTorchInitContainerMaxTries = config.PyTorchInitContainerMaxTriesDefault
    37  
    38  	testCases := []struct {
    39  		job         *kubeflowv1.PyTorchJob
    40  		rtype       kubeflowv1.ReplicaType
    41  		index       string
    42  		expected    int
    43  		exepctedErr error
    44  	}{
    45  		{
    46  			job: &kubeflowv1.PyTorchJob{
    47  				Spec: kubeflowv1.PyTorchJobSpec{
    48  					PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    49  						kubeflowv1.PyTorchJobReplicaTypeWorker: {
    50  							Replicas: pointer.Int32(1),
    51  						},
    52  					},
    53  				},
    54  			},
    55  			rtype:       kubeflowv1.PyTorchJobReplicaTypeWorker,
    56  			index:       "0",
    57  			expected:    0,
    58  			exepctedErr: nil,
    59  		},
    60  		{
    61  			job: &kubeflowv1.PyTorchJob{
    62  				Spec: kubeflowv1.PyTorchJobSpec{
    63  					PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    64  						kubeflowv1.PyTorchJobReplicaTypeWorker: {
    65  							Replicas: pointer.Int32(1),
    66  						},
    67  						kubeflowv1.PyTorchJobReplicaTypeMaster: {
    68  							Replicas: pointer.Int32(1),
    69  						},
    70  					},
    71  				},
    72  			},
    73  			rtype:       kubeflowv1.PyTorchJobReplicaTypeWorker,
    74  			index:       "0",
    75  			expected:    1,
    76  			exepctedErr: nil,
    77  		},
    78  		{
    79  			job: &kubeflowv1.PyTorchJob{
    80  				Spec: kubeflowv1.PyTorchJobSpec{
    81  					PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    82  						kubeflowv1.PyTorchJobReplicaTypeWorker: {
    83  							Replicas: pointer.Int32(1),
    84  						},
    85  						kubeflowv1.PyTorchJobReplicaTypeMaster: {
    86  							Replicas: pointer.Int32(1),
    87  						},
    88  					},
    89  				},
    90  			},
    91  			rtype:       kubeflowv1.PyTorchJobReplicaTypeMaster,
    92  			index:       "0",
    93  			expected:    0,
    94  			exepctedErr: nil,
    95  		},
    96  	}
    97  
    98  	for _, t := range testCases {
    99  		log := logr.Discard()
   100  		podTemplateSpec := t.job.Spec.PyTorchReplicaSpecs[t.rtype].Template
   101  		err := setInitContainer(t.job, &podTemplateSpec,
   102  			strings.ToLower(string(t.rtype)), t.index, log)
   103  		if t.exepctedErr == nil {
   104  			gomega.Expect(err).To(gomega.BeNil())
   105  		} else {
   106  			gomega.Expect(err).To(gomega.Equal(t.exepctedErr))
   107  		}
   108  		gomega.Expect(len(podTemplateSpec.Spec.InitContainers)).To(gomega.Equal(t.expected))
   109  	}
   110  }