github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/initcontainer_test.go (about) 1 // Copyright 2021 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License 14 15 package pytorch 16 17 import ( 18 "strings" 19 "testing" 20 21 "github.com/go-logr/logr" 22 "github.com/onsi/ginkgo/v2" 23 "github.com/onsi/gomega" 24 "k8s.io/utils/pointer" 25 26 kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 27 "github.com/kubeflow/training-operator/pkg/config" 28 ) 29 30 func TestInitContainer(t *testing.T) { 31 gomega.RegisterFailHandler(ginkgo.Fail) 32 defer ginkgo.GinkgoRecover() 33 34 config.Config.PyTorchInitContainerImage = config.PyTorchInitContainerImageDefault 35 config.Config.PyTorchInitContainerTemplateFile = config.PyTorchInitContainerTemplateFileDefault 36 config.Config.PyTorchInitContainerMaxTries = config.PyTorchInitContainerMaxTriesDefault 37 38 testCases := []struct { 39 job *kubeflowv1.PyTorchJob 40 rtype kubeflowv1.ReplicaType 41 index string 42 expected int 43 exepctedErr error 44 }{ 45 { 46 job: &kubeflowv1.PyTorchJob{ 47 Spec: kubeflowv1.PyTorchJobSpec{ 48 PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 49 kubeflowv1.PyTorchJobReplicaTypeWorker: { 50 Replicas: pointer.Int32(1), 51 }, 52 }, 53 }, 54 }, 55 rtype: kubeflowv1.PyTorchJobReplicaTypeWorker, 56 index: "0", 57 expected: 0, 58 exepctedErr: nil, 59 }, 60 { 61 job: &kubeflowv1.PyTorchJob{ 62 Spec: kubeflowv1.PyTorchJobSpec{ 63 PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 64 kubeflowv1.PyTorchJobReplicaTypeWorker: { 65 Replicas: pointer.Int32(1), 66 }, 67 kubeflowv1.PyTorchJobReplicaTypeMaster: { 68 Replicas: pointer.Int32(1), 69 }, 70 }, 71 }, 72 }, 73 rtype: kubeflowv1.PyTorchJobReplicaTypeWorker, 74 index: "0", 75 expected: 1, 76 exepctedErr: nil, 77 }, 78 { 79 job: &kubeflowv1.PyTorchJob{ 80 Spec: kubeflowv1.PyTorchJobSpec{ 81 PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 82 kubeflowv1.PyTorchJobReplicaTypeWorker: { 83 Replicas: pointer.Int32(1), 84 }, 85 kubeflowv1.PyTorchJobReplicaTypeMaster: { 86 Replicas: pointer.Int32(1), 87 }, 88 }, 89 }, 90 }, 91 rtype: kubeflowv1.PyTorchJobReplicaTypeMaster, 92 index: "0", 93 expected: 0, 94 exepctedErr: nil, 95 }, 96 } 97 98 for _, t := range testCases { 99 log := logr.Discard() 100 podTemplateSpec := t.job.Spec.PyTorchReplicaSpecs[t.rtype].Template 101 err := setInitContainer(t.job, &podTemplateSpec, 102 strings.ToLower(string(t.rtype)), t.index, log) 103 if t.exepctedErr == nil { 104 gomega.Expect(err).To(gomega.BeNil()) 105 } else { 106 gomega.Expect(err).To(gomega.Equal(t.exepctedErr)) 107 } 108 gomega.Expect(len(podTemplateSpec.Spec.InitContainers)).To(gomega.Equal(t.expected)) 109 } 110 }