github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/elastic_test.go (about) 1 // Copyright 2021 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License 14 15 package pytorch 16 17 import ( 18 "testing" 19 20 "github.com/onsi/ginkgo/v2" 21 "github.com/onsi/gomega" 22 corev1 "k8s.io/api/core/v1" 23 "k8s.io/utils/pointer" 24 25 kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 26 ) 27 28 func TestElasticGenerate(t *testing.T) { 29 gomega.RegisterFailHandler(ginkgo.Fail) 30 defer ginkgo.GinkgoRecover() 31 32 backendC10D := kubeflowv1.BackendC10D 33 34 tests := []struct { 35 name string 36 job *kubeflowv1.PyTorchJob 37 expectedErr error 38 expected []corev1.EnvVar 39 }{ 40 { 41 name: "Without ElasticPolicy", 42 job: &kubeflowv1.PyTorchJob{ 43 Spec: kubeflowv1.PyTorchJobSpec{ 44 PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 45 kubeflowv1.PyTorchJobReplicaTypeWorker: { 46 Replicas: pointer.Int32(1), 47 }, 48 }, 49 }, 50 }, 51 expectedErr: nil, 52 expected: nil, 53 }, 54 { 55 name: "With ElasticPolicy", 56 job: &kubeflowv1.PyTorchJob{ 57 Spec: kubeflowv1.PyTorchJobSpec{ 58 ElasticPolicy: &kubeflowv1.ElasticPolicy{ 59 MinReplicas: pointer.Int32(1), 60 MaxReplicas: pointer.Int32(3), 61 RDZVBackend: &backendC10D, 62 RDZVPort: pointer.Int32(1234), 63 RDZVHost: pointer.String("localhost"), 64 RDZVID: pointer.String("rdzv-id"), 65 RDZVConf: []kubeflowv1.RDZVConf{ 66 { 67 Key: "rdzv-conf-name", 68 Value: "rdzv-conf-value", 69 }, 70 { 71 Key: "rdzv-conf-name-1", 72 Value: "rdzv-conf-value-1", 73 }, 74 }, 75 MaxRestarts: pointer.Int32(3), 76 }, 77 PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 78 kubeflowv1.PyTorchJobReplicaTypeWorker: { 79 Replicas: pointer.Int32(1), 80 }, 81 }, 82 }, 83 }, 84 expectedErr: nil, 85 expected: []corev1.EnvVar{ 86 { 87 Name: EnvMaxRestarts, 88 Value: "3", 89 }, 90 { 91 Name: EnvRDZVBackend, 92 Value: "c10d", 93 }, 94 { 95 Name: EnvRDZVEndpoint, 96 Value: "localhost:1234", 97 }, 98 { 99 Name: EnvRDZVID, 100 Value: "rdzv-id", 101 }, 102 { 103 Name: EnvRDZVConf, 104 Value: "rdzv-conf-name=rdzv-conf-value,rdzv-conf-name-1=rdzv-conf-value-1", 105 }, 106 { 107 Name: EnvNnodes, 108 Value: "1:3", 109 }, 110 }, 111 }, 112 } 113 114 for _, test := range tests { 115 actual, err := GetElasticEnvVarGenerator().Generate(test.job) 116 if test.expectedErr == nil { 117 gomega.Expect(err).To(gomega.BeNil()) 118 } else { 119 gomega.Expect(err).To(gomega.Equal(test.expectedErr)) 120 } 121 if test.expected == nil { 122 gomega.Expect(actual).To(gomega.BeNil()) 123 } else { 124 gomega.Expect(actual).To(gomega.ConsistOf(test.expected)) 125 } 126 } 127 }