github.com/kubeflow/training-operator@v1.7.0/pkg/apis/kubeflow.org/v1/xgboost_defaults_test.go (about) 1 // Copyright 2018 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package v1 16 17 import ( 18 "reflect" 19 "testing" 20 21 corev1 "k8s.io/api/core/v1" 22 "k8s.io/utils/pointer" 23 ) 24 25 func expectedXGBoostJob(cleanPodPolicy CleanPodPolicy, restartPolicy RestartPolicy, replicas int32, portName string, port int32) *XGBoostJob { 26 var ports []corev1.ContainerPort 27 28 // port not set 29 if portName != "" { 30 ports = append(ports, 31 corev1.ContainerPort{ 32 Name: portName, 33 ContainerPort: port, 34 }, 35 ) 36 } 37 38 // port set with custom name 39 if portName != XGBoostJobDefaultPortName { 40 ports = append(ports, 41 corev1.ContainerPort{ 42 Name: XGBoostJobDefaultPortName, 43 ContainerPort: XGBoostJobDefaultPort, 44 }, 45 ) 46 } 47 48 return &XGBoostJob{ 49 Spec: XGBoostJobSpec{ 50 RunPolicy: RunPolicy{ 51 CleanPodPolicy: &cleanPodPolicy, 52 }, 53 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 54 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 55 Replicas: pointer.Int32(replicas), 56 RestartPolicy: restartPolicy, 57 Template: corev1.PodTemplateSpec{ 58 Spec: corev1.PodSpec{ 59 Containers: []corev1.Container{ 60 corev1.Container{ 61 Name: XGBoostJobDefaultContainerName, 62 Image: testImage, 63 Ports: ports, 64 }, 65 }, 66 }, 67 }, 68 }, 69 }, 70 }, 71 } 72 } 73 74 func TestSetDefaults_XGBoostJob(t *testing.T) { 75 testCases := map[string]struct { 76 original *XGBoostJob 77 expected *XGBoostJob 78 }{ 79 "set spec with minimum setting": { 80 original: &XGBoostJob{ 81 Spec: XGBoostJobSpec{ 82 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 83 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 84 Template: corev1.PodTemplateSpec{ 85 Spec: corev1.PodSpec{ 86 Containers: []corev1.Container{ 87 corev1.Container{ 88 Name: XGBoostJobDefaultContainerName, 89 Image: testImage, 90 }, 91 }, 92 }, 93 }, 94 }, 95 }, 96 }, 97 }, 98 expected: expectedXGBoostJob(CleanPodPolicyNone, XGBoostJobDefaultRestartPolicy, 1, XGBoostJobDefaultPortName, XGBoostJobDefaultPort), 99 }, 100 "Set spec with restart policy": { 101 original: &XGBoostJob{ 102 Spec: XGBoostJobSpec{ 103 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 104 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 105 RestartPolicy: RestartPolicyOnFailure, 106 Template: corev1.PodTemplateSpec{ 107 Spec: corev1.PodSpec{ 108 Containers: []corev1.Container{ 109 corev1.Container{ 110 Name: XGBoostJobDefaultContainerName, 111 Image: testImage, 112 }, 113 }, 114 }, 115 }, 116 }, 117 }, 118 }, 119 }, 120 expected: expectedXGBoostJob(CleanPodPolicyNone, RestartPolicyOnFailure, 1, XGBoostJobDefaultPortName, XGBoostJobDefaultPort), 121 }, 122 "Set spec with replicas": { 123 original: &XGBoostJob{ 124 Spec: XGBoostJobSpec{ 125 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 126 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 127 Replicas: pointer.Int32(3), 128 Template: corev1.PodTemplateSpec{ 129 Spec: corev1.PodSpec{ 130 Containers: []corev1.Container{ 131 corev1.Container{ 132 Name: XGBoostJobDefaultContainerName, 133 Image: testImage, 134 }, 135 }, 136 }, 137 }, 138 }, 139 }, 140 }, 141 }, 142 expected: expectedXGBoostJob(CleanPodPolicyNone, XGBoostJobDefaultRestartPolicy, 3, XGBoostJobDefaultPortName, XGBoostJobDefaultPort), 143 }, 144 145 "Set spec with default node port name and port": { 146 original: &XGBoostJob{ 147 Spec: XGBoostJobSpec{ 148 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 149 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 150 Template: corev1.PodTemplateSpec{ 151 Spec: corev1.PodSpec{ 152 Containers: []corev1.Container{ 153 corev1.Container{ 154 Name: XGBoostJobDefaultContainerName, 155 Image: testImage, 156 Ports: []corev1.ContainerPort{ 157 corev1.ContainerPort{ 158 Name: XGBoostJobDefaultPortName, 159 ContainerPort: XGBoostJobDefaultPort, 160 }, 161 }, 162 }, 163 }, 164 }, 165 }, 166 }, 167 }, 168 }, 169 }, 170 expected: expectedXGBoostJob(CleanPodPolicyNone, XGBoostJobDefaultRestartPolicy, 1, XGBoostJobDefaultPortName, XGBoostJobDefaultPort), 171 }, 172 173 "Set spec with node port": { 174 original: &XGBoostJob{ 175 Spec: XGBoostJobSpec{ 176 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 177 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 178 Template: corev1.PodTemplateSpec{ 179 Spec: corev1.PodSpec{ 180 Containers: []corev1.Container{ 181 corev1.Container{ 182 Name: XGBoostJobDefaultContainerName, 183 Image: testImage, 184 Ports: []corev1.ContainerPort{ 185 corev1.ContainerPort{ 186 Name: XGBoostJobDefaultPortName, 187 ContainerPort: 9999, 188 }, 189 }, 190 }, 191 }, 192 }, 193 }, 194 }, 195 }, 196 }, 197 }, 198 expected: expectedXGBoostJob(CleanPodPolicyNone, XGBoostJobDefaultRestartPolicy, 1, XGBoostJobDefaultPortName, 9999), 199 }, 200 "set spec with cleanpod policy": { 201 original: &XGBoostJob{ 202 Spec: XGBoostJobSpec{ 203 RunPolicy: RunPolicy{ 204 CleanPodPolicy: CleanPodPolicyPointer(CleanPodPolicyAll), 205 }, 206 XGBReplicaSpecs: map[ReplicaType]*ReplicaSpec{ 207 XGBoostJobReplicaTypeWorker: &ReplicaSpec{ 208 Template: corev1.PodTemplateSpec{ 209 Spec: corev1.PodSpec{ 210 Containers: []corev1.Container{ 211 corev1.Container{ 212 Name: XGBoostJobDefaultContainerName, 213 Image: testImage, 214 }, 215 }, 216 }, 217 }, 218 }, 219 }, 220 }, 221 }, 222 expected: expectedXGBoostJob(CleanPodPolicyAll, XGBoostJobDefaultRestartPolicy, 1, XGBoostJobDefaultPortName, XGBoostJobDefaultPort), 223 }, 224 } 225 226 for name, tc := range testCases { 227 SetDefaults_XGBoostJob(tc.original) 228 if !reflect.DeepEqual(tc.original, tc.expected) { 229 t.Errorf("%s: Want\n%v; Got\n %v", name, pformat(tc.expected), pformat(tc.original)) 230 } 231 } 232 233 }