github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/elastic_test.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License
    14  
    15  package pytorch
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/onsi/ginkgo/v2"
    21  	"github.com/onsi/gomega"
    22  	corev1 "k8s.io/api/core/v1"
    23  	"k8s.io/utils/pointer"
    24  
    25  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    26  )
    27  
    28  func TestElasticGenerate(t *testing.T) {
    29  	gomega.RegisterFailHandler(ginkgo.Fail)
    30  	defer ginkgo.GinkgoRecover()
    31  
    32  	backendC10D := kubeflowv1.BackendC10D
    33  
    34  	tests := []struct {
    35  		name        string
    36  		job         *kubeflowv1.PyTorchJob
    37  		expectedErr error
    38  		expected    []corev1.EnvVar
    39  	}{
    40  		{
    41  			name: "Without ElasticPolicy",
    42  			job: &kubeflowv1.PyTorchJob{
    43  				Spec: kubeflowv1.PyTorchJobSpec{
    44  					PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    45  						kubeflowv1.PyTorchJobReplicaTypeWorker: {
    46  							Replicas: pointer.Int32(1),
    47  						},
    48  					},
    49  				},
    50  			},
    51  			expectedErr: nil,
    52  			expected:    nil,
    53  		},
    54  		{
    55  			name: "With ElasticPolicy",
    56  			job: &kubeflowv1.PyTorchJob{
    57  				Spec: kubeflowv1.PyTorchJobSpec{
    58  					ElasticPolicy: &kubeflowv1.ElasticPolicy{
    59  						MinReplicas: pointer.Int32(1),
    60  						MaxReplicas: pointer.Int32(3),
    61  						RDZVBackend: &backendC10D,
    62  						RDZVPort:    pointer.Int32(1234),
    63  						RDZVHost:    pointer.String("localhost"),
    64  						RDZVID:      pointer.String("rdzv-id"),
    65  						RDZVConf: []kubeflowv1.RDZVConf{
    66  							{
    67  								Key:   "rdzv-conf-name",
    68  								Value: "rdzv-conf-value",
    69  							},
    70  							{
    71  								Key:   "rdzv-conf-name-1",
    72  								Value: "rdzv-conf-value-1",
    73  							},
    74  						},
    75  						MaxRestarts: pointer.Int32(3),
    76  					},
    77  					PyTorchReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    78  						kubeflowv1.PyTorchJobReplicaTypeWorker: {
    79  							Replicas: pointer.Int32(1),
    80  						},
    81  					},
    82  				},
    83  			},
    84  			expectedErr: nil,
    85  			expected: []corev1.EnvVar{
    86  				{
    87  					Name:  EnvMaxRestarts,
    88  					Value: "3",
    89  				},
    90  				{
    91  					Name:  EnvRDZVBackend,
    92  					Value: "c10d",
    93  				},
    94  				{
    95  					Name:  EnvRDZVEndpoint,
    96  					Value: "localhost:1234",
    97  				},
    98  				{
    99  					Name:  EnvRDZVID,
   100  					Value: "rdzv-id",
   101  				},
   102  				{
   103  					Name:  EnvRDZVConf,
   104  					Value: "rdzv-conf-name=rdzv-conf-value,rdzv-conf-name-1=rdzv-conf-value-1",
   105  				},
   106  				{
   107  					Name:  EnvNnodes,
   108  					Value: "1:3",
   109  				},
   110  			},
   111  		},
   112  	}
   113  
   114  	for _, test := range tests {
   115  		actual, err := GetElasticEnvVarGenerator().Generate(test.job)
   116  		if test.expectedErr == nil {
   117  			gomega.Expect(err).To(gomega.BeNil())
   118  		} else {
   119  			gomega.Expect(err).To(gomega.Equal(test.expectedErr))
   120  		}
   121  		if test.expected == nil {
   122  			gomega.Expect(actual).To(gomega.BeNil())
   123  		} else {
   124  			gomega.Expect(actual).To(gomega.ConsistOf(test.expected))
   125  		}
   126  	}
   127  }