github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/mxnet/mxjob_controller.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mxnet
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"reflect"
    21  	"time"
    22  
    23  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    24  	trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common"
    25  	"github.com/kubeflow/training-operator/pkg/common/util"
    26  	"github.com/kubeflow/training-operator/pkg/controller.v1/common"
    27  	"github.com/kubeflow/training-operator/pkg/controller.v1/control"
    28  	"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
    29  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    30  
    31  	"github.com/go-logr/logr"
    32  	"github.com/sirupsen/logrus"
    33  	corev1 "k8s.io/api/core/v1"
    34  	"k8s.io/apimachinery/pkg/api/errors"
    35  	"k8s.io/apimachinery/pkg/api/meta"
    36  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    37  	"k8s.io/apimachinery/pkg/runtime"
    38  	"k8s.io/apimachinery/pkg/runtime/schema"
    39  	"k8s.io/apimachinery/pkg/types"
    40  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    41  	"k8s.io/client-go/informers"
    42  	kubeclientset "k8s.io/client-go/kubernetes"
    43  	"k8s.io/client-go/tools/record"
    44  	ctrl "sigs.k8s.io/controller-runtime"
    45  	"sigs.k8s.io/controller-runtime/pkg/client"
    46  	"sigs.k8s.io/controller-runtime/pkg/controller"
    47  	"sigs.k8s.io/controller-runtime/pkg/event"
    48  	"sigs.k8s.io/controller-runtime/pkg/handler"
    49  	"sigs.k8s.io/controller-runtime/pkg/log"
    50  	"sigs.k8s.io/controller-runtime/pkg/manager"
    51  	"sigs.k8s.io/controller-runtime/pkg/predicate"
    52  	"sigs.k8s.io/controller-runtime/pkg/source"
    53  	schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
    54  	"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
    55  )
    56  
    57  const (
    58  	controllerName = "mxjob-controller"
    59  )
    60  
    61  // NewReconciler creates a MXJob Reconciler
    62  func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *MXJobReconciler {
    63  	r := &MXJobReconciler{
    64  		Client:    mgr.GetClient(),
    65  		Scheme:    mgr.GetScheme(),
    66  		Recorder:  mgr.GetEventRecorderFor(controllerName),
    67  		apiReader: mgr.GetAPIReader(),
    68  		Log:       log.Log,
    69  	}
    70  
    71  	// Create clients.
    72  	cfg := mgr.GetConfig()
    73  	kubeClientSet := kubeclientset.NewForConfigOrDie(cfg)
    74  	sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0)
    75  	priorityClassInformer := sharedInformers.Scheduling().V1().PriorityClasses()
    76  
    77  	// Initialize common job controller
    78  	r.JobController = common.JobController{
    79  		Controller:                  r,
    80  		Expectations:                expectation.NewControllerExpectations(),
    81  		WorkQueue:                   &util.FakeWorkQueue{},
    82  		Recorder:                    r.Recorder,
    83  		KubeClientSet:               kubeClientSet,
    84  		PriorityClassLister:         priorityClassInformer.Lister(),
    85  		PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced,
    86  		PodControl:                  control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.Recorder},
    87  		ServiceControl:              control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.Recorder},
    88  	}
    89  
    90  	gangSchedulingSetupFunc(&r.JobController)
    91  
    92  	return r
    93  }
    94  
    95  // MXJobReconciler reconciles a MXJob object
    96  type MXJobReconciler struct {
    97  	common.JobController
    98  	client.Client
    99  	Log       logr.Logger
   100  	Recorder  record.EventRecorder
   101  	apiReader client.Reader
   102  	Scheme    *runtime.Scheme
   103  }
   104  
   105  //+kubebuilder:rbac:groups=kubeflow.org,resources=mxjobs,verbs=get;list;watch;create;update;patch;delete
   106  //+kubebuilder:rbac:groups=kubeflow.org,resources=mxjobs/status,verbs=get;update;patch
   107  //+kubebuilder:rbac:groups=kubeflow.org,resources=mxjobs/finalizers,verbs=update
   108  //+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete
   109  //+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete
   110  //+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
   111  //+kubebuilder:rbac:groups=scheduling.x-k8s.io,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
   112  //+kubebuilder:rbac:groups="",resources=events,verbs=get;list;watch;create;update;patch;delete
   113  
   114  func (r *MXJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
   115  	_ = log.FromContext(ctx)
   116  	logger := r.Log.WithValues(kubeflowv1.MXJobSingular, req.NamespacedName)
   117  
   118  	mxjob := &kubeflowv1.MXJob{}
   119  	err := r.Get(ctx, req.NamespacedName, mxjob)
   120  	if err != nil {
   121  		logger.Info(err.Error(), "unable to fetch MXJob", req.NamespacedName.String())
   122  		return ctrl.Result{}, client.IgnoreNotFound(err)
   123  	}
   124  
   125  	if err = kubeflowv1.ValidateV1MXJob(mxjob); err != nil {
   126  		logger.Error(err, "MXJob failed validation")
   127  		r.Recorder.Eventf(mxjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.MPIJobKind, commonutil.JobFailedValidationReason),
   128  			"MXJob failed validation because %s", err)
   129  		return ctrl.Result{}, err
   130  	}
   131  
   132  	// Check if reconciliation is needed
   133  	jobKey, err := common.KeyFunc(mxjob)
   134  	if err != nil {
   135  		utilruntime.HandleError(fmt.Errorf("couldn't get jobKey for job object %#v: %v", mxjob, err))
   136  	}
   137  
   138  	replicaTypes := util.GetReplicaTypes(mxjob.Spec.MXReplicaSpecs)
   139  	needReconcile := util.SatisfiedExpectations(r.Expectations, jobKey, replicaTypes)
   140  
   141  	if !needReconcile || mxjob.GetDeletionTimestamp() != nil {
   142  		logger.Info("reconcile cancelled, job does not need to do reconcile or has been deleted",
   143  			"sync", needReconcile, "deleted", mxjob.GetDeletionTimestamp() != nil)
   144  		return ctrl.Result{}, nil
   145  	}
   146  
   147  	// Set default priorities to mxnet job
   148  	r.Scheme.Default(mxjob)
   149  
   150  	// Convert MX.Spec.MXReplicasSpecs to  map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec
   151  	replicas := map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{}
   152  	for k, v := range mxjob.Spec.MXReplicaSpecs {
   153  		replicas[k] = v
   154  	}
   155  
   156  	// Use common to reconcile the job related pod and service
   157  	err = r.ReconcileJobs(mxjob, replicas, mxjob.Status, &mxjob.Spec.RunPolicy)
   158  	if err != nil {
   159  		logrus.Warnf("Reconcile MX Job error %v", err)
   160  		return ctrl.Result{}, err
   161  	}
   162  
   163  	t, err := util.DurationUntilExpireTime(&mxjob.Spec.RunPolicy, mxjob.Status)
   164  	if err != nil {
   165  		logrus.Warnf("Reconcile MX Job error %v", err)
   166  		return ctrl.Result{}, err
   167  	}
   168  	if t >= 0 {
   169  		return ctrl.Result{Requeue: true, RequeueAfter: t}, nil
   170  	}
   171  
   172  	return ctrl.Result{}, nil
   173  }
   174  
   175  // SetupWithManager sets up the controller with the Manager.
   176  func (r *MXJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads int) error {
   177  	c, err := controller.New(r.ControllerName(), mgr, controller.Options{
   178  		Reconciler:              r,
   179  		MaxConcurrentReconciles: controllerThreads,
   180  	})
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	// using onOwnerCreateFunc is easier to set defaults
   186  	if err = c.Watch(source.Kind(mgr.GetCache(), &kubeflowv1.MXJob{}), &handler.EnqueueRequestForObject{},
   187  		predicate.Funcs{CreateFunc: r.onOwnerCreateFunc()}); err != nil {
   188  		return err
   189  	}
   190  
   191  	// eventHandler for owned objects
   192  	eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.MXJob{}, handler.OnlyControllerOwner())
   193  	// predicates for owned objects
   194  	predicates := predicate.Funcs{
   195  		CreateFunc: util.OnDependentCreateFunc(r.Expectations),
   196  		UpdateFunc: util.OnDependentUpdateFunc(&r.JobController),
   197  		DeleteFunc: util.OnDependentDeleteFunc(r.Expectations),
   198  	}
   199  	genericPredicates := predicate.Funcs{
   200  		CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations),
   201  		UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController),
   202  		DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations),
   203  	}
   204  	// inject watching for job related pod
   205  	if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}), eventHandler, predicates); err != nil {
   206  		return err
   207  	}
   208  	// inject watching for job related service
   209  	if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Service{}), eventHandler, predicates); err != nil {
   210  		return err
   211  	}
   212  	// skip watching volcano PodGroup if volcano PodGroup is not installed
   213  	if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"},
   214  		v1beta1.SchemeGroupVersion.Version,
   215  	); err == nil {
   216  		// inject watching for job related volcano PodGroup
   217  		if err = c.Watch(source.Kind(mgr.GetCache(), &v1beta1.PodGroup{}), eventHandler, genericPredicates); err != nil {
   218  			return err
   219  		}
   220  	}
   221  	// skip watching scheduler-plugins PodGroup if scheduler-plugins PodGroup is not installed
   222  	if _, err = mgr.GetRESTMapper().RESTMapping(
   223  		schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"},
   224  		schedulerpluginsv1alpha1.SchemeGroupVersion.Version,
   225  	); err == nil {
   226  		// inject watching for job related scheduler-plugins PodGroup
   227  		if err = c.Watch(source.Kind(mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}), eventHandler, genericPredicates); err != nil {
   228  			return err
   229  		}
   230  	}
   231  
   232  	return nil
   233  }
   234  
   235  // ControllerName is ControllerInterface's implementation
   236  func (r *MXJobReconciler) ControllerName() string {
   237  	return controllerName
   238  }
   239  
   240  func (r *MXJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind {
   241  	return kubeflowv1.GroupVersion.WithKind(kubeflowv1.MXJobKind)
   242  }
   243  
   244  func (r *MXJobReconciler) GetAPIGroupVersion() schema.GroupVersion {
   245  	return kubeflowv1.GroupVersion
   246  }
   247  
   248  func (r *MXJobReconciler) GetGroupNameLabelValue() string {
   249  	return kubeflowv1.GroupVersion.Group
   250  }
   251  
   252  func (r *MXJobReconciler) GetFrameworkName() string {
   253  	return kubeflowv1.MXJobFrameworkName
   254  }
   255  
   256  func (r *MXJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
   257  	job := &kubeflowv1.MXJob{}
   258  	err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
   259  	if err != nil {
   260  		if errors.IsNotFound(err) {
   261  			logrus.Error(err, "mxnet job not found", "namespace", namespace, "name", name)
   262  		} else {
   263  			logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
   264  		}
   265  		return nil, err
   266  	}
   267  	return job, nil
   268  }
   269  
   270  func (r *MXJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) {
   271  	job := &kubeflowv1.MXJob{}
   272  
   273  	err := r.apiReader.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
   274  	if err != nil {
   275  		if errors.IsNotFound(err) {
   276  			logrus.Error(err, "xgboost job not found", "namespace", namespace, "name", name)
   277  		} else {
   278  			logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
   279  		}
   280  		return nil, err
   281  	}
   282  	return job, nil
   283  }
   284  
   285  func (r *MXJobReconciler) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error) {
   286  	job, err := meta.Accessor(obj)
   287  	if err != nil {
   288  		return nil, fmt.Errorf("%v is not a type of MXJob", job)
   289  	}
   290  	// List all pods to include those that don't match the selector anymore
   291  	// but have a ControllerRef pointing to this controller.
   292  	podlist := &corev1.PodList{}
   293  	err = r.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName())), client.InNamespace(job.GetNamespace()))
   294  	if err != nil {
   295  		return nil, err
   296  	}
   297  	return util.JobControlledPodList(podlist.Items, job), nil
   298  }
   299  
   300  func (r *MXJobReconciler) GetServicesForJob(job interface{}) ([]*corev1.Service, error) {
   301  	mxJob, err := meta.Accessor(job)
   302  	if err != nil {
   303  		return nil, fmt.Errorf("%v is not a type of MXJob", mxJob)
   304  	}
   305  
   306  	// List all services to include those that don't match the selector anymore
   307  	// but have a ControllerRef pointing to this controller.
   308  	serviceList := &corev1.ServiceList{}
   309  	err = r.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(mxJob.GetName())), client.InNamespace(mxJob.GetNamespace()))
   310  	if err != nil {
   311  		return nil, err
   312  	}
   313  
   314  	ret := util.ConvertServiceList(serviceList.Items)
   315  	return ret, nil
   316  }
   317  
   318  func (r *MXJobReconciler) DeleteJob(job interface{}) error {
   319  	mxjob, ok := job.(*kubeflowv1.MXJob)
   320  	if !ok {
   321  		return fmt.Errorf("%+v is not a type of XGBoostJob", job)
   322  	}
   323  	if err := r.Delete(context.Background(), mxjob); err != nil {
   324  		r.Recorder.Eventf(mxjob, corev1.EventTypeWarning, control.FailedDeletePodReason, "Error deleting: %v", err)
   325  		logrus.Error(err, "failed to delete job", "namespace", mxjob.Namespace, "name", mxjob.Name)
   326  		return err
   327  	}
   328  	r.Recorder.Eventf(mxjob, corev1.EventTypeNormal, control.SuccessfulDeletePodReason, "Deleted job: %v", mxjob.Name)
   329  	logrus.Info("job deleted", "namespace", mxjob.Namespace, "name", mxjob.Name)
   330  	trainingoperatorcommon.DeletedJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
   331  	return nil
   332  }
   333  
   334  func (r *MXJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec, jobStatus *kubeflowv1.JobStatus) error {
   335  	mxjob, ok := job.(*kubeflowv1.MXJob)
   336  	if !ok {
   337  		return fmt.Errorf("%v is not a type of MXJob", mxjob)
   338  	}
   339  
   340  	mxjobKey, err := common.KeyFunc(mxjob)
   341  	if err != nil {
   342  		utilruntime.HandleError(fmt.Errorf("couldn't get key for mxjob object %#v: %v", mxjob, err))
   343  		return err
   344  	}
   345  
   346  	if jobStatus.StartTime == nil {
   347  		now := metav1.Now()
   348  		jobStatus.StartTime = &now
   349  		// enqueue a sync to check if job past ActiveDeadlineSeconds
   350  		if mxjob.Spec.RunPolicy.ActiveDeadlineSeconds != nil {
   351  			logrus.Infof("Job with ActiveDeadlineSeconds will sync after %d seconds", *mxjob.Spec.RunPolicy.ActiveDeadlineSeconds)
   352  			r.WorkQueue.AddAfter(mxjobKey, time.Duration(*mxjob.Spec.RunPolicy.ActiveDeadlineSeconds)*time.Second)
   353  		}
   354  	}
   355  
   356  	// check whether mxnet singleHost training
   357  	singleTraining := r.isSingleWorker(replicas)
   358  
   359  	for rtype, spec := range replicas {
   360  		status := jobStatus.ReplicaStatuses[rtype]
   361  
   362  		// Expect to have `replicas - succeeded` pods alive.
   363  		succeeded := status.Succeeded
   364  		expected := *(spec.Replicas) - succeeded
   365  		running := status.Active
   366  		failed := status.Failed
   367  
   368  		r.Log.Info(fmt.Sprintf("MXJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d, failed=%d, singleTraining=%t",
   369  			mxjob.Name, rtype, expected, running, succeeded, failed, singleTraining))
   370  
   371  		if rtype == kubeflowv1.MXJobReplicaTypeScheduler || singleTraining {
   372  			if running > 0 {
   373  				msg := fmt.Sprintf("MXJob %s is running.", mxjob.Name)
   374  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobRunningReason), msg)
   375  			}
   376  			// when scheduler is succeeded, the job is finished.
   377  			if expected == 0 {
   378  				msg := fmt.Sprintf("MXJob %s is successfully completed.", mxjob.Name)
   379  				r.Recorder.Event(mxjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobSucceededReason), msg)
   380  				if jobStatus.CompletionTime == nil {
   381  					now := metav1.Now()
   382  					jobStatus.CompletionTime = &now
   383  				}
   384  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobSucceeded, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobSucceededReason), msg)
   385  				trainingoperatorcommon.SuccessfulJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
   386  				return nil
   387  			}
   388  		}
   389  		if failed > 0 {
   390  			if spec.RestartPolicy == kubeflowv1.RestartPolicyExitCode {
   391  				msg := fmt.Sprintf("mxjob %s is restarting because %d %s replica(s) failed.", mxjob.Name, failed, rtype)
   392  				r.Recorder.Event(mxjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobRestartingReason), msg)
   393  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRestarting, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobRestartingReason), msg)
   394  				if err != nil {
   395  					logrus.Infof("Append job condition error: %v", err)
   396  					return err
   397  				}
   398  				trainingoperatorcommon.RestartedJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
   399  			} else {
   400  				msg := fmt.Sprintf("mxjob %s is failed because %d %s replica(s) failed.", mxjob.Name, failed, rtype)
   401  				r.Recorder.Event(mxjob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobFailedReason), msg)
   402  				if jobStatus.CompletionTime == nil {
   403  					now := metav1.Now()
   404  					jobStatus.CompletionTime = &now
   405  				}
   406  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobFailed, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobFailedReason), msg)
   407  				trainingoperatorcommon.FailedJobsCounterInc(mxjob.Namespace, r.GetFrameworkName())
   408  			}
   409  		}
   410  	}
   411  
   412  	return nil
   413  }
   414  
   415  // UpdateJobStatusInApiServer updates the status of the given MXJob.
   416  func (r *MXJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *kubeflowv1.JobStatus) error {
   417  	if jobStatus.ReplicaStatuses == nil {
   418  		jobStatus.ReplicaStatuses = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaStatus{}
   419  	}
   420  
   421  	mxJob, ok := job.(*kubeflowv1.MXJob)
   422  	if !ok {
   423  		return fmt.Errorf("%v is not a type of MXJob", mxJob)
   424  	}
   425  
   426  	if !reflect.DeepEqual(&mxJob.Status, jobStatus) {
   427  		mxJob = mxJob.DeepCopy()
   428  		mxJob.Status = *jobStatus.DeepCopy()
   429  	}
   430  
   431  	if err := r.Status().Update(context.Background(), mxJob); err != nil {
   432  		logrus.Error(err, " failed to update MxJob conditions in the API server")
   433  		return err
   434  	}
   435  
   436  	return nil
   437  }
   438  
   439  func (r *MXJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
   440  	return SetPodEnv(job, podTemplate, rtype, index)
   441  }
   442  
   443  func (r *MXJobReconciler) GetDefaultContainerName() string {
   444  	return kubeflowv1.MXJobDefaultContainerName
   445  }
   446  
   447  func (r *MXJobReconciler) GetDefaultContainerPortName() string {
   448  	return kubeflowv1.MXJobDefaultPortName
   449  }
   450  
   451  func (r *MXJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec,
   452  	rtype kubeflowv1.ReplicaType, index int) bool {
   453  	return string(rtype) == string(kubeflowv1.MXJobReplicaTypeServer)
   454  }
   455  
   456  // onOwnerCreateFunc modify creation condition.
   457  func (r *MXJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
   458  	return func(e event.CreateEvent) bool {
   459  		mxJob, ok := e.Object.(*kubeflowv1.MXJob)
   460  		if !ok {
   461  			return true
   462  		}
   463  
   464  		// Use defaulters registered in scheme.
   465  		r.Scheme.Default(mxJob)
   466  		msg := fmt.Sprintf("MXJob %s is created.", e.Object.GetName())
   467  		logrus.Info(msg)
   468  		trainingoperatorcommon.CreatedJobsCounterInc(mxJob.Namespace, r.GetFrameworkName())
   469  		commonutil.UpdateJobConditions(&mxJob.Status, kubeflowv1.JobCreated, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobCreatedReason), msg)
   470  		return true
   471  	}
   472  }
   473  
   474  func (r *MXJobReconciler) isSingleWorker(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec) bool {
   475  	var workerNum, scheNum, svrNum int32 = 0, 0, 0
   476  
   477  	for rtype, spec := range replicas {
   478  		if rtype == kubeflowv1.MXJobReplicaTypeScheduler {
   479  			scheNum += *spec.Replicas
   480  		} else if rtype == kubeflowv1.MXJobReplicaTypeServer {
   481  			svrNum += *spec.Replicas
   482  		} else if rtype == kubeflowv1.MXJobReplicaTypeWorker {
   483  			workerNum += *spec.Replicas
   484  		}
   485  	}
   486  	return workerNum == 1 && scheNum == 0 && svrNum == 0
   487  }