github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/tensorflow/tfjob_controller.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tensorflow
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  	"time"
    22  
    23  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    24  	trainingoperatorcommon "github.com/kubeflow/training-operator/pkg/common"
    25  	"github.com/kubeflow/training-operator/pkg/common/util"
    26  	"github.com/kubeflow/training-operator/pkg/controller.v1/common"
    27  	"github.com/kubeflow/training-operator/pkg/controller.v1/control"
    28  	"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
    29  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    30  
    31  	"github.com/go-logr/logr"
    32  	"github.com/sirupsen/logrus"
    33  	corev1 "k8s.io/api/core/v1"
    34  	v1 "k8s.io/api/core/v1"
    35  	"k8s.io/apimachinery/pkg/api/errors"
    36  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    37  	"k8s.io/apimachinery/pkg/runtime"
    38  	"k8s.io/apimachinery/pkg/runtime/schema"
    39  	"k8s.io/apimachinery/pkg/types"
    40  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    41  	"k8s.io/client-go/informers"
    42  	kubeclientset "k8s.io/client-go/kubernetes"
    43  	"k8s.io/client-go/tools/record"
    44  	ctrl "sigs.k8s.io/controller-runtime"
    45  	"sigs.k8s.io/controller-runtime/pkg/client"
    46  	"sigs.k8s.io/controller-runtime/pkg/controller"
    47  	"sigs.k8s.io/controller-runtime/pkg/event"
    48  	"sigs.k8s.io/controller-runtime/pkg/handler"
    49  	"sigs.k8s.io/controller-runtime/pkg/log"
    50  	"sigs.k8s.io/controller-runtime/pkg/manager"
    51  	"sigs.k8s.io/controller-runtime/pkg/predicate"
    52  	"sigs.k8s.io/controller-runtime/pkg/source"
    53  	schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
    54  	"volcano.sh/apis/pkg/apis/scheduling/v1beta1"
    55  )
    56  
    57  const (
    58  	FailedDeleteJobReason     = "FailedDeleteJob"
    59  	SuccessfulDeleteJobReason = "SuccessfulDeleteJob"
    60  
    61  	controllerName = "tfjob-controller"
    62  
    63  	// tfConfig is the environment variable name of TensorFlow cluster spec.
    64  	tfConfig = "TF_CONFIG"
    65  )
    66  
    67  func NewReconciler(mgr manager.Manager, gangSchedulingSetupFunc common.GangSchedulingSetupFunc) *TFJobReconciler {
    68  	r := &TFJobReconciler{
    69  		Client:    mgr.GetClient(),
    70  		Scheme:    mgr.GetScheme(),
    71  		recorder:  mgr.GetEventRecorderFor(controllerName),
    72  		apiReader: mgr.GetAPIReader(),
    73  		Log:       log.Log,
    74  	}
    75  
    76  	cfg := mgr.GetConfig()
    77  	kubeClientSet := kubeclientset.NewForConfigOrDie(cfg)
    78  	sharedInformers := informers.NewSharedInformerFactory(kubeClientSet, 0)
    79  	priorityClassInformer := sharedInformers.Scheduling().V1().PriorityClasses()
    80  
    81  	r.JobController = common.JobController{
    82  		Controller:                  r,
    83  		Expectations:                expectation.NewControllerExpectations(),
    84  		WorkQueue:                   &util.FakeWorkQueue{},
    85  		Recorder:                    r.recorder,
    86  		KubeClientSet:               kubeClientSet,
    87  		PriorityClassLister:         priorityClassInformer.Lister(),
    88  		PriorityClassInformerSynced: priorityClassInformer.Informer().HasSynced,
    89  		PodControl:                  control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder},
    90  		ServiceControl:              control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder},
    91  	}
    92  
    93  	gangSchedulingSetupFunc(&r.JobController)
    94  
    95  	return r
    96  }
    97  
    98  // TFJobReconciler reconciles a TFJob object
    99  type TFJobReconciler struct {
   100  	common.JobController
   101  	client.Client
   102  	Scheme    *runtime.Scheme
   103  	recorder  record.EventRecorder
   104  	apiReader client.Reader
   105  	Log       logr.Logger
   106  }
   107  
   108  //+kubebuilder:rbac:groups=kubeflow.org,resources=tfjobs,verbs=get;list;watch;create;update;patch;delete
   109  //+kubebuilder:rbac:groups=kubeflow.org,resources=tfjobs/status,verbs=get;update;patch
   110  //+kubebuilder:rbac:groups=kubeflow.org,resources=tfjobs/finalizers,verbs=update
   111  //+kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;update;patch;delete
   112  //+kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;delete
   113  //+kubebuilder:rbac:groups=scheduling.volcano.sh,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
   114  //+kubebuilder:rbac:groups=scheduling.x-k8s.io,resources=podgroups,verbs=get;list;watch;create;update;patch;delete
   115  //+kubebuilder:rbac:groups="",resources=events,verbs=get;list;watch;create;update;patch;delete
   116  
   117  // Reconcile is part of the main kubernetes reconciliation loop which aims to
   118  // move the current state of the cluster closer to the desired state.
   119  func (r *TFJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
   120  	_ = log.FromContext(ctx)
   121  	logger := r.Log.WithValues(kubeflowv1.TFJobSingular, req.NamespacedName)
   122  
   123  	tfjob := &kubeflowv1.TFJob{}
   124  	err := r.Get(ctx, req.NamespacedName, tfjob)
   125  	if err != nil {
   126  		logger.Info(err.Error(), "unable to fetch TFJob", req.NamespacedName.String())
   127  		return ctrl.Result{}, client.IgnoreNotFound(err)
   128  	}
   129  
   130  	if err = kubeflowv1.ValidateV1TFJob(tfjob); err != nil {
   131  		logger.Error(err, "TFJob failed validation")
   132  		r.Recorder.Eventf(tfjob, corev1.EventTypeWarning, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobFailedValidationReason),
   133  			"TFJob failed validation because %s", err)
   134  		return ctrl.Result{}, err
   135  	}
   136  
   137  	// Check if reconciliation is needed
   138  	jobKey, err := common.KeyFunc(tfjob)
   139  	if err != nil {
   140  		utilruntime.HandleError(fmt.Errorf("couldn't get jobKey for job object %#v: %v", tfjob, err))
   141  	}
   142  
   143  	replicaTypes := util.GetReplicaTypes(tfjob.Spec.TFReplicaSpecs)
   144  	needReconcile := util.SatisfiedExpectations(r.Expectations, jobKey, replicaTypes)
   145  
   146  	if !needReconcile || tfjob.GetDeletionTimestamp() != nil {
   147  		logger.Info("reconcile cancelled, job does not need to do reconcile or has been deleted",
   148  			"sync", needReconcile, "deleted", tfjob.GetDeletionTimestamp() != nil)
   149  		return ctrl.Result{}, nil
   150  	}
   151  
   152  	// Set default priorities to tfjob
   153  	r.Scheme.Default(tfjob)
   154  
   155  	// Use common to reconcile the job related pod and service
   156  	err = r.ReconcileJobs(tfjob, tfjob.Spec.TFReplicaSpecs, tfjob.Status, &tfjob.Spec.RunPolicy)
   157  	if err != nil {
   158  		logrus.Warnf("Reconcile Tensorflow Job error %v", err)
   159  		return ctrl.Result{}, err
   160  	}
   161  
   162  	t, err := util.DurationUntilExpireTime(&tfjob.Spec.RunPolicy, tfjob.Status)
   163  	if err != nil {
   164  		logrus.Warnf("Reconcile Tensorflow Job error %v", err)
   165  		return ctrl.Result{}, err
   166  	}
   167  	if t >= 0 {
   168  		return ctrl.Result{Requeue: true, RequeueAfter: t}, nil
   169  	}
   170  
   171  	return ctrl.Result{}, nil
   172  }
   173  
   174  // SetupWithManager sets up the controller with the Manager.
   175  func (r *TFJobReconciler) SetupWithManager(mgr ctrl.Manager, controllerThreads int) error {
   176  	c, err := controller.New(r.ControllerName(), mgr, controller.Options{
   177  		Reconciler:              r,
   178  		MaxConcurrentReconciles: controllerThreads,
   179  	})
   180  	if err != nil {
   181  		return err
   182  	}
   183  
   184  	// using onOwnerCreateFunc is easier to set defaults
   185  	if err = c.Watch(source.Kind(mgr.GetCache(), &kubeflowv1.TFJob{}), &handler.EnqueueRequestForObject{},
   186  		predicate.Funcs{CreateFunc: r.onOwnerCreateFunc()},
   187  	); err != nil {
   188  		return err
   189  	}
   190  
   191  	// eventHandler for owned objects
   192  	eventHandler := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &kubeflowv1.TFJob{}, handler.OnlyControllerOwner())
   193  	predicates := predicate.Funcs{
   194  		CreateFunc: util.OnDependentCreateFunc(r.Expectations),
   195  		UpdateFunc: util.OnDependentUpdateFunc(&r.JobController),
   196  		DeleteFunc: util.OnDependentDeleteFunc(r.Expectations),
   197  	}
   198  	// Create generic predicates
   199  	genericPredicates := predicate.Funcs{
   200  		CreateFunc: util.OnDependentCreateFuncGeneric(r.Expectations),
   201  		UpdateFunc: util.OnDependentUpdateFuncGeneric(&r.JobController),
   202  		DeleteFunc: util.OnDependentDeleteFuncGeneric(r.Expectations),
   203  	}
   204  	// inject watching for job related pod
   205  	if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Pod{}), eventHandler, predicates); err != nil {
   206  		return err
   207  	}
   208  	// inject watching for job related service
   209  	if err = c.Watch(source.Kind(mgr.GetCache(), &corev1.Service{}), eventHandler, predicates); err != nil {
   210  		return err
   211  	}
   212  	// skip watching volcano PodGroup if volcano PodGroup is not installed
   213  	if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: v1beta1.GroupName, Kind: "PodGroup"},
   214  		v1beta1.SchemeGroupVersion.Version); err == nil {
   215  		// inject watching for job related volcano PodGroup
   216  		if err = c.Watch(source.Kind(mgr.GetCache(), &v1beta1.PodGroup{}), eventHandler, genericPredicates); err != nil {
   217  			return err
   218  		}
   219  	}
   220  	// skip watching scheduler-plugins PodGroup if scheduler-plugins PodGroup is not installed
   221  	if _, err = mgr.GetRESTMapper().RESTMapping(schema.GroupKind{Group: schedulerpluginsv1alpha1.SchemeGroupVersion.Group, Kind: "PodGroup"},
   222  		schedulerpluginsv1alpha1.SchemeGroupVersion.Version); err == nil {
   223  		// inject watching for job related scheduler-plugins PodGroup
   224  		if err = c.Watch(source.Kind(mgr.GetCache(), &schedulerpluginsv1alpha1.PodGroup{}), eventHandler, genericPredicates); err != nil {
   225  			return err
   226  		}
   227  	}
   228  	return nil
   229  }
   230  
   231  func (r *TFJobReconciler) ControllerName() string {
   232  	return controllerName
   233  }
   234  
   235  func (r *TFJobReconciler) GetAPIGroupVersionKind() schema.GroupVersionKind {
   236  	return kubeflowv1.GroupVersion.WithKind(kubeflowv1.TFJobKind)
   237  }
   238  
   239  func (r *TFJobReconciler) GetAPIGroupVersion() schema.GroupVersion {
   240  	return kubeflowv1.GroupVersion
   241  }
   242  
   243  func (r *TFJobReconciler) GetGroupNameLabelValue() string {
   244  	return kubeflowv1.GroupVersion.Group
   245  }
   246  
   247  func (r *TFJobReconciler) GetFrameworkName() string {
   248  	return kubeflowv1.TFJobFrameworkName
   249  }
   250  
   251  func (r *TFJobReconciler) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
   252  	tfjob := &kubeflowv1.TFJob{}
   253  	err := r.Get(context.Background(), types.NamespacedName{
   254  		Namespace: namespace, Name: name,
   255  	}, tfjob)
   256  	return tfjob, err
   257  }
   258  
   259  func (r *TFJobReconciler) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) {
   260  	job := &kubeflowv1.TFJob{}
   261  
   262  	err := r.apiReader.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
   263  	if err != nil {
   264  		if errors.IsNotFound(err) {
   265  			logrus.Error(err, "tensorflow job not found", "namespace", namespace, "name", name)
   266  		} else {
   267  			logrus.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
   268  		}
   269  		return nil, err
   270  	}
   271  	return job, nil
   272  }
   273  
   274  // GetPodsForJob returns the set of pods that this job should manage.
   275  // It also reconciles ControllerRef by adopting/orphaning.
   276  // Note that the returned Pods are pointers into the cache.
   277  func (r *TFJobReconciler) GetPodsForJob(jobObject interface{}) ([]*corev1.Pod, error) {
   278  	job, ok := jobObject.(metav1.Object)
   279  	if !ok {
   280  		return nil, fmt.Errorf("job is not of type metav1.Object")
   281  	}
   282  
   283  	// Create selector.
   284  	selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
   285  		MatchLabels: r.GenLabels(job.GetName()),
   286  	})
   287  
   288  	if err != nil {
   289  		return nil, fmt.Errorf("couldn't convert Job selector: %v", err)
   290  	}
   291  	// List all pods to include those that don't match the selector anymore
   292  	// but have a ControllerRef pointing to this controller.
   293  	podlist := &corev1.PodList{}
   294  	err = r.List(context.Background(), podlist,
   295  		client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(job.GetNamespace()))
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	pods := util.JobControlledPodList(podlist.Items, job)
   301  
   302  	// If any adoptions are attempted, we should first recheck for deletion
   303  	// with an uncached quorum read sometime after listing Pods (see #42639).
   304  	canAdoptFunc := common.RecheckDeletionTimestamp(func() (metav1.Object, error) {
   305  		fresh, err := r.Controller.GetJobFromAPIClient(job.GetNamespace(), job.GetName())
   306  		if err != nil {
   307  			return nil, err
   308  		}
   309  		if fresh.GetUID() != job.GetUID() {
   310  			return nil, fmt.Errorf("original Job %v/%v is gone: got uid %v, wanted %v", job.GetNamespace(), job.GetName(), fresh.GetUID(), job.GetUID())
   311  		}
   312  		return fresh, nil
   313  	})
   314  	cm := control.NewPodControllerRefManager(r.PodControl, job, selector, r.Controller.GetAPIGroupVersionKind(), canAdoptFunc)
   315  	return cm.ClaimPods(pods)
   316  }
   317  
   318  // GetServicesForJob returns the set of services that this job should manage.
   319  // It also reconciles ControllerRef by adopting/orphaning.
   320  // Note that the returned services are pointers into the cache.
   321  func (r *TFJobReconciler) GetServicesForJob(jobObject interface{}) ([]*corev1.Service, error) {
   322  	job, ok := jobObject.(metav1.Object)
   323  	if !ok {
   324  		return nil, fmt.Errorf("job is not of type metav1.Object")
   325  	}
   326  
   327  	// Create selector
   328  	selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
   329  		MatchLabels: r.GenLabels(job.GetName()),
   330  	})
   331  
   332  	if err != nil {
   333  		return nil, fmt.Errorf("couldn't convert Job selector: %v", err)
   334  	}
   335  	// List all services to include those that don't match the selector anymore
   336  	// but have a ControllerRef pointing to this controller.
   337  	svclist := &corev1.ServiceList{}
   338  	err = r.List(context.Background(), svclist,
   339  		client.MatchingLabelsSelector{Selector: selector}, client.InNamespace(job.GetNamespace()))
   340  	if err != nil {
   341  		return nil, fmt.Errorf("couldn't get Service: %v", err)
   342  	}
   343  
   344  	// If any adoptions are attempted, we should first recheck for deletion
   345  	// with an uncached quorum read sometime after listing services (see #42639).
   346  	canAdoptFunc := common.RecheckDeletionTimestamp(func() (metav1.Object, error) {
   347  		fresh, err := r.GetJobFromInformerCache(job.GetNamespace(), job.GetName())
   348  		if err != nil {
   349  			return nil, err
   350  		}
   351  		if fresh.GetUID() != job.GetUID() {
   352  			return nil, fmt.Errorf("original Job %v/%v is gone: got uid %v, wanted %v", job.GetNamespace(), job.GetName(), fresh.GetUID(), job.GetUID())
   353  		}
   354  		return fresh, nil
   355  	})
   356  	cm := control.NewServiceControllerRefManager(r.ServiceControl, job, selector, r.Controller.GetAPIGroupVersionKind(), canAdoptFunc)
   357  
   358  	services := util.ConvertServiceList(svclist.Items)
   359  	return cm.ClaimServices(services)
   360  }
   361  
   362  func (r *TFJobReconciler) DeleteJob(job interface{}) error {
   363  	tfJob, ok := job.(*kubeflowv1.TFJob)
   364  	if !ok {
   365  		return fmt.Errorf("%v is not a type of TFJob", tfJob)
   366  	}
   367  
   368  	log := commonutil.LoggerForJob(tfJob)
   369  	if err := r.Delete(context.Background(), tfJob); err != nil {
   370  		r.recorder.Eventf(tfJob, v1.EventTypeWarning, FailedDeleteJobReason, "Error deleting: %v", err)
   371  		log.Errorf("failed to delete job %s/%s, %v", tfJob.Namespace, tfJob.Name, err)
   372  		return err
   373  	}
   374  
   375  	r.recorder.Eventf(tfJob, v1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", tfJob.Name)
   376  	log.Infof("job %s/%s has been deleted", tfJob.Namespace, tfJob.Name)
   377  	trainingoperatorcommon.DeletedJobsCounterInc(tfJob.Namespace, r.GetFrameworkName())
   378  	return nil
   379  }
   380  
   381  func (r *TFJobReconciler) UpdateJobStatus(job interface{}, replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec, jobStatus *kubeflowv1.JobStatus) error {
   382  	tfJob, ok := job.(*kubeflowv1.TFJob)
   383  	if !ok {
   384  		return fmt.Errorf("%v is not a type of TFJob", tfJob)
   385  	}
   386  
   387  	tfJobKey, err := common.KeyFunc(tfJob)
   388  	if err != nil {
   389  		utilruntime.HandleError(fmt.Errorf("couldn't get key for tfjob object %#v: %v", tfJob, err))
   390  		return err
   391  	}
   392  
   393  	logger := commonutil.LoggerForJob(tfJob)
   394  
   395  	worker0Completed, err := r.IsWorker0Completed(tfJob, replicas)
   396  	if err != nil {
   397  		logger.Warnf("check if worker 0 completed error %v", err)
   398  		return err
   399  	}
   400  
   401  	// Set StartTime.
   402  	if jobStatus.StartTime == nil {
   403  		now := metav1.Now()
   404  		jobStatus.StartTime = &now
   405  		// enqueue a sync to check if job past ActiveDeadlineSeconds
   406  		if tfJob.Spec.RunPolicy.ActiveDeadlineSeconds != nil {
   407  			logger.Infof("Job with ActiveDeadlineSeconds will sync after %d seconds", *tfJob.Spec.RunPolicy.ActiveDeadlineSeconds)
   408  			// TODO(Jeffwan): requeue job key in reconciler scenarios
   409  			r.WorkQueue.AddAfter(tfJobKey, time.Duration(*tfJob.Spec.RunPolicy.ActiveDeadlineSeconds)*time.Second)
   410  		}
   411  	}
   412  
   413  	// For the situation that jobStatus has a restarting condition, and append a running condition,
   414  	// the restarting condition will be removed from jobStatus by kubeflowv1.filterOutCondition(),
   415  	// so we need to record the existing restarting condition for later use.
   416  	var existingRestartingCondition *kubeflowv1.JobCondition
   417  	for _, condition := range jobStatus.Conditions {
   418  		if condition.Type == kubeflowv1.JobRestarting {
   419  			existingRestartingCondition = &kubeflowv1.JobCondition{
   420  				Reason:  condition.Reason,
   421  				Message: condition.Message,
   422  			}
   423  		}
   424  	}
   425  
   426  	// iterate the replica spec based on this order
   427  	allTypes := []kubeflowv1.ReplicaType{
   428  		kubeflowv1.TFJobReplicaTypeChief,
   429  		kubeflowv1.TFJobReplicaTypeEval,
   430  		kubeflowv1.TFJobReplicaTypeMaster,
   431  		kubeflowv1.TFJobReplicaTypePS,
   432  		kubeflowv1.TFJobReplicaTypeWorker,
   433  	}
   434  	for _, rtype := range allTypes {
   435  		if replicas[rtype] == nil {
   436  			continue
   437  		}
   438  		spec := replicas[rtype]
   439  		status := jobStatus.ReplicaStatuses[rtype]
   440  
   441  		// Expect to have `replicas - succeeded` pods alive.
   442  		succeeded := status.Succeeded
   443  		expected := *(spec.Replicas) - succeeded
   444  		running := status.Active
   445  		failed := status.Failed
   446  
   447  		logger.Infof("TFJob=%s/%s, ReplicaType=%s expected=%d, running=%d, failed=%d",
   448  			tfJob.Namespace, tfJob.Name, rtype, expected, running, failed)
   449  
   450  		// If the TFJob contains Chief or Master spec, then we will update the status
   451  		// according to the Chief/Master spec.
   452  		if ContainsChiefOrMasterSpec(tfJob.Spec.TFReplicaSpecs) {
   453  			if kubeflowv1.IsChieforMaster(rtype) {
   454  				if running > 0 {
   455  					msg := fmt.Sprintf("TFJob %s/%s is running.", tfJob.Namespace, tfJob.Name)
   456  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobRunningReason), msg)
   457  				}
   458  				if expected == 0 {
   459  					msg := fmt.Sprintf("TFJob %s/%s successfully completed.",
   460  						tfJob.Namespace, tfJob.Name)
   461  					r.recorder.Event(tfJob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSucceededReason), msg)
   462  					if jobStatus.CompletionTime == nil {
   463  						now := metav1.Now()
   464  						jobStatus.CompletionTime = &now
   465  					}
   466  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobSucceeded, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSucceededReason), msg)
   467  					trainingoperatorcommon.SuccessfulJobsCounterInc(tfJob.Namespace, r.GetFrameworkName())
   468  				}
   469  			}
   470  		} else {
   471  			if rtype == kubeflowv1.TFJobReplicaTypeWorker {
   472  				// Leave a succeeded condition for the following two cases:
   473  				// 1. If default success policy is used and worker 0 has completed.
   474  				// 2. If `SuccessPolicyAllWorkers` success policy is used and all workers are succeeded.
   475  				if expected == 0 || (worker0Completed && *tfJob.Spec.SuccessPolicy != kubeflowv1.SuccessPolicyAllWorkers) {
   476  					msg := fmt.Sprintf("TFJob %s/%s successfully completed.",
   477  						tfJob.Namespace, tfJob.Name)
   478  					r.recorder.Event(tfJob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSucceededReason), msg)
   479  					if jobStatus.CompletionTime == nil {
   480  						now := metav1.Now()
   481  						jobStatus.CompletionTime = &now
   482  					}
   483  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobSucceeded, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobSucceededReason), msg)
   484  					trainingoperatorcommon.SuccessfulJobsCounterInc(tfJob.Namespace, r.GetFrameworkName())
   485  				} else if running > 0 {
   486  					// Some workers are still running, leave a running condition.
   487  					msg := fmt.Sprintf("TFJob %s/%s is running.", tfJob.Namespace, tfJob.Name)
   488  					commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRunning, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobRunningReason), msg)
   489  				}
   490  			}
   491  		}
   492  
   493  		if failed > 0 {
   494  			// For the situation that jobStatus has a restarting condition, and appends a new running condition,
   495  			// the restarting condition will be removed from jobStatus by kubeflowv1.filterOutCondition(),
   496  			// so we need to append the restarting condition back to jobStatus.
   497  			if existingRestartingCondition != nil {
   498  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobRestarting, corev1.ConditionTrue, existingRestartingCondition.Reason, existingRestartingCondition.Message)
   499  				// job is restarting, no need to set it failed
   500  				// we know it because we update the status condition when reconciling the replicas
   501  				trainingoperatorcommon.RestartedJobsCounterInc(tfJob.Namespace, r.GetFrameworkName())
   502  			} else {
   503  				if tfJob.Spec.EnableDynamicWorker && rtype == kubeflowv1.TFJobReplicaTypeWorker {
   504  					commonutil.LoggerForJob(tfJob).Infof("TFJob %s/%s continues regardless %d Worker replica(s) failed as enableDynamicWorker is set true.",
   505  						tfJob.Namespace, tfJob.Name, failed)
   506  					continue
   507  				}
   508  				msg := fmt.Sprintf("TFJob %s/%s has failed because %d %s replica(s) failed.",
   509  					tfJob.Namespace, tfJob.Name, failed, rtype)
   510  				r.recorder.Event(tfJob, corev1.EventTypeNormal, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobFailedReason), msg)
   511  				if jobStatus.CompletionTime == nil {
   512  					now := metav1.Now()
   513  					jobStatus.CompletionTime = &now
   514  				}
   515  				commonutil.UpdateJobConditions(jobStatus, kubeflowv1.JobFailed, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobFailedReason), msg)
   516  				trainingoperatorcommon.FailedJobsCounterInc(tfJob.Namespace, r.GetFrameworkName())
   517  			}
   518  		}
   519  	}
   520  	// we assign the jobStatus to the tfJob.Status for testing purpose
   521  	// it won't effect the main reconcile logic
   522  	// because we already use oldStatus := jobStatus.DeepCopy() to record the oldStatus
   523  	// and use !reflect.DeepEqual(*oldStatus, jobStatus) to decide whether to update the tfJob or not
   524  	tfJob.Status = *jobStatus.DeepCopy()
   525  
   526  	return nil
   527  }
   528  
   529  func (r *TFJobReconciler) UpdateJobStatusInApiServer(job interface{}, jobStatus *kubeflowv1.JobStatus) error {
   530  	if jobStatus.ReplicaStatuses == nil {
   531  		jobStatus.ReplicaStatuses = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaStatus{}
   532  	}
   533  
   534  	tfJob, ok := job.(*kubeflowv1.TFJob)
   535  	if !ok {
   536  		return fmt.Errorf("%v is not a type of TFJob", tfJob)
   537  	}
   538  
   539  	startTime := time.Now()
   540  	logger := commonutil.LoggerForJob(tfJob)
   541  	defer func() {
   542  		logger.Infof("Finished updating TFJobs Status %q (%v)",
   543  			tfJob.Name, time.Since(startTime))
   544  	}()
   545  
   546  	tfJob = tfJob.DeepCopy()
   547  	tfJob.Status = *jobStatus.DeepCopy()
   548  
   549  	result := r.Status().Update(context.Background(), tfJob)
   550  
   551  	if result != nil {
   552  		r.Log.WithValues("tfjob", types.NamespacedName{
   553  			Namespace: tfJob.GetNamespace(),
   554  			Name:      tfJob.GetName(),
   555  		})
   556  		return result
   557  	}
   558  
   559  	return nil
   560  }
   561  
   562  // Same as Func (tc *TFController) SetClusterSpec(...) in pod.go
   563  func (r *TFJobReconciler) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
   564  	tfjob, ok := job.(*kubeflowv1.TFJob)
   565  	if !ok {
   566  		return fmt.Errorf("%v is not a type of TFJob", tfjob)
   567  	}
   568  
   569  	// Do not set TF_CONFIG for local training jobs.
   570  	if !isDistributed(tfjob) {
   571  		return nil
   572  	}
   573  	// Generate TF_CONFIG JSON string.
   574  	tfConfigStr, err := genTFConfigJSONStr(tfjob, rtype, index)
   575  	if err != nil {
   576  		return err
   577  	}
   578  
   579  	if tfConfigStr == "" {
   580  		return nil
   581  	}
   582  	// Add TF_CONFIG environment variable to tensorflow container in the pod.
   583  	for i := range podTemplate.Spec.Containers {
   584  		if podTemplate.Spec.Containers[i].Name == kubeflowv1.TFJobDefaultContainerName {
   585  			if len(podTemplate.Spec.Containers[i].Env) == 0 {
   586  				podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
   587  			}
   588  			podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
   589  				Name:  tfConfig,
   590  				Value: tfConfigStr,
   591  			})
   592  			break
   593  		}
   594  	}
   595  	return nil
   596  }
   597  
   598  func (r *TFJobReconciler) GetDefaultContainerName() string {
   599  	return kubeflowv1.TFJobDefaultContainerName
   600  }
   601  
   602  func (r *TFJobReconciler) GetDefaultContainerPortName() string {
   603  	return kubeflowv1.TFJobDefaultPortName
   604  }
   605  
   606  func (r *TFJobReconciler) IsMasterRole(replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec,
   607  	rtype kubeflowv1.ReplicaType, index int) bool {
   608  	if ContainsChiefOrMasterSpec(replicas) {
   609  		return rtype == kubeflowv1.TFJobReplicaTypeChief || rtype == kubeflowv1.TFJobReplicaTypeMaster
   610  	}
   611  	// else check if it is worker with index 0
   612  	return rtype == kubeflowv1.TFJobReplicaTypeWorker && index == 0
   613  }
   614  
   615  // IsWorker0Completed returns true if pod of worker0 succeeded and exited with 0
   616  func (r *TFJobReconciler) IsWorker0Completed(tfJob *kubeflowv1.TFJob, replicas map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec) (bool, error) {
   617  	worker0Completed := false
   618  	_, ok := replicas[kubeflowv1.TFJobReplicaTypeWorker]
   619  	if !ok {
   620  		return true, nil
   621  	}
   622  	podSlices, err := r.getPodSlices(tfJob, replicas[kubeflowv1.TFJobReplicaTypeWorker].Replicas)
   623  	if err != nil {
   624  		return false, err
   625  	}
   626  	for index, podSlice := range podSlices {
   627  		if len(podSlice) == 1 {
   628  			pod := podSlice[0]
   629  			exitCode := getContainerExitCode(pod)
   630  			if index == 0 && exitCode == 0 && pod.Status.Phase == v1.PodSucceeded {
   631  				worker0Completed = true
   632  			}
   633  		}
   634  	}
   635  	return worker0Completed, nil
   636  }
   637  
   638  // getPodSlices returns a slice, which element is the slice of pod.
   639  // It gives enough information to caller to make decision to up/down scale resources.
   640  func (r *TFJobReconciler) getPodSlices(tfjob *kubeflowv1.TFJob, replicasNum *int32) ([][]*v1.Pod, error) {
   641  	logger := commonutil.LoggerForReplica(tfjob, strings.ToLower(string(kubeflowv1.TFJobReplicaTypeWorker)))
   642  
   643  	pods, err := r.GetPodsForJob(tfjob)
   644  	if err != nil {
   645  		commonutil.LoggerForJob(tfjob).Warnf("getPodsForTFJob error %v", err)
   646  		return nil, err
   647  	}
   648  
   649  	// Get all pods for the type rt.
   650  	pods, err = r.JobController.FilterPodsForReplicaType(pods, strings.ToLower(string(kubeflowv1.TFJobReplicaTypeWorker)))
   651  	if err != nil {
   652  		return nil, err
   653  	}
   654  
   655  	podSlices := r.GetPodSlices(pods, int(*replicasNum), logger)
   656  	return podSlices, nil
   657  }
   658  
   659  // onOwnerCreateFunc modify creation condition.
   660  func (r *TFJobReconciler) onOwnerCreateFunc() func(event.CreateEvent) bool {
   661  	return func(e event.CreateEvent) bool {
   662  		tfJob, ok := e.Object.(*kubeflowv1.TFJob)
   663  		if !ok {
   664  			return true
   665  		}
   666  
   667  		r.Scheme.Default(tfJob)
   668  		msg := fmt.Sprintf("TFJob %s is created.", e.Object.GetName())
   669  		logrus.Info(msg)
   670  		trainingoperatorcommon.CreatedJobsCounterInc(tfJob.Namespace, r.GetFrameworkName())
   671  		commonutil.UpdateJobConditions(&tfJob.Status, kubeflowv1.JobCreated, corev1.ConditionTrue, commonutil.NewReason(kubeflowv1.TFJobKind, commonutil.JobCreatedReason), msg)
   672  		return true
   673  	}
   674  }