volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/pytorch/pytorch_test.go (about) 1 package pytorch 2 3 import ( 4 "fmt" 5 "reflect" 6 "testing" 7 8 v1 "k8s.io/api/core/v1" 9 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 10 11 "volcano.sh/apis/pkg/apis/batch/v1alpha1" 12 pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface" 13 ) 14 15 func TestPytorch(t *testing.T) { 16 plugins := make(map[string][]string) 17 plugins[PytorchPluginName] = []string{"--port=5000"} 18 19 testcases := []struct { 20 Name string 21 Job *v1alpha1.Job 22 Pod *v1.Pod 23 port int 24 envs []v1.EnvVar 25 }{ 26 { 27 Name: "test pod without master", 28 Job: &v1alpha1.Job{ 29 ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"}, 30 Spec: v1alpha1.JobSpec{ 31 Tasks: []v1alpha1.TaskSpec{ 32 { 33 Name: "worker", 34 Replicas: 1, 35 Template: v1.PodTemplateSpec{}, 36 }, 37 }, 38 }, 39 }, 40 Pod: &v1.Pod{ 41 ObjectMeta: metav1.ObjectMeta{ 42 Name: "test-pytorch-worker-0", 43 }, 44 Spec: v1.PodSpec{ 45 Containers: []v1.Container{ 46 { 47 Name: "worker", 48 }, 49 }, 50 }, 51 }, 52 port: -1, 53 envs: nil, 54 }, 55 { 56 Name: "test master pod without port", 57 Job: &v1alpha1.Job{ 58 ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"}, 59 Spec: v1alpha1.JobSpec{ 60 Tasks: []v1alpha1.TaskSpec{ 61 { 62 Name: "master", 63 Replicas: 1, 64 Template: v1.PodTemplateSpec{}, 65 }, 66 { 67 Name: "worker", 68 Replicas: 1, 69 Template: v1.PodTemplateSpec{}, 70 }, 71 }, 72 }, 73 }, 74 Pod: &v1.Pod{ 75 ObjectMeta: metav1.ObjectMeta{ 76 Name: "test-pytorch-master-0", 77 Annotations: map[string]string{ 78 v1alpha1.TaskSpecKey: "master", 79 }, 80 }, 81 Spec: v1.PodSpec{ 82 Containers: []v1.Container{ 83 { 84 Name: "master", 85 }, 86 }, 87 }, 88 }, 89 port: DefaultPort, 90 envs: []v1.EnvVar{ 91 { 92 Name: EnvMasterAddr, 93 Value: "test-pytorch-master-0.test-pytorch", 94 }, 95 { 96 Name: EnvMasterPort, 97 Value: fmt.Sprintf("%v", DefaultPort), 98 }, 99 { 100 Name: "WORLD_SIZE", 101 Value: fmt.Sprintf("%v", 2), 102 }, 103 { 104 Name: "RANK", 105 Value: fmt.Sprintf("%v", 0), 106 }, 107 }, 108 }, 109 { 110 Name: "test master pod with port", 111 Job: &v1alpha1.Job{ 112 ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"}, 113 Spec: v1alpha1.JobSpec{ 114 Tasks: []v1alpha1.TaskSpec{ 115 { 116 Name: "master", 117 Replicas: 1, 118 Template: v1.PodTemplateSpec{}, 119 }, 120 { 121 Name: "worker", 122 Replicas: 1, 123 Template: v1.PodTemplateSpec{}, 124 }, 125 }, 126 }, 127 }, 128 Pod: &v1.Pod{ 129 ObjectMeta: metav1.ObjectMeta{ 130 Name: "test-pytorch-master-0", 131 Annotations: map[string]string{ 132 v1alpha1.TaskSpecKey: "master", 133 }, 134 }, 135 Spec: v1.PodSpec{ 136 Containers: []v1.Container{ 137 { 138 Name: "master", 139 Ports: []v1.ContainerPort{ 140 { 141 Name: "pytorchjob-port", 142 ContainerPort: 23456, 143 }, 144 }, 145 }, 146 }, 147 }, 148 }, 149 port: DefaultPort, 150 envs: []v1.EnvVar{ 151 { 152 Name: EnvMasterAddr, 153 Value: "test-pytorch-master-0.test-pytorch", 154 }, 155 { 156 Name: EnvMasterPort, 157 Value: fmt.Sprintf("%v", DefaultPort), 158 }, 159 { 160 Name: "WORLD_SIZE", 161 Value: fmt.Sprintf("%v", 2), 162 }, 163 { 164 Name: "RANK", 165 Value: fmt.Sprintf("%v", 0), 166 }, 167 }, 168 }, 169 { 170 Name: "test master pod env", 171 Job: &v1alpha1.Job{ 172 ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"}, 173 Spec: v1alpha1.JobSpec{ 174 Tasks: []v1alpha1.TaskSpec{ 175 { 176 Name: "master", 177 Replicas: 1, 178 Template: v1.PodTemplateSpec{}, 179 }, 180 { 181 Name: "worker", 182 Replicas: 2, 183 Template: v1.PodTemplateSpec{}, 184 }, 185 }, 186 }, 187 }, 188 Pod: &v1.Pod{ 189 ObjectMeta: metav1.ObjectMeta{ 190 Name: "test-pytorch-master-0", 191 Annotations: map[string]string{ 192 v1alpha1.TaskSpecKey: "master", 193 }, 194 }, 195 Spec: v1.PodSpec{ 196 Containers: []v1.Container{ 197 { 198 Name: "master", 199 Ports: []v1.ContainerPort{ 200 { 201 Name: "pytorchjob-port", 202 ContainerPort: 123, 203 }, 204 }, 205 }, 206 }, 207 }, 208 }, 209 port: 123, 210 envs: []v1.EnvVar{ 211 { 212 Name: EnvMasterAddr, 213 Value: "test-pytorch-master-0.test-pytorch", 214 }, 215 { 216 Name: EnvMasterPort, 217 Value: fmt.Sprintf("%v", DefaultPort), 218 }, 219 { 220 Name: "WORLD_SIZE", 221 Value: fmt.Sprintf("%v", 3), 222 }, 223 { 224 Name: "RANK", 225 Value: fmt.Sprintf("%v", 0), 226 }, 227 }, 228 }, 229 { 230 Name: "test worker-1 pod env", 231 Job: &v1alpha1.Job{ 232 ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"}, 233 Spec: v1alpha1.JobSpec{ 234 Tasks: []v1alpha1.TaskSpec{ 235 { 236 Name: "master", 237 Replicas: 1, 238 Template: v1.PodTemplateSpec{}, 239 }, 240 { 241 Name: "worker", 242 Replicas: 2, 243 Template: v1.PodTemplateSpec{}, 244 }, 245 }, 246 }, 247 }, 248 Pod: &v1.Pod{ 249 ObjectMeta: metav1.ObjectMeta{ 250 Name: "test-pytorch-worker-0", 251 Annotations: map[string]string{ 252 v1alpha1.TaskSpecKey: "worker", 253 }, 254 }, 255 Spec: v1.PodSpec{ 256 Containers: []v1.Container{ 257 { 258 Name: "worker", 259 Ports: []v1.ContainerPort{ 260 { 261 Name: "pytorchjob-port", 262 ContainerPort: 123, 263 }, 264 }, 265 }, 266 }, 267 }, 268 }, 269 port: 123, 270 envs: []v1.EnvVar{ 271 { 272 Name: EnvMasterAddr, 273 Value: "test-pytorch-master-0.test-pytorch", 274 }, 275 { 276 Name: EnvMasterPort, 277 Value: fmt.Sprintf("%v", DefaultPort), 278 }, 279 { 280 Name: "WORLD_SIZE", 281 Value: fmt.Sprintf("%v", 3), 282 }, 283 { 284 Name: "RANK", 285 Value: fmt.Sprintf("%v", 1), 286 }, 287 }, 288 }, 289 { 290 Name: "test worker-2 pod env", 291 Job: &v1alpha1.Job{ 292 ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"}, 293 Spec: v1alpha1.JobSpec{ 294 Tasks: []v1alpha1.TaskSpec{ 295 { 296 Name: "master", 297 Replicas: 1, 298 Template: v1.PodTemplateSpec{}, 299 }, 300 { 301 Name: "worker", 302 Replicas: 2, 303 Template: v1.PodTemplateSpec{}, 304 }, 305 }, 306 }, 307 }, 308 Pod: &v1.Pod{ 309 ObjectMeta: metav1.ObjectMeta{ 310 Name: "test-pytorch-worker-1", 311 Annotations: map[string]string{ 312 v1alpha1.TaskSpecKey: "worker", 313 }, 314 }, 315 Spec: v1.PodSpec{ 316 Containers: []v1.Container{ 317 { 318 Name: "worker", 319 Ports: []v1.ContainerPort{ 320 { 321 Name: "pytorchjob-port", 322 ContainerPort: 123, 323 }, 324 }, 325 }, 326 }, 327 }, 328 }, 329 port: 123, 330 envs: []v1.EnvVar{ 331 { 332 Name: EnvMasterAddr, 333 Value: "test-pytorch-master-0.test-pytorch", 334 }, 335 { 336 Name: EnvMasterPort, 337 Value: fmt.Sprintf("%v", DefaultPort), 338 }, 339 { 340 Name: "WORLD_SIZE", 341 Value: fmt.Sprintf("%v", 3), 342 }, 343 { 344 Name: "RANK", 345 Value: fmt.Sprintf("%v", 2), 346 }, 347 }, 348 }, 349 } 350 351 for index, testcase := range testcases { 352 t.Run(testcase.Name, func(t *testing.T) { 353 mp := New(pluginsinterface.PluginClientset{}, testcase.Job.Spec.Plugins[PytorchPluginName]) 354 if err := mp.OnPodCreate(testcase.Pod, testcase.Job); err != nil { 355 t.Errorf("Case %d (%s): expect no error, but got error %v", index, testcase.Name, err) 356 } 357 358 if testcase.port != -1 { 359 if testcase.Pod.Spec.Containers[0].Ports == nil || testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort != int32(testcase.port) { 360 t.Errorf("Case %d (%s): wrong port, got %d, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort, testcase.port) 361 } 362 } else { 363 if testcase.Pod.Spec.Containers[0].Ports != nil { 364 t.Errorf("Case %d (%s): wrong port, got %d, expected empty", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort) 365 } 366 } 367 368 if !reflect.DeepEqual(testcase.Pod.Spec.Containers[0].Env, testcase.envs) { 369 t.Errorf("Case %d (%s): wrong envs, got %v, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Env, testcase.envs) 370 } 371 }) 372 } 373 }