github.com/kubeflow/training-operator@v1.7.0/pkg/apis/kubeflow.org/v1/pytorch_defaults_test.go (about)

     1  package v1
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/onsi/ginkgo/v2"
     7  	"github.com/onsi/gomega"
     8  	"k8s.io/utils/pointer"
     9  )
    10  
    11  func TestSetElasticPolicy(t *testing.T) {
    12  	gomega.RegisterFailHandler(ginkgo.Fail)
    13  
    14  	type args struct {
    15  		job *PyTorchJob
    16  	}
    17  	type result struct {
    18  		expectedMinReplicas *int32
    19  		expectedMaxReplicas *int32
    20  	}
    21  	tests := []struct {
    22  		name   string
    23  		args   args
    24  		result result
    25  	}{
    26  		{
    27  			name: "minReplicas and maxReplicas to null",
    28  			args: args{
    29  				job: &PyTorchJob{
    30  					Spec: PyTorchJobSpec{
    31  						ElasticPolicy: &ElasticPolicy{},
    32  						PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
    33  							PyTorchJobReplicaTypeWorker: {
    34  								Replicas: pointer.Int32(1),
    35  							},
    36  						},
    37  					},
    38  				},
    39  			},
    40  			result: result{
    41  				expectedMinReplicas: pointer.Int32(1),
    42  				expectedMaxReplicas: pointer.Int32(1),
    43  			},
    44  		},
    45  		{
    46  			name: "minReplicas and maxReplicas to 1",
    47  			args: args{
    48  				job: &PyTorchJob{
    49  					Spec: PyTorchJobSpec{
    50  						ElasticPolicy: &ElasticPolicy{
    51  							MaxReplicas: pointer.Int32(1),
    52  							MinReplicas: pointer.Int32(1),
    53  						},
    54  						PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
    55  							PyTorchJobReplicaTypeWorker: {
    56  								Replicas: pointer.Int32(1),
    57  							},
    58  						},
    59  					},
    60  				},
    61  			},
    62  			result: result{
    63  				expectedMinReplicas: pointer.Int32(1),
    64  				expectedMaxReplicas: pointer.Int32(1),
    65  			},
    66  		},
    67  		{
    68  			name: "minReplicas and maxReplicas to 1",
    69  			args: args{
    70  				job: &PyTorchJob{
    71  					Spec: PyTorchJobSpec{
    72  						ElasticPolicy: &ElasticPolicy{
    73  							MaxReplicas: pointer.Int32(1),
    74  							MinReplicas: pointer.Int32(1),
    75  						},
    76  						PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
    77  							PyTorchJobReplicaTypeWorker: {
    78  								Replicas: pointer.Int32(1),
    79  							},
    80  						},
    81  					},
    82  				},
    83  			},
    84  			result: result{
    85  				expectedMinReplicas: pointer.Int32(1),
    86  				expectedMaxReplicas: pointer.Int32(1),
    87  			},
    88  		},
    89  		{
    90  			name: "minReplicas to null, maxRepliacs to 1",
    91  			args: args{
    92  				job: &PyTorchJob{
    93  					Spec: PyTorchJobSpec{
    94  						ElasticPolicy: &ElasticPolicy{
    95  							MaxReplicas: pointer.Int32(1),
    96  							MinReplicas: nil,
    97  						},
    98  						PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
    99  							PyTorchJobReplicaTypeWorker: {
   100  								Replicas: pointer.Int32(1),
   101  							},
   102  						},
   103  					},
   104  				},
   105  			},
   106  			result: result{
   107  				expectedMinReplicas: pointer.Int32(1),
   108  				expectedMaxReplicas: pointer.Int32(1),
   109  			},
   110  		},
   111  		{
   112  			name: "maxRepliacs to null, minReplicas to 1",
   113  			args: args{
   114  				job: &PyTorchJob{
   115  					Spec: PyTorchJobSpec{
   116  						ElasticPolicy: &ElasticPolicy{
   117  							MaxReplicas: nil,
   118  							MinReplicas: pointer.Int32(1),
   119  						},
   120  						PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   121  							PyTorchJobReplicaTypeWorker: {
   122  								Replicas: pointer.Int32(1),
   123  							},
   124  						},
   125  					},
   126  				},
   127  			},
   128  			result: result{
   129  				expectedMinReplicas: pointer.Int32(1),
   130  				expectedMaxReplicas: pointer.Int32(1),
   131  			},
   132  		},
   133  	}
   134  	for _, test := range tests {
   135  		t.Run(test.name, func(t *testing.T) {
   136  			setElasticPolicy(test.args.job)
   137  			if test.result.expectedMinReplicas != nil {
   138  				gomega.Expect(test.args.job.Spec.ElasticPolicy.MinReplicas).
   139  					To(gomega.Equal(test.result.expectedMinReplicas))
   140  			} else {
   141  				gomega.Expect(test.args.job.Spec.ElasticPolicy.MinReplicas).
   142  					To(gomega.BeNil())
   143  			}
   144  
   145  			if test.result.expectedMaxReplicas != nil {
   146  				gomega.Expect(test.args.job.Spec.ElasticPolicy.MaxReplicas).
   147  					To(gomega.Equal(test.result.expectedMaxReplicas))
   148  			} else {
   149  				gomega.Expect(test.args.job.Spec.ElasticPolicy.MaxReplicas).
   150  					To(gomega.BeNil())
   151  			}
   152  		})
   153  	}
   154  }
   155  
   156  func TestSetDefaultNprocPerNode(t *testing.T) {
   157  	gomega.RegisterFailHandler(ginkgo.Fail)
   158  	t.Run("test default nproc per node", func(t *testing.T) {
   159  		job := &PyTorchJob{
   160  			Spec: PyTorchJobSpec{
   161  				ElasticPolicy: &ElasticPolicy{
   162  					NProcPerNode: nil,
   163  				},
   164  				PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   165  					PyTorchJobReplicaTypeWorker: {
   166  						Replicas: pointer.Int32(1),
   167  					},
   168  				},
   169  			},
   170  		}
   171  
   172  		setDefaultNprocPerNode(job)
   173  		gomega.Expect(job.Spec.NprocPerNode).
   174  			To(gomega.Equal(&DefaultNprocPerNode))
   175  	})
   176  	t.Run("test default nproc per node", func(t *testing.T) {
   177  		job := &PyTorchJob{
   178  			Spec: PyTorchJobSpec{
   179  				ElasticPolicy: nil,
   180  				PyTorchReplicaSpecs: map[ReplicaType]*ReplicaSpec{
   181  					PyTorchJobReplicaTypeWorker: {
   182  						Replicas: pointer.Int32(1),
   183  					},
   184  				},
   185  			},
   186  		}
   187  
   188  		setDefaultNprocPerNode(job)
   189  		gomega.Expect(job.Spec.NprocPerNode).
   190  			To(gomega.Equal(&DefaultNprocPerNode))
   191  	})
   192  }