github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/pytorchjob_controller.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pytorch
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  	"time"
    22  
    23  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    24  	trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common"
    25  	"github.com/kubeflow/training-operator/pkg/common/util"
    26  	"github.com/kubeflow/training-operator/pkg/controller.v1/common"
    27  	"github.com/kubeflow/training-operator/pkg/controller.v1/control"
    28  	"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
    29  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    30  
    31  	"github.com/go-logr/logr"
    32  	"github.com/sirupsen/logrus"
    33  	corev1 "k8s.io/api/core/v1"
    34  	"k8s.io/apimachinery/pkg/api/equality"
    35  	"k8s.io/apimachinery/pkg/api/errors"
    36  	"k8s.io/apimachinery/pkg/api/meta"
    37  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    38  	"k8s.io/apimachinery/pkg/runtime"
    39  	"k8s.io/apimachinery/pkg/runtime/schema"
    40  	"k8s.io/apimachinery/pkg/types"
    41  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    42  	"k8s.io/client-go/informers"
    43  	kubeclientset "k8s.io/client-go/kubernetes"
    44  	"k8s.io/client-go/tools/record"
    45  	ctrl "sigs.k8s.io/controller-runtime"
    46  	"sigs.k8s.io/controller-runtime/pkg/client"
    47  	"sigs.k8s.io/controller-runtime/pkg/controller"
    48  	"sigs.k8s.io/controller-runtime/pkg/event"
    49  	"sigs.k8s.io/controller-runtime/pkg/handler"
    50  	"sigs.k8s.io/controller-runtime/pkg/log"
    51  	"sigs.k8s.io/controller-runtime/pkg/manager"
    52  	"sigs.k8s.io/controller-runtime/pkg/predicate"
    53  	"sigs.k8s.io/controller-runtime/pkg/source"
    54  	schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
    55  	"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
    56  )
    57  
    58  const (
    59  	controllerName = "pytorchjob-controller"
    60  )
    61  
    62  // NewReconciler creates a PyTorchJob Reconciler
    63  func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *PyTorchJobReconciler {
    64  	r := &PyTorchJobReconciler{
    65  		Client:    mgr.GetClient(),
    66  		Scheme:    mgr.GetScheme(),
    67  		recorder:  mgr.GetEventRecorderFor(controllerName),
    68  		apiReader: mgr.GetAPIReader(),
    69  		Log:       log.Log,
    70  	}
    71  
    72  	// Create clients
    73  	cfg := mgr.GetConfig()
    74  	kubeClientSet := kubeclientset.NewForConfigOrDie(cfg)
    75  	sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0)
    76  	priorityClassInformer := sharedInformers.Scheduling().V1().PriorityClasses()
    77  
    78  	// Initialize common job controller
    79  	r.JobController = common.JobController{
    80  		Controller:                  r,
    81  		Expectations:                expectation.NewControllerExpectations(),
    82  		WorkQueue:                   &util.FakeWorkQueue{},
    83  		Recorder:                    r.recorder,
    84  		KubeClientSet:               kubeClientSet,
    85  		PriorityClassLister:         priorityClassInformer.Lister(),
    86  		PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced,
    87  		PodControl:                  control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder},
    88  		ServiceControl:              control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder},
    89  	}
    90  
    91  	gangSchedulingSetupFunc(&r.JobController)
    92  
    93  	return r
    94  }
    95  
    96  // PyTorchJobReconciler reconciles a PyTorchJob object
    97  type PyTorchJobReconciler struct {
    98  	common.JobController
    99  	client.Client
   100  	Scheme    *runtime.Scheme
   101  	Log       logr.Logger
   102  	recorder  record.EventRecorder
   103  	apiReader client.Reader
   104  }
   105  
   106  //+kubebuilder:rbac:groups=kubeflow.org,resources=pytorchjobs,verbs=get;list;watch;create;update;patch;delete
   107  //+kubebuilder:rbac:groups=kubeflow.org,resources=pytorchjobs/status,verbs=get;update;patch
   108  //+kubebuilder:rbac:groups=kubeflow.org,resources=pytorchjobs/finalizers,verbs=update
   109  //+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete
   110  //+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete
   111  //+kubebuilder:rbac:groups=autoscaling,resources=horizontalpodautoscalers,verbs=get;list;watch;create;update;patch;delete
   112  //+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
   113  //+kubebuilder:rbac:groups=scheduling.x-k8s.io,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
   114  //+kubebuilder:rbac:groups="",resources=events,verbs=get;list;watch;create;update;patch;delete
   115  
   116  // Reconcile is part of the main kubernetes reconciliation loop which aims to
   117  // move the current state of the cluster closer to the desired state.
   118  // the PyTorchJob object against the actual cluster state, and then
   119  // perform operations to make the cluster state reflect the state specified by
   120  // the user.
   121  //
   122  // For more details, check Reconcile and its Result here:
   123  // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.8.3/pkg/reconcile
   124  func (r *PyTorchJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
   125  	_ = log.FromContext(ctx)
   126  	logger := r.Log.WithValues(kubeflowv1.PyTorchJobSingular, req.NamespacedName)
   127  
   128  	pytorchjob := &kubeflowv1.PyTorchJob{}
   129  	err := r.Get(ctx, req.NamespacedName, pytorchjob)
   130  	if err != nil {
   131  		logger.Info(err.Error(), "unable to fetch PyTorchJob", req.NamespacedName.String())
   132  		return ctrl.Result{}, client.IgnoreNotFound(err)
   133  	}
   134  
   135  	if err = kubeflowv1.ValidateV1PyTorchJob(pytorchjob); err != nil {
   136  		logger.Error(err, "PyTorchJob failed validation")
   137  		r.Recorder.Eventf(pytorchjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobFailedValidationReason),
   138  			"PyTorchJob failed validation because %s", err)
   139  		return ctrl.Result{}, err
   140  	}
   141  
   142  	// Check if reconciliation is needed
   143  	jobKey, err := common.KeyFunc(pytorchjob)
   144  	if err != nil {
   145  		utilruntime.HandleError(fmt.Errorf("couldn't get jobKey for job object %#v: %v", pytorchjob, err))
   146  	}
   147  
   148  	replicaTypes := util.GetReplicaTypes(pytorchjob.Spec.PyTorchReplicaSpecs)
   149  	needReconcile := util.SatisfiedExpectations(r.Expectations, jobKey, replicaTypes)
   150  
   151  	if !needReconcile || pytorchjob.GetDeletionTimestamp() != nil {
   152  		logger.Info("reconcile cancelled, job does not need to do reconcile or has been deleted",
   153  			"sync", needReconcile, "deleted", pytorchjob.GetDeletionTimestamp() != nil)
   154  		return ctrl.Result{}, nil
   155  	}
   156  
   157  	// Set default priorities to pytorch job
   158  	r.Scheme.Default(pytorchjob)
   159  
   160  	err = r.ReconcileHPA(pytorchjob)
   161  	if err != nil {
   162  		logger.Error(err, "Reconcile PyTorchJob HPA error")
   163  		return ctrl.Result{}, err
   164  	}
   165  	// Use common to reconcile the job related pod and service
   166  	err = r.ReconcileJobs(pytorchjob, pytorchjob.Spec.PyTorchReplicaSpecs, pytorchjob.Status, &pytorchjob.Spec.RunPolicy)
   167  	if err != nil {
   168  		logger.Error(err, "Reconcile PyTorchJob error")
   169  		return ctrl.Result{}, err
   170  	}
   171  	t, err := util.DurationUntilExpireTime(&pytorchjob.Spec.RunPolicy, pytorchjob.Status)
   172  	if err != nil {
   173  		logrus.Warnf("Reconcile PyTorchJob error %v", err)
   174  		return ctrl.Result{}, err
   175  	}
   176  	if t >= 0 {
   177  		return ctrl.Result{Requeue: true, RequeueAfter: t}, nil
   178  	}
   179  
   180  	return ctrl.Result{}, nil
   181  }
   182  
   183  // SetupWithManager sets up the controller with the Manager.
   184  func (r *PyTorchJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads int) error {
   185  	c, err := controller.New(r.ControllerName(), mgr, controller.Options{
   186  		Reconciler:              r,
   187  		MaxConcurrentReconciles: controllerThreads,
   188  	})
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	// using onOwnerCreateFunc is easier to set defaults
   194  	if err = c.Watch(source.Kind(mgr.GetCache(), &kubeflowv1.PyTorchJob{}), &handler.EnqueueRequestForObject{},
   195  		predicate.Funcs{CreateFunc: r.onOwnerCreateFunc()},
   196  	); err != nil {
   197  		return err
   198  	}
   199  
   200  	// eventHandler for owned object
   201  	eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.PyTorchJob{}, handler.OnlyControllerOwner())
   202  	predicates := predicate.Funcs{
   203  		CreateFunc: util.OnDependentCreateFunc(r.Expectations),
   204  		UpdateFunc: util.OnDependentUpdateFunc(&r.JobController),
   205  		DeleteFunc: util.OnDependentDeleteFunc(r.Expectations),
   206  	}
   207  	// Create generic predicates
   208  	genericPredicates := predicate.Funcs{
   209  		CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations),
   210  		UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController),
   211  		DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations),
   212  	}
   213  	// inject watching for job related pod
   214  	if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}), eventHandler, predicates); err != nil {
   215  		return err
   216  	}
   217  	// inject watching for job related service
   218  	if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Service{}), eventHandler, predicates); err != nil {
   219  		return err
   220  	}
   221  	// skip watching volcano PodGroup if volcano PodGroup is not installed
   222  	if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"},
   223  		v1beta1.SchemeGroupVersion.Version); err == nil {
   224  		// inject watching for job related volcano PodGroup
   225  		if err = c.Watch(source.Kind(mgr.GetCache(), &v1beta1.PodGroup{}), eventHandler, genericPredicates); err != nil {
   226  			return err
   227  		}
   228  	}
   229  	// skip watching scheduler-plugins PodGroup if scheduler-plugins PodGroup is not installed
   230  	if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"},
   231  		schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil {
   232  		// inject watching for job related scheduler-plugins PodGroup
   233  		if err = c.Watch(source.Kind(mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}), eventHandler, genericPredicates); err != nil {
   234  			return err
   235  		}
   236  	}
   237  	return nil
   238  }
   239  
   240  func (r *PyTorchJobReconciler) ControllerName() string {
   241  	return controllerName
   242  }
   243  
   244  func (r *PyTorchJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind {
   245  	return kubeflowv1.GroupVersion.WithKind(kubeflowv1.PyTorchJobKind)
   246  }
   247  
   248  func (r *PyTorchJobReconciler) GetAPIGroupVersion() schema.GroupVersion {
   249  	return kubeflowv1.GroupVersion
   250  }
   251  
   252  func (r *PyTorchJobReconciler) GetGroupNameLabelValue() string {
   253  	return kubeflowv1.GroupVersion.Group
   254  }
   255  
   256  func (r *PyTorchJobReconciler) GetFrameworkName() string {
   257  	return kubeflowv1.PyTorchJobFrameworkName
   258  }
   259  
   260  func (r *PyTorchJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
   261  	job := &kubeflowv1.PyTorchJob{}
   262  	err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
   263  	if err != nil {
   264  		if errors.IsNotFound(err) {
   265  			logrus.Error(err, "pytorch job not found", "namespace", namespace, "name", name)
   266  		} else {
   267  			logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
   268  		}
   269  		return nil, err
   270  	}
   271  	return job, nil
   272  }
   273  
   274  func (r *PyTorchJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) {
   275  	job := &kubeflowv1.PyTorchJob{}
   276  
   277  	err := r.apiReader.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
   278  	if err != nil {
   279  		if errors.IsNotFound(err) {
   280  			logrus.Error(err, "pytorch job not found", "namespace", namespace, "name", name)
   281  		} else {
   282  			logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
   283  		}
   284  		return nil, err
   285  	}
   286  	return job, nil
   287  }
   288  
   289  func (r *PyTorchJobReconciler) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error) {
   290  	job, err := meta.Accessor(obj)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	// List all pods to include those that don't match the selector anymore
   296  	// but have a ControllerRef pointing to this controller.
   297  	podlist := &corev1.PodList{}
   298  	err = r.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  
   303  	return util.JobControlledPodList(podlist.Items, job), nil
   304  }
   305  
   306  func (r *PyTorchJobReconciler) GetServicesForJob(obj interface{}) ([]*corev1.Service, error) {
   307  	job, err := meta.Accessor(obj)
   308  	if err != nil {
   309  		return nil, err
   310  	}
   311  
   312  	// List all pods to include those that don't match the selector anymore
   313  	// but have a ControllerRef pointing to this controller.
   314  	serviceList := &corev1.ServiceList{}
   315  	err = r.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
   316  	if err != nil {
   317  		return nil, err
   318  	}
   319  
   320  	ret := util.ConvertServiceList(serviceList.Items)
   321  	return ret, nil
   322  }
   323  
   324  func (r *PyTorchJobReconciler) DeleteJob(job interface{}) error {
   325  	pytorchjob, ok := job.(*kubeflowv1.PyTorchJob)
   326  	if !ok {
   327  		return fmt.Errorf("%+v is not a type of PyTorchJob", job)
   328  	}
   329  	if err := r.Delete(context.Background(), pytorchjob); err != nil {
   330  		r.recorder.Eventf(pytorchjob, corev1.EventTypeWarning, control.FailedDeletePodReason, "Error deleting: %v", err)
   331  		logrus.Error(err, "failed to delete job", "namespace", pytorchjob.Namespace, "name", pytorchjob.Name)
   332  		return err
   333  	}
   334  	r.recorder.Eventf(pytorchjob, corev1.EventTypeNormal, control.SuccessfulDeletePodReason, "Deleted job: %v", pytorchjob.Name)
   335  	logrus.Info("job deleted", "namespace", pytorchjob.Namespace, "name", pytorchjob.Name)
   336  	trainingoperatorcommon.DeletedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
   337  	return nil
   338  }
   339  
   340  func (jc *PyTorchJobReconciler) GenLabelSelector(jobName string,
   341  	rtype kubeflowv1.ReplicaType) *metav1.LabelSelector {
   342  	labels := jc.GenLabels(jobName)
   343  	labels[kubeflowv1.ReplicaTypeLabel] = strings.ToLower(string(rtype))
   344  
   345  	return &metav1.LabelSelector{
   346  		MatchLabels: labels,
   347  	}
   348  }
   349  
   350  // UpdateJobStatus updates the job status and job conditions
   351  func (r *PyTorchJobReconciler) UpdateJobStatus(job interface{},
   352  	replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec,
   353  	jobStatus *kubeflowv1.JobStatus) error {
   354  	pytorchjob, ok := job.(*kubeflowv1.PyTorchJob)
   355  	if !ok {
   356  		return fmt.Errorf("%+v is not a type of PyTorchJob", job)
   357  	}
   358  	pytorchjobKey, err := common.KeyFunc(pytorchjob)
   359  	if err != nil {
   360  		utilruntime.HandleError(fmt.Errorf("couldn't get key for pytorchjob object %#v: %v", pytorchjob, err))
   361  		return err
   362  	}
   363  
   364  	logger := commonutil.LoggerForJob(pytorchjob)
   365  
   366  	// Set StartTime.
   367  	if jobStatus.StartTime == nil {
   368  		now := metav1.Now()
   369  		jobStatus.StartTime = &now
   370  		// enqueue a sync to check if job past ActiveDeadlineSeconds
   371  		if pytorchjob.Spec.RunPolicy.ActiveDeadlineSeconds != nil {
   372  			logger.Infof("Job with ActiveDeadlineSeconds will sync after %d seconds", *pytorchjob.Spec.RunPolicy.ActiveDeadlineSeconds)
   373  			r.WorkQueue.AddAfter(pytorchjobKey, time.Duration(*pytorchjob.Spec.RunPolicy.ActiveDeadlineSeconds)*time.Second)
   374  		}
   375  	}
   376  
   377  	for rtype, spec := range replicas {
   378  		status := jobStatus.ReplicaStatuses[rtype]
   379  		// Generate the label selector.
   380  		status.Selector = metav1.FormatLabelSelector(r.GenLabelSelector(pytorchjob.Name, rtype))
   381  
   382  		succeeded := status.Succeeded
   383  		expected := *(spec.Replicas) - succeeded
   384  		running := status.Active
   385  		failed := status.Failed
   386  		specReplicas := *spec.Replicas
   387  
   388  		logrus.Infof("PyTorchJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d, failed=%d, Replicas=%d",
   389  			pytorchjob.Name, rtype, expected, running, succeeded, failed, specReplicas)
   390  
   391  		if ContainsMasterSpec(replicas) {
   392  			if rtype == kubeflowv1.PyTorchJobReplicaTypeMaster {
   393  				if running > 0 {
   394  					msg := fmt.Sprintf("PyTorchJob %s is running.", pytorchjob.Name)
   395  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRunningReason), msg)
   396  				}
   397  				// when master is succeed, the job is finished.
   398  				if expected == 0 {
   399  					msg := fmt.Sprintf("PyTorchJob %s is successfully completed.", pytorchjob.Name)
   400  					logrus.Info(msg)
   401  					r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSucceededReason), msg)
   402  					if jobStatus.CompletionTime == nil {
   403  						now := metav1.Now()
   404  						jobStatus.CompletionTime = &now
   405  					}
   406  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobSucceeded, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSucceededReason), msg)
   407  					trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
   408  					return nil
   409  				}
   410  			}
   411  		} else {
   412  			if rtype == kubeflowv1.PyTorchJobReplicaTypeWorker {
   413  				// TODO(gaocegege): Support SuccessPolicy
   414  				// Leave a succeeded condition for the following two cases:
   415  				// 1. If all workers are succeeded.
   416  				// 2. If `ElasticPolicy` is not nil and any worker has completed.
   417  				if expected == 0 || (pytorchjob.Spec.ElasticPolicy != nil && succeeded > 0) {
   418  					msg := fmt.Sprintf("PyTorchJob %s/%s successfully completed.",
   419  						pytorchjob.Namespace, pytorchjob.Name)
   420  					r.recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSucceededReason), msg)
   421  					if jobStatus.CompletionTime == nil {
   422  						now := metav1.Now()
   423  						jobStatus.CompletionTime = &now
   424  					}
   425  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobSucceeded, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSucceededReason), msg)
   426  					trainingoperatorcommon.SuccessfulJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
   427  				} else if running > 0 {
   428  					// Some workers are still running, leave a running condition.
   429  					msg := fmt.Sprintf("PyTorchJob %s/%s is running.",
   430  						pytorchjob.Namespace, pytorchjob.Name)
   431  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRunningReason), msg)
   432  				}
   433  			}
   434  		}
   435  
   436  		if failed > 0 && (specReplicas > succeeded+running) {
   437  			if spec.RestartPolicy != kubeflowv1.RestartPolicyNever {
   438  				msg := fmt.Sprintf("PyTorchJob %s is restarting because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype)
   439  				r.Recorder.Event(pytorchjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRestartingReason), msg)
   440  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRestarting, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRestartingReason), msg)
   441  				trainingoperatorcommon.RestartedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
   442  			} else {
   443  				msg := fmt.Sprintf("PyTorchJob %s is failed because %d %s replica(s) failed.", pytorchjob.Name, failed, rtype)
   444  				r.Recorder.Event(pytorchjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobFailedReason), msg)
   445  				if jobStatus.CompletionTime == nil {
   446  					now := metav1.Now()
   447  					jobStatus.CompletionTime = &now
   448  				}
   449  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobFailed, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobFailedReason), msg)
   450  				trainingoperatorcommon.FailedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
   451  			}
   452  		}
   453  	}
   454  	return nil
   455  }
   456  
   457  // ContainsMasterSpec returns true if the pytorchjob contains master spec.
   458  func ContainsMasterSpec(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec) bool {
   459  	if _, ok := replicas[kubeflowv1.PyTorchJobReplicaTypeMaster]; ok {
   460  		return true
   461  	}
   462  	return false
   463  }
   464  
   465  // UpdateJobStatusInApiServer updates the job status in to cluster.
   466  func (r *PyTorchJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *kubeflowv1.JobStatus) error {
   467  	if jobStatus.ReplicaStatuses == nil {
   468  		jobStatus.ReplicaStatuses = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaStatus{}
   469  	}
   470  
   471  	pytorchjob, ok := job.(*kubeflowv1.PyTorchJob)
   472  	trainingoperatorcommon.ClearGeneratedFields(&pytorchjob.ObjectMeta)
   473  	if !ok {
   474  		return fmt.Errorf("%+v is not a type of PyTorchJob", job)
   475  	}
   476  
   477  	// Job status passed in differs with status in job, update in basis of the passed in one.
   478  	if !equality.Semantic.DeepEqual(&pytorchjob.Status, jobStatus) {
   479  		pytorchjob = pytorchjob.DeepCopy()
   480  		pytorchjob.Status = *jobStatus.DeepCopy()
   481  	}
   482  
   483  	result := r.Status().Update(context.Background(), pytorchjob)
   484  
   485  	if result != nil {
   486  		r.Log.WithValues("pytorchjob", types.NamespacedName{
   487  			Namespace: pytorchjob.GetNamespace(),
   488  			Name:      pytorchjob.GetName(),
   489  		})
   490  		return result
   491  	}
   492  
   493  	return nil
   494  }
   495  
   496  // SetClusterSpec sets the cluster spec and init container for the pod
   497  func (r *PyTorchJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
   498  	if err := setPodEnv(job, podTemplate, rtype, index); err != nil {
   499  		return err
   500  	}
   501  	if err := setInitContainer(job, podTemplate, rtype, index, r.Log); err != nil {
   502  		return err
   503  	}
   504  	return nil
   505  }
   506  
   507  func (r *PyTorchJobReconciler) GetDefaultContainerName() string {
   508  	return kubeflowv1.PyTorchJobDefaultContainerName
   509  }
   510  
   511  func (r *PyTorchJobReconciler) GetDefaultContainerPortName() string {
   512  	return kubeflowv1.PyTorchJobDefaultPortName
   513  }
   514  
   515  func (r *PyTorchJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec,
   516  	rtype kubeflowv1.ReplicaType, index int) bool {
   517  	return string(rtype) == string(kubeflowv1.PyTorchJobReplicaTypeMaster)
   518  }
   519  
   520  // onOwnerCreateFunc modify creation condition.
   521  func (r *PyTorchJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
   522  	return func(e event.CreateEvent) bool {
   523  		pytorchjob, ok := e.Object.(*kubeflowv1.PyTorchJob)
   524  		if !ok {
   525  			return true
   526  		}
   527  		r.Scheme.Default(pytorchjob)
   528  		msg := fmt.Sprintf("PyTorchJob %s is created.", e.Object.GetName())
   529  		logrus.Info(msg)
   530  		trainingoperatorcommon.CreatedJobsCounterInc(pytorchjob.Namespace, r.GetFrameworkName())
   531  		commonutil.UpdateJobConditions(&pytorchjob.Status, kubeflowv1.JobCreated, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason), msg)
   532  		return true
   533  	}
   534  }