volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/pytorch/pytorch_test.go (about)

     1  package pytorch
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"testing"
     7  
     8  	v1 "k8s.io/api/core/v1"
     9  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    10  
    11  	"volcano.sh/apis/pkg/apis/batch/v1alpha1"
    12  	pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface"
    13  )
    14  
    15  func TestPytorch(t *testing.T) {
    16  	plugins := make(map[string][]string)
    17  	plugins[PytorchPluginName] = []string{"--port=5000"}
    18  
    19  	testcases := []struct {
    20  		Name string
    21  		Job  *v1alpha1.Job
    22  		Pod  *v1.Pod
    23  		port int
    24  		envs []v1.EnvVar
    25  	}{
    26  		{
    27  			Name: "test pod without master",
    28  			Job: &v1alpha1.Job{
    29  				ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
    30  				Spec: v1alpha1.JobSpec{
    31  					Tasks: []v1alpha1.TaskSpec{
    32  						{
    33  							Name:     "worker",
    34  							Replicas: 1,
    35  							Template: v1.PodTemplateSpec{},
    36  						},
    37  					},
    38  				},
    39  			},
    40  			Pod: &v1.Pod{
    41  				ObjectMeta: metav1.ObjectMeta{
    42  					Name: "test-pytorch-worker-0",
    43  				},
    44  				Spec: v1.PodSpec{
    45  					Containers: []v1.Container{
    46  						{
    47  							Name: "worker",
    48  						},
    49  					},
    50  				},
    51  			},
    52  			port: -1,
    53  			envs: nil,
    54  		},
    55  		{
    56  			Name: "test master pod without port",
    57  			Job: &v1alpha1.Job{
    58  				ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
    59  				Spec: v1alpha1.JobSpec{
    60  					Tasks: []v1alpha1.TaskSpec{
    61  						{
    62  							Name:     "master",
    63  							Replicas: 1,
    64  							Template: v1.PodTemplateSpec{},
    65  						},
    66  						{
    67  							Name:     "worker",
    68  							Replicas: 1,
    69  							Template: v1.PodTemplateSpec{},
    70  						},
    71  					},
    72  				},
    73  			},
    74  			Pod: &v1.Pod{
    75  				ObjectMeta: metav1.ObjectMeta{
    76  					Name: "test-pytorch-master-0",
    77  					Annotations: map[string]string{
    78  						v1alpha1.TaskSpecKey: "master",
    79  					},
    80  				},
    81  				Spec: v1.PodSpec{
    82  					Containers: []v1.Container{
    83  						{
    84  							Name: "master",
    85  						},
    86  					},
    87  				},
    88  			},
    89  			port: DefaultPort,
    90  			envs: []v1.EnvVar{
    91  				{
    92  					Name:  EnvMasterAddr,
    93  					Value: "test-pytorch-master-0.test-pytorch",
    94  				},
    95  				{
    96  					Name:  EnvMasterPort,
    97  					Value: fmt.Sprintf("%v", DefaultPort),
    98  				},
    99  				{
   100  					Name:  "WORLD_SIZE",
   101  					Value: fmt.Sprintf("%v", 2),
   102  				},
   103  				{
   104  					Name:  "RANK",
   105  					Value: fmt.Sprintf("%v", 0),
   106  				},
   107  			},
   108  		},
   109  		{
   110  			Name: "test master pod with port",
   111  			Job: &v1alpha1.Job{
   112  				ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
   113  				Spec: v1alpha1.JobSpec{
   114  					Tasks: []v1alpha1.TaskSpec{
   115  						{
   116  							Name:     "master",
   117  							Replicas: 1,
   118  							Template: v1.PodTemplateSpec{},
   119  						},
   120  						{
   121  							Name:     "worker",
   122  							Replicas: 1,
   123  							Template: v1.PodTemplateSpec{},
   124  						},
   125  					},
   126  				},
   127  			},
   128  			Pod: &v1.Pod{
   129  				ObjectMeta: metav1.ObjectMeta{
   130  					Name: "test-pytorch-master-0",
   131  					Annotations: map[string]string{
   132  						v1alpha1.TaskSpecKey: "master",
   133  					},
   134  				},
   135  				Spec: v1.PodSpec{
   136  					Containers: []v1.Container{
   137  						{
   138  							Name: "master",
   139  							Ports: []v1.ContainerPort{
   140  								{
   141  									Name:          "pytorchjob-port",
   142  									ContainerPort: 23456,
   143  								},
   144  							},
   145  						},
   146  					},
   147  				},
   148  			},
   149  			port: DefaultPort,
   150  			envs: []v1.EnvVar{
   151  				{
   152  					Name:  EnvMasterAddr,
   153  					Value: "test-pytorch-master-0.test-pytorch",
   154  				},
   155  				{
   156  					Name:  EnvMasterPort,
   157  					Value: fmt.Sprintf("%v", DefaultPort),
   158  				},
   159  				{
   160  					Name:  "WORLD_SIZE",
   161  					Value: fmt.Sprintf("%v", 2),
   162  				},
   163  				{
   164  					Name:  "RANK",
   165  					Value: fmt.Sprintf("%v", 0),
   166  				},
   167  			},
   168  		},
   169  		{
   170  			Name: "test master pod env",
   171  			Job: &v1alpha1.Job{
   172  				ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
   173  				Spec: v1alpha1.JobSpec{
   174  					Tasks: []v1alpha1.TaskSpec{
   175  						{
   176  							Name:     "master",
   177  							Replicas: 1,
   178  							Template: v1.PodTemplateSpec{},
   179  						},
   180  						{
   181  							Name:     "worker",
   182  							Replicas: 2,
   183  							Template: v1.PodTemplateSpec{},
   184  						},
   185  					},
   186  				},
   187  			},
   188  			Pod: &v1.Pod{
   189  				ObjectMeta: metav1.ObjectMeta{
   190  					Name: "test-pytorch-master-0",
   191  					Annotations: map[string]string{
   192  						v1alpha1.TaskSpecKey: "master",
   193  					},
   194  				},
   195  				Spec: v1.PodSpec{
   196  					Containers: []v1.Container{
   197  						{
   198  							Name: "master",
   199  							Ports: []v1.ContainerPort{
   200  								{
   201  									Name:          "pytorchjob-port",
   202  									ContainerPort: 123,
   203  								},
   204  							},
   205  						},
   206  					},
   207  				},
   208  			},
   209  			port: 123,
   210  			envs: []v1.EnvVar{
   211  				{
   212  					Name:  EnvMasterAddr,
   213  					Value: "test-pytorch-master-0.test-pytorch",
   214  				},
   215  				{
   216  					Name:  EnvMasterPort,
   217  					Value: fmt.Sprintf("%v", DefaultPort),
   218  				},
   219  				{
   220  					Name:  "WORLD_SIZE",
   221  					Value: fmt.Sprintf("%v", 3),
   222  				},
   223  				{
   224  					Name:  "RANK",
   225  					Value: fmt.Sprintf("%v", 0),
   226  				},
   227  			},
   228  		},
   229  		{
   230  			Name: "test worker-1 pod env",
   231  			Job: &v1alpha1.Job{
   232  				ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
   233  				Spec: v1alpha1.JobSpec{
   234  					Tasks: []v1alpha1.TaskSpec{
   235  						{
   236  							Name:     "master",
   237  							Replicas: 1,
   238  							Template: v1.PodTemplateSpec{},
   239  						},
   240  						{
   241  							Name:     "worker",
   242  							Replicas: 2,
   243  							Template: v1.PodTemplateSpec{},
   244  						},
   245  					},
   246  				},
   247  			},
   248  			Pod: &v1.Pod{
   249  				ObjectMeta: metav1.ObjectMeta{
   250  					Name: "test-pytorch-worker-0",
   251  					Annotations: map[string]string{
   252  						v1alpha1.TaskSpecKey: "worker",
   253  					},
   254  				},
   255  				Spec: v1.PodSpec{
   256  					Containers: []v1.Container{
   257  						{
   258  							Name: "worker",
   259  							Ports: []v1.ContainerPort{
   260  								{
   261  									Name:          "pytorchjob-port",
   262  									ContainerPort: 123,
   263  								},
   264  							},
   265  						},
   266  					},
   267  				},
   268  			},
   269  			port: 123,
   270  			envs: []v1.EnvVar{
   271  				{
   272  					Name:  EnvMasterAddr,
   273  					Value: "test-pytorch-master-0.test-pytorch",
   274  				},
   275  				{
   276  					Name:  EnvMasterPort,
   277  					Value: fmt.Sprintf("%v", DefaultPort),
   278  				},
   279  				{
   280  					Name:  "WORLD_SIZE",
   281  					Value: fmt.Sprintf("%v", 3),
   282  				},
   283  				{
   284  					Name:  "RANK",
   285  					Value: fmt.Sprintf("%v", 1),
   286  				},
   287  			},
   288  		},
   289  		{
   290  			Name: "test worker-2 pod env",
   291  			Job: &v1alpha1.Job{
   292  				ObjectMeta: metav1.ObjectMeta{Name: "test-pytorch"},
   293  				Spec: v1alpha1.JobSpec{
   294  					Tasks: []v1alpha1.TaskSpec{
   295  						{
   296  							Name:     "master",
   297  							Replicas: 1,
   298  							Template: v1.PodTemplateSpec{},
   299  						},
   300  						{
   301  							Name:     "worker",
   302  							Replicas: 2,
   303  							Template: v1.PodTemplateSpec{},
   304  						},
   305  					},
   306  				},
   307  			},
   308  			Pod: &v1.Pod{
   309  				ObjectMeta: metav1.ObjectMeta{
   310  					Name: "test-pytorch-worker-1",
   311  					Annotations: map[string]string{
   312  						v1alpha1.TaskSpecKey: "worker",
   313  					},
   314  				},
   315  				Spec: v1.PodSpec{
   316  					Containers: []v1.Container{
   317  						{
   318  							Name: "worker",
   319  							Ports: []v1.ContainerPort{
   320  								{
   321  									Name:          "pytorchjob-port",
   322  									ContainerPort: 123,
   323  								},
   324  							},
   325  						},
   326  					},
   327  				},
   328  			},
   329  			port: 123,
   330  			envs: []v1.EnvVar{
   331  				{
   332  					Name:  EnvMasterAddr,
   333  					Value: "test-pytorch-master-0.test-pytorch",
   334  				},
   335  				{
   336  					Name:  EnvMasterPort,
   337  					Value: fmt.Sprintf("%v", DefaultPort),
   338  				},
   339  				{
   340  					Name:  "WORLD_SIZE",
   341  					Value: fmt.Sprintf("%v", 3),
   342  				},
   343  				{
   344  					Name:  "RANK",
   345  					Value: fmt.Sprintf("%v", 2),
   346  				},
   347  			},
   348  		},
   349  	}
   350  
   351  	for index, testcase := range testcases {
   352  		t.Run(testcase.Name, func(t *testing.T) {
   353  			mp := New(pluginsinterface.PluginClientset{}, testcase.Job.Spec.Plugins[PytorchPluginName])
   354  			if err := mp.OnPodCreate(testcase.Pod, testcase.Job); err != nil {
   355  				t.Errorf("Case %d (%s): expect no error, but got error %v", index, testcase.Name, err)
   356  			}
   357  
   358  			if testcase.port != -1 {
   359  				if testcase.Pod.Spec.Containers[0].Ports == nil || testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort != int32(testcase.port) {
   360  					t.Errorf("Case %d (%s): wrong port, got %d, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort, testcase.port)
   361  				}
   362  			} else {
   363  				if testcase.Pod.Spec.Containers[0].Ports != nil {
   364  					t.Errorf("Case %d (%s): wrong port, got %d, expected empty", index, testcase.Name, testcase.Pod.Spec.Containers[0].Ports[0].ContainerPort)
   365  				}
   366  			}
   367  
   368  			if !reflect.DeepEqual(testcase.Pod.Spec.Containers[0].Env, testcase.envs) {
   369  				t.Errorf("Case %d (%s): wrong envs, got %v, expected %v", index, testcase.Name, testcase.Pod.Spec.Containers[0].Env, testcase.envs)
   370  			}
   371  		})
   372  	}
   373  }