sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/rayjob/rayjob_controller_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package rayjob
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/google/go-cmp/cmp"
    23  	"github.com/google/go-cmp/cmp/cmpopts"
    24  	rayjobapi "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
    25  	corev1 "k8s.io/api/core/v1"
    26  	"k8s.io/utils/ptr"
    27  
    28  	kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
    29  	"sigs.k8s.io/kueue/pkg/podset"
    30  	testingrayutil "sigs.k8s.io/kueue/pkg/util/testingjobs/rayjob"
    31  )
    32  
    33  func TestPodSets(t *testing.T) {
    34  	job := testingrayutil.MakeJob("job", "ns").
    35  		WithHeadGroupSpec(
    36  			rayjobapi.HeadGroupSpec{
    37  				Template: corev1.PodTemplateSpec{
    38  					Spec: corev1.PodSpec{
    39  						Containers: []corev1.Container{
    40  							{
    41  								Name: "head_c",
    42  							},
    43  						},
    44  					},
    45  				},
    46  			},
    47  		).
    48  		WithWorkerGroups(
    49  			rayjobapi.WorkerGroupSpec{
    50  				GroupName: "group1",
    51  				Template: corev1.PodTemplateSpec{
    52  					Spec: corev1.PodSpec{
    53  						Containers: []corev1.Container{
    54  							{
    55  								Name: "group1_c",
    56  							},
    57  						},
    58  					},
    59  				},
    60  			},
    61  			rayjobapi.WorkerGroupSpec{
    62  				GroupName: "group2",
    63  				Replicas:  ptr.To[int32](3),
    64  				Template: corev1.PodTemplateSpec{
    65  					Spec: corev1.PodSpec{
    66  						Containers: []corev1.Container{
    67  							{
    68  								Name: "group2_c",
    69  							},
    70  						},
    71  					},
    72  				},
    73  			},
    74  		).
    75  		Obj()
    76  
    77  	wantPodSets := []kueue.PodSet{
    78  		{
    79  			Name:  "head",
    80  			Count: 1,
    81  			Template: corev1.PodTemplateSpec{
    82  				Spec: corev1.PodSpec{
    83  					Containers: []corev1.Container{
    84  						{
    85  							Name: "head_c",
    86  						},
    87  					},
    88  				},
    89  			},
    90  		},
    91  		{
    92  			Name:  "group1",
    93  			Count: 1,
    94  			Template: corev1.PodTemplateSpec{
    95  				Spec: corev1.PodSpec{
    96  					Containers: []corev1.Container{
    97  						{
    98  							Name: "group1_c",
    99  						},
   100  					},
   101  				},
   102  			},
   103  		},
   104  		{
   105  			Name:  "group2",
   106  			Count: 3,
   107  			Template: corev1.PodTemplateSpec{
   108  				Spec: corev1.PodSpec{
   109  					Containers: []corev1.Container{
   110  						{
   111  							Name: "group2_c",
   112  						},
   113  					},
   114  				},
   115  			},
   116  		},
   117  	}
   118  
   119  	result := ((*RayJob)(job)).PodSets()
   120  
   121  	if diff := cmp.Diff(wantPodSets, result); diff != "" {
   122  		t.Errorf("PodSets() mismatch (-want +got):\n%s", diff)
   123  	}
   124  }
   125  
   126  func TestNodeSelectors(t *testing.T) {
   127  	baseJob := testingrayutil.MakeJob("job", "ns").
   128  		WithHeadGroupSpec(rayjobapi.HeadGroupSpec{
   129  			Template: corev1.PodTemplateSpec{
   130  				Spec: corev1.PodSpec{
   131  					NodeSelector: map[string]string{},
   132  				},
   133  			},
   134  		}).
   135  		WithWorkerGroups(rayjobapi.WorkerGroupSpec{
   136  			Template: corev1.PodTemplateSpec{
   137  				Spec: corev1.PodSpec{
   138  					NodeSelector: map[string]string{
   139  						"key-wg1": "value-wg1",
   140  					},
   141  				},
   142  			},
   143  		}, rayjobapi.WorkerGroupSpec{
   144  			Template: corev1.PodTemplateSpec{
   145  				Spec: corev1.PodSpec{
   146  					NodeSelector: map[string]string{
   147  						"key-wg2": "value-wg2",
   148  					},
   149  				},
   150  			},
   151  		}).
   152  		Obj()
   153  
   154  	cases := map[string]struct {
   155  		job          *rayjobapi.RayJob
   156  		runInfo      []podset.PodSetInfo
   157  		restoreInfo  []podset.PodSetInfo
   158  		wantRunError error
   159  		wantAfterRun *rayjobapi.RayJob
   160  		wantFinal    *rayjobapi.RayJob
   161  	}{
   162  		"valid configuration": {
   163  			job: baseJob.DeepCopy(),
   164  			runInfo: []podset.PodSetInfo{
   165  				{
   166  					NodeSelector: map[string]string{
   167  						"newKey": "newValue",
   168  					},
   169  				},
   170  				{
   171  					NodeSelector: map[string]string{
   172  						"key-wg1": "value-wg1",
   173  					},
   174  				},
   175  				{
   176  					NodeSelector: map[string]string{
   177  						// don't add anything
   178  					},
   179  				},
   180  			},
   181  			restoreInfo: []podset.PodSetInfo{
   182  				{
   183  					NodeSelector: map[string]string{
   184  						// clean it all
   185  					},
   186  				},
   187  				{
   188  					NodeSelector: map[string]string{
   189  						"key-wg1": "value-wg1",
   190  					},
   191  				},
   192  				{
   193  					NodeSelector: map[string]string{
   194  						"key-wg2": "value-wg2",
   195  					},
   196  				},
   197  			},
   198  			wantAfterRun: testingrayutil.MakeJob("job", "ns").
   199  				Suspend(false).
   200  				WithHeadGroupSpec(rayjobapi.HeadGroupSpec{
   201  					Template: corev1.PodTemplateSpec{
   202  						Spec: corev1.PodSpec{
   203  							NodeSelector: map[string]string{
   204  								"newKey": "newValue",
   205  							},
   206  						},
   207  					},
   208  				}).
   209  				WithWorkerGroups(rayjobapi.WorkerGroupSpec{
   210  					Template: corev1.PodTemplateSpec{
   211  						Spec: corev1.PodSpec{
   212  							NodeSelector: map[string]string{
   213  								"key-wg1": "value-wg1",
   214  							},
   215  						},
   216  					},
   217  				}, rayjobapi.WorkerGroupSpec{
   218  					Template: corev1.PodTemplateSpec{
   219  						Spec: corev1.PodSpec{
   220  							NodeSelector: map[string]string{
   221  								"key-wg2": "value-wg2",
   222  							},
   223  						},
   224  					},
   225  				}).
   226  				Obj(),
   227  
   228  			wantFinal: baseJob.DeepCopy(),
   229  		},
   230  		"invalid runInfo": {
   231  			job: baseJob.DeepCopy(),
   232  			runInfo: []podset.PodSetInfo{
   233  				{
   234  					NodeSelector: map[string]string{
   235  						"newKey": "newValue",
   236  					},
   237  				},
   238  				{
   239  					NodeSelector: map[string]string{
   240  						"key-wg1": "updated-value-wg1",
   241  					},
   242  				},
   243  			},
   244  			wantRunError: podset.ErrInvalidPodsetInfo,
   245  			wantAfterRun: baseJob.DeepCopy(),
   246  		},
   247  	}
   248  
   249  	for name, tc := range cases {
   250  		t.Run(name, func(t *testing.T) {
   251  			genJob := (*RayJob)(tc.job)
   252  			gotRunError := genJob.RunWithPodSetsInfo(tc.runInfo)
   253  
   254  			if diff := cmp.Diff(tc.wantRunError, gotRunError, cmpopts.EquateErrors()); diff != "" {
   255  				t.Errorf("Unexpected run error (-want/+got): %s", diff)
   256  			}
   257  			if diff := cmp.Diff(tc.wantAfterRun, tc.job); diff != "" {
   258  				t.Errorf("Unexpected job after run (-want/+got): %s", diff)
   259  			}
   260  
   261  			if tc.wantRunError == nil {
   262  				genJob.Suspend()
   263  				genJob.RestorePodSetsInfo(tc.restoreInfo)
   264  				if diff := cmp.Diff(tc.wantFinal, tc.job); diff != "" {
   265  					t.Errorf("Unexpected job after restore (-want/+got): %s", diff)
   266  				}
   267  			}
   268  		})
   269  	}
   270  }