github.com/kubeflow/training-operator@v1.7.0/pkg/apis/kubeflow.org/v1/pytorch_defaults_test.go (about) 1 package v1 2 3 import ( 4 "testing" 5 6 "github.com/onsi/ginkgo/v2" 7 "github.com/onsi/gomega" 8 "k8s.io/utils/pointer" 9 ) 10 11 func TestSetElasticPolicy(t *testing.T) { 12 gomega.RegisterFailHandler(ginkgo.Fail) 13 14 type args struct { 15 job *PyTorchJob 16 } 17 type result struct { 18 expectedMinReplicas *int32 19 expectedMaxReplicas *int32 20 } 21 tests := []struct { 22 name string 23 args args 24 result result 25 }{ 26 { 27 name: "minReplicas and maxReplicas to null", 28 args: args{ 29 job: &PyTorchJob{ 30 Spec: PyTorchJobSpec{ 31 ElasticPolicy: &ElasticPolicy{}, 32 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 33 PyTorchJobReplicaTypeWorker: { 34 Replicas: pointer.Int32(1), 35 }, 36 }, 37 }, 38 }, 39 }, 40 result: result{ 41 expectedMinReplicas: pointer.Int32(1), 42 expectedMaxReplicas: pointer.Int32(1), 43 }, 44 }, 45 { 46 name: "minReplicas and maxReplicas to 1", 47 args: args{ 48 job: &PyTorchJob{ 49 Spec: PyTorchJobSpec{ 50 ElasticPolicy: &ElasticPolicy{ 51 MaxReplicas: pointer.Int32(1), 52 MinReplicas: pointer.Int32(1), 53 }, 54 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 55 PyTorchJobReplicaTypeWorker: { 56 Replicas: pointer.Int32(1), 57 }, 58 }, 59 }, 60 }, 61 }, 62 result: result{ 63 expectedMinReplicas: pointer.Int32(1), 64 expectedMaxReplicas: pointer.Int32(1), 65 }, 66 }, 67 { 68 name: "minReplicas and maxReplicas to 1", 69 args: args{ 70 job: &PyTorchJob{ 71 Spec: PyTorchJobSpec{ 72 ElasticPolicy: &ElasticPolicy{ 73 MaxReplicas: pointer.Int32(1), 74 MinReplicas: pointer.Int32(1), 75 }, 76 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 77 PyTorchJobReplicaTypeWorker: { 78 Replicas: pointer.Int32(1), 79 }, 80 }, 81 }, 82 }, 83 }, 84 result: result{ 85 expectedMinReplicas: pointer.Int32(1), 86 expectedMaxReplicas: pointer.Int32(1), 87 }, 88 }, 89 { 90 name: "minReplicas to null, maxRepliacs to 1", 91 args: args{ 92 job: &PyTorchJob{ 93 Spec: PyTorchJobSpec{ 94 ElasticPolicy: &ElasticPolicy{ 95 MaxReplicas: pointer.Int32(1), 96 MinReplicas: nil, 97 }, 98 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 99 PyTorchJobReplicaTypeWorker: { 100 Replicas: pointer.Int32(1), 101 }, 102 }, 103 }, 104 }, 105 }, 106 result: result{ 107 expectedMinReplicas: pointer.Int32(1), 108 expectedMaxReplicas: pointer.Int32(1), 109 }, 110 }, 111 { 112 name: "maxRepliacs to null, minReplicas to 1", 113 args: args{ 114 job: &PyTorchJob{ 115 Spec: PyTorchJobSpec{ 116 ElasticPolicy: &ElasticPolicy{ 117 MaxReplicas: nil, 118 MinReplicas: pointer.Int32(1), 119 }, 120 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 121 PyTorchJobReplicaTypeWorker: { 122 Replicas: pointer.Int32(1), 123 }, 124 }, 125 }, 126 }, 127 }, 128 result: result{ 129 expectedMinReplicas: pointer.Int32(1), 130 expectedMaxReplicas: pointer.Int32(1), 131 }, 132 }, 133 } 134 for _, test := range tests { 135 t.Run(test.name, func(t *testing.T) { 136 setElasticPolicy(test.args.job) 137 if test.result.expectedMinReplicas != nil { 138 gomega.Expect(test.args.job.Spec.ElasticPolicy.MinReplicas). 139 To(gomega.Equal(test.result.expectedMinReplicas)) 140 } else { 141 gomega.Expect(test.args.job.Spec.ElasticPolicy.MinReplicas). 142 To(gomega.BeNil()) 143 } 144 145 if test.result.expectedMaxReplicas != nil { 146 gomega.Expect(test.args.job.Spec.ElasticPolicy.MaxReplicas). 147 To(gomega.Equal(test.result.expectedMaxReplicas)) 148 } else { 149 gomega.Expect(test.args.job.Spec.ElasticPolicy.MaxReplicas). 150 To(gomega.BeNil()) 151 } 152 }) 153 } 154 } 155 156 func TestSetDefaultNprocPerNode(t *testing.T) { 157 gomega.RegisterFailHandler(ginkgo.Fail) 158 t.Run("test default nproc per node", func(t *testing.T) { 159 job := &PyTorchJob{ 160 Spec: PyTorchJobSpec{ 161 ElasticPolicy: &ElasticPolicy{ 162 NProcPerNode: nil, 163 }, 164 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 165 PyTorchJobReplicaTypeWorker: { 166 Replicas: pointer.Int32(1), 167 }, 168 }, 169 }, 170 } 171 172 setDefaultNprocPerNode(job) 173 gomega.Expect(job.Spec.NprocPerNode). 174 To(gomega.Equal(&DefaultNprocPerNode)) 175 }) 176 t.Run("test default nproc per node", func(t *testing.T) { 177 job := &PyTorchJob{ 178 Spec: PyTorchJobSpec{ 179 ElasticPolicy: nil, 180 PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 181 PyTorchJobReplicaTypeWorker: { 182 Replicas: pointer.Int32(1), 183 }, 184 }, 185 }, 186 } 187 188 setDefaultNprocPerNode(job) 189 gomega.Expect(job.Spec.NprocPerNode). 190 To(gomega.Equal(&DefaultNprocPerNode)) 191 }) 192 }