github.com/kubeflow/training-operator@v1.7.0/sdk/python/test/e2e/test_e2e_tfjob.py (about)

     1  # Copyright 2021 kubeflow.org.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #    http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  import os
    16  import logging
    17  import pytest
    18  
    19  from kubernetes.client import V1PodTemplateSpec
    20  from kubernetes.client import V1ObjectMeta
    21  from kubernetes.client import V1PodSpec
    22  from kubernetes.client import V1Container
    23  from kubernetes.client import V1ResourceRequirements
    24  
    25  from kubeflow.training import TrainingClient
    26  from kubeflow.training import KubeflowOrgV1ReplicaSpec
    27  from kubeflow.training import KubeflowOrgV1RunPolicy
    28  from kubeflow.training import KubeflowOrgV1TFJob
    29  from kubeflow.training import KubeflowOrgV1TFJobSpec
    30  from kubeflow.training import KubeflowOrgV1SchedulingPolicy
    31  from kubeflow.training.constants import constants
    32  
    33  from test.e2e.utils import verify_job_e2e, verify_unschedulable_job_e2e, get_pod_spec_scheduler_name
    34  from test.e2e.constants import TEST_GANG_SCHEDULER_NAME_ENV_KEY
    35  from test.e2e.constants import GANG_SCHEDULERS, NONE_GANG_SCHEDULERS
    36  
    37  logging.basicConfig(format="%(message)s")
    38  logging.getLogger().setLevel(logging.INFO)
    39  
    40  TRAINING_CLIENT = TrainingClient()
    41  JOB_NAME = "tfjob-mnist-ci-test"
    42  CONTAINER_NAME = "tensorflow"
    43  GANG_SCHEDULER_NAME = os.getenv(TEST_GANG_SCHEDULER_NAME_ENV_KEY)
    44  
    45  
    46  @pytest.mark.skipif(
    47      GANG_SCHEDULER_NAME in NONE_GANG_SCHEDULERS, reason="For gang-scheduling",
    48  )
    49  def test_sdk_e2e_with_gang_scheduling(job_namespace):
    50      container = generate_container()
    51  
    52      worker = KubeflowOrgV1ReplicaSpec(
    53          replicas=1,
    54          restart_policy="Never",
    55          template=V1PodTemplateSpec(
    56              metadata=V1ObjectMeta(annotations={constants.ISTIO_SIDECAR_INJECTION: "false"}),
    57              spec=V1PodSpec(
    58                  containers=[container],
    59                  scheduler_name=get_pod_spec_scheduler_name(GANG_SCHEDULER_NAME),
    60              )
    61          ),
    62      )
    63  
    64      unschedulable_tfjob = generate_tfjob(worker, KubeflowOrgV1SchedulingPolicy(min_available=10), job_namespace)
    65      schedulable_tfjob = generate_tfjob(worker, KubeflowOrgV1SchedulingPolicy(min_available=1), job_namespace)
    66  
    67      TRAINING_CLIENT.create_tfjob(unschedulable_tfjob, job_namespace)
    68      logging.info(f"List of created {constants.TFJOB_KIND}s")
    69      logging.info(TRAINING_CLIENT.list_tfjobs(job_namespace))
    70  
    71      verify_unschedulable_job_e2e(
    72          TRAINING_CLIENT,
    73          JOB_NAME,
    74          job_namespace,
    75          constants.TFJOB_KIND,
    76      )
    77  
    78      TRAINING_CLIENT.patch_tfjob(schedulable_tfjob, JOB_NAME, job_namespace)
    79      logging.info(f"List of patched {constants.TFJOB_KIND}s")
    80      logging.info(TRAINING_CLIENT.list_tfjobs(job_namespace))
    81  
    82      verify_job_e2e(
    83          TRAINING_CLIENT,
    84          JOB_NAME,
    85          job_namespace,
    86          constants.TFJOB_KIND,
    87          CONTAINER_NAME,
    88      )
    89  
    90      TRAINING_CLIENT.delete_tfjob(JOB_NAME, job_namespace)
    91  
    92  
    93  @pytest.mark.skipif(
    94      GANG_SCHEDULER_NAME in GANG_SCHEDULERS, reason="For plain scheduling",
    95  )
    96  def test_sdk_e2e(job_namespace):
    97      container = generate_container()
    98  
    99      worker = KubeflowOrgV1ReplicaSpec(
   100          replicas=1,
   101          restart_policy="Never",
   102          template=V1PodTemplateSpec(metadata=V1ObjectMeta(annotations={constants.ISTIO_SIDECAR_INJECTION: "false"}),
   103                                     spec=V1PodSpec(containers=[container])),
   104      )
   105  
   106      tfjob = generate_tfjob(worker, job_namespace=job_namespace)
   107  
   108      TRAINING_CLIENT.create_tfjob(tfjob, job_namespace)
   109      logging.info(f"List of created {constants.TFJOB_KIND}s")
   110      logging.info(TRAINING_CLIENT.list_tfjobs(job_namespace))
   111  
   112      verify_job_e2e(
   113          TRAINING_CLIENT, JOB_NAME, job_namespace, constants.TFJOB_KIND, CONTAINER_NAME,
   114      )
   115  
   116      TRAINING_CLIENT.delete_tfjob(JOB_NAME, job_namespace)
   117  
   118  
   119  def generate_tfjob(
   120      worker: KubeflowOrgV1ReplicaSpec,
   121      scheduling_policy: KubeflowOrgV1SchedulingPolicy = None,
   122      job_namespace: str = "default",
   123  ) -> KubeflowOrgV1TFJob:
   124      return KubeflowOrgV1TFJob(
   125          api_version="kubeflow.org/v1",
   126          kind="TFJob",
   127          metadata=V1ObjectMeta(name=JOB_NAME, namespace=job_namespace),
   128          spec=KubeflowOrgV1TFJobSpec(
   129              run_policy=KubeflowOrgV1RunPolicy(
   130                  clean_pod_policy="None",
   131                  scheduling_policy=scheduling_policy,
   132              ),
   133              tf_replica_specs={"Worker": worker},
   134          ),
   135      )
   136  
   137  
   138  def generate_container() -> V1Container:
   139      return V1Container(
   140          name=CONTAINER_NAME,
   141          image="gcr.io/kubeflow-ci/tf-mnist-with-summaries:1.0",
   142          command=[
   143              "python",
   144              "/var/tf_mnist/mnist_with_summaries.py",
   145              "--log_dir=/train/logs",
   146              "--learning_rate=0.01",
   147              "--batch_size=150",
   148          ],
   149          resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.75"}),
   150      )