github.com/kubeflow/training-operator@v1.7.0/hack/python-sdk/post_gen.py (about) 1 #!/usr/bin/env python 2 3 # Copyright 2021 The Kubeflow Authors. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 """ 18 This script is used for updating generated SDK files. 19 """ 20 21 import os 22 import fileinput 23 import re 24 25 __replacements = [ 26 ("import kubeflow.training", "from kubeflow.training.models import *"), 27 ("kubeflow.training.models.v1\/.*.v1.", "V1"), 28 ("kubeflow.training.models.kubeflow/org/v1/", "kubeflow_org_v1_"), 29 ("\.kubeflow.org.v1\.", ".KubeflowOrgV1"), 30 ] 31 32 sdk_dir = os.path.abspath(os.path.join(__file__, "../../..", "sdk/python")) 33 34 35 def main(): 36 fix_test_files() 37 add_imports() 38 39 40 def fix_test_files() -> None: 41 """ 42 Fix invalid model imports in generated model tests 43 """ 44 test_folder_dir = os.path.join(sdk_dir, "test") 45 test_files = os.listdir(test_folder_dir) 46 for test_file in test_files: 47 print(f"Precessing file {test_file}") 48 if test_file.endswith(".py"): 49 with fileinput.FileInput( 50 os.path.join(test_folder_dir, test_file), inplace=True 51 ) as file: 52 for line in file: 53 print(_apply_regex(line), end="") 54 55 56 def add_imports() -> None: 57 with open(os.path.join(sdk_dir, "kubeflow/training/__init__.py"), "a") as f: 58 f.write("from kubeflow.training.api.training_client import TrainingClient\n") 59 with open(os.path.join(sdk_dir, "kubeflow/__init__.py"), "a") as f: 60 f.write("__path__ = __import__('pkgutil').extend_path(__path__, __name__)") 61 62 # Add Kubernetes models to proper deserialization of Training models. 63 with open(os.path.join(sdk_dir, "kubeflow/training/models/__init__.py"), "r") as f: 64 new_lines = [] 65 for line in f.readlines(): 66 new_lines.append(line) 67 if line.startswith("from __future__ import absolute_import"): 68 new_lines.append("\n") 69 new_lines.append("# Import Kubernetes models.\n") 70 new_lines.append("from kubernetes.client import *\n") 71 with open(os.path.join(sdk_dir, "kubeflow/training/models/__init__.py"), "w") as f: 72 f.writelines(new_lines) 73 74 75 def _apply_regex(input_str: str) -> str: 76 for pattern, replacement in __replacements: 77 input_str = re.sub(pattern, replacement, input_str) 78 return input_str 79 80 81 if __name__ == "__main__": 82 main()