github.com/kubeflow/training-operator@v1.7.0/pkg/apis/kubeflow.org/v1/tensorflow_defaults_test.go (about)

     1  // Copyright 2018 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package v1
    16  
    17  import (
    18  	"reflect"
    19  	"testing"
    20  
    21  	corev1 "k8s.io/api/core/v1"
    22  	"k8s.io/utils/pointer"
    23  )
    24  
    25  func expectedTFJob(cleanPodPolicy CleanPodPolicy, restartPolicy RestartPolicy, portName string, port int32) *TFJob {
    26  	var ports []corev1.ContainerPort
    27  
    28  	// port not set
    29  	if portName != "" {
    30  		ports = append(ports,
    31  			corev1.ContainerPort{
    32  				Name:          portName,
    33  				ContainerPort: port,
    34  			},
    35  		)
    36  	}
    37  
    38  	// port set with custom name
    39  	if portName != TFJobDefaultPortName {
    40  		ports = append(ports,
    41  			corev1.ContainerPort{
    42  				Name:          TFJobDefaultPortName,
    43  				ContainerPort: TFJobDefaultPort,
    44  			},
    45  		)
    46  	}
    47  
    48  	defaultSuccessPolicy := SuccessPolicyDefault
    49  
    50  	return &TFJob{
    51  		Spec: TFJobSpec{
    52  			SuccessPolicy: &defaultSuccessPolicy,
    53  			RunPolicy: RunPolicy{
    54  				CleanPodPolicy: &cleanPodPolicy,
    55  			},
    56  			TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
    57  				TFJobReplicaTypeWorker: &ReplicaSpec{
    58  					Replicas:      pointer.Int32(1),
    59  					RestartPolicy: restartPolicy,
    60  					Template: corev1.PodTemplateSpec{
    61  						Spec: corev1.PodSpec{
    62  							Containers: []corev1.Container{
    63  								corev1.Container{
    64  									Name:  TFJobDefaultContainerName,
    65  									Image: testImage,
    66  									Ports: ports,
    67  								},
    68  							},
    69  						},
    70  					},
    71  				},
    72  			},
    73  		},
    74  	}
    75  }
    76  
    77  func TestSetTypeNames(t *testing.T) {
    78  	spec := &ReplicaSpec{
    79  		RestartPolicy: RestartPolicyAlways,
    80  		Template: corev1.PodTemplateSpec{
    81  			Spec: corev1.PodSpec{
    82  				Containers: []corev1.Container{
    83  					corev1.Container{
    84  						Name:  TFJobDefaultContainerName,
    85  						Image: testImage,
    86  						Ports: []corev1.ContainerPort{
    87  							corev1.ContainerPort{
    88  								Name:          TFJobDefaultPortName,
    89  								ContainerPort: TFJobDefaultPort,
    90  							},
    91  						},
    92  					},
    93  				},
    94  			},
    95  		},
    96  	}
    97  
    98  	workerUpperCase := ReplicaType("WORKER")
    99  	original := &TFJob{
   100  		Spec: TFJobSpec{
   101  			TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   102  				workerUpperCase: spec,
   103  			},
   104  		},
   105  	}
   106  
   107  	setTensorflowTypeNamesToCamelCase(original)
   108  	if _, ok := original.Spec.TFReplicaSpecs[workerUpperCase]; ok {
   109  		t.Errorf("Failed to delete key %s", workerUpperCase)
   110  	}
   111  	if _, ok := original.Spec.TFReplicaSpecs[TFJobReplicaTypeWorker]; !ok {
   112  		t.Errorf("Failed to set key %s", TFJobReplicaTypeWorker)
   113  	}
   114  }
   115  
   116  func TestSetDefaultTFJob(t *testing.T) {
   117  	customPortName := "customPort"
   118  	var customPort int32 = 1234
   119  	customRestartPolicy := RestartPolicyAlways
   120  
   121  	testCases := map[string]struct {
   122  		original *TFJob
   123  		expected *TFJob
   124  	}{
   125  		"set replicas": {
   126  			original: &TFJob{
   127  				Spec: TFJobSpec{
   128  					TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   129  						TFJobReplicaTypeWorker: &ReplicaSpec{
   130  							RestartPolicy: customRestartPolicy,
   131  							Template: corev1.PodTemplateSpec{
   132  								Spec: corev1.PodSpec{
   133  									Containers: []corev1.Container{
   134  										corev1.Container{
   135  											Name:  TFJobDefaultContainerName,
   136  											Image: testImage,
   137  											Ports: []corev1.ContainerPort{
   138  												{
   139  													Name:          TFJobDefaultPortName,
   140  													ContainerPort: TFJobDefaultPort,
   141  												},
   142  											},
   143  										},
   144  									},
   145  								},
   146  							},
   147  						},
   148  					},
   149  				},
   150  			},
   151  			expected: expectedTFJob(CleanPodPolicyNone, customRestartPolicy, TFJobDefaultPortName, TFJobDefaultPort),
   152  		},
   153  		"set replicas with default restartpolicy": {
   154  			original: &TFJob{
   155  				Spec: TFJobSpec{
   156  					TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   157  						TFJobReplicaTypeWorker: &ReplicaSpec{
   158  							Template: corev1.PodTemplateSpec{
   159  								Spec: corev1.PodSpec{
   160  									Containers: []corev1.Container{
   161  										corev1.Container{
   162  											Name:  TFJobDefaultContainerName,
   163  											Image: testImage,
   164  											Ports: []corev1.ContainerPort{
   165  												corev1.ContainerPort{
   166  													Name:          TFJobDefaultPortName,
   167  													ContainerPort: TFJobDefaultPort,
   168  												},
   169  											},
   170  										},
   171  									},
   172  								},
   173  							},
   174  						},
   175  					},
   176  				},
   177  			},
   178  			expected: expectedTFJob(CleanPodPolicyNone, TFJobDefaultRestartPolicy, TFJobDefaultPortName, TFJobDefaultPort),
   179  		},
   180  		"set replicas with default port": {
   181  			original: &TFJob{
   182  				Spec: TFJobSpec{
   183  					TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   184  						TFJobReplicaTypeWorker: &ReplicaSpec{
   185  							Replicas:      pointer.Int32(1),
   186  							RestartPolicy: customRestartPolicy,
   187  							Template: corev1.PodTemplateSpec{
   188  								Spec: corev1.PodSpec{
   189  									Containers: []corev1.Container{
   190  										corev1.Container{
   191  											Name:  TFJobDefaultContainerName,
   192  											Image: testImage,
   193  										},
   194  									},
   195  								},
   196  							},
   197  						},
   198  					},
   199  				},
   200  			},
   201  			expected: expectedTFJob(CleanPodPolicyNone, customRestartPolicy, "", 0),
   202  		},
   203  		"set replicas adding default port": {
   204  			original: &TFJob{
   205  				Spec: TFJobSpec{
   206  					TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   207  						TFJobReplicaTypeWorker: &ReplicaSpec{
   208  							Replicas:      pointer.Int32(1),
   209  							RestartPolicy: customRestartPolicy,
   210  							Template: corev1.PodTemplateSpec{
   211  								Spec: corev1.PodSpec{
   212  									Containers: []corev1.Container{
   213  										corev1.Container{
   214  											Name:  TFJobDefaultContainerName,
   215  											Image: testImage,
   216  											Ports: []corev1.ContainerPort{
   217  												corev1.ContainerPort{
   218  													Name:          customPortName,
   219  													ContainerPort: customPort,
   220  												},
   221  											},
   222  										},
   223  									},
   224  								},
   225  							},
   226  						},
   227  					},
   228  				},
   229  			},
   230  			expected: expectedTFJob(CleanPodPolicyNone, customRestartPolicy, customPortName, customPort),
   231  		},
   232  		"set custom cleanpod policy": {
   233  			original: &TFJob{
   234  				Spec: TFJobSpec{
   235  					RunPolicy: RunPolicy{
   236  						CleanPodPolicy: CleanPodPolicyPointer(CleanPodPolicyAll),
   237  					},
   238  					TFReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   239  						TFJobReplicaTypeWorker: &ReplicaSpec{
   240  							Replicas:      pointer.Int32(1),
   241  							RestartPolicy: customRestartPolicy,
   242  							Template: corev1.PodTemplateSpec{
   243  								Spec: corev1.PodSpec{
   244  									Containers: []corev1.Container{
   245  										corev1.Container{
   246  											Name:  TFJobDefaultContainerName,
   247  											Image: testImage,
   248  											Ports: []corev1.ContainerPort{
   249  												corev1.ContainerPort{
   250  													Name:          customPortName,
   251  													ContainerPort: customPort,
   252  												},
   253  											},
   254  										},
   255  									},
   256  								},
   257  							},
   258  						},
   259  					},
   260  				},
   261  			},
   262  			expected: expectedTFJob(CleanPodPolicyAll, customRestartPolicy, customPortName, customPort),
   263  		},
   264  	}
   265  
   266  	for name, tc := range testCases {
   267  		SetDefaults_TFJob(tc.original)
   268  		if !reflect.DeepEqual(tc.original, tc.expected) {
   269  			t.Errorf("%s: Want\n%v; Got\n %v", name, tc.expected, tc.original)
   270  		}
   271  	}
   272  }