github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/subprocess_server.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import contextlib 21 import glob 22 import hashlib 23 import logging 24 import os 25 import re 26 import shutil 27 import signal 28 import socket 29 import subprocess 30 import tempfile 31 import threading 32 import time 33 import zipfile 34 from urllib.error import URLError 35 from urllib.request import urlopen 36 37 import grpc 38 39 from apache_beam.version import __version__ as beam_version 40 41 _LOGGER = logging.getLogger(__name__) 42 43 44 class SubprocessServer(object): 45 """An abstract base class for running GRPC Servers as an external process. 46 47 This class acts as a context which will start up a server, provides a stub 48 to connect to it, and then shuts the server down. For example:: 49 50 with SubprocessServer(GrpcStubClass, [executable, arg, ...]) as stub: 51 stub.CallService(...) 52 """ 53 def __init__(self, stub_class, cmd, port=None): 54 """Creates the server object. 55 56 :param stub_class: the auto-generated GRPC client stub class used for 57 connecting to the GRPC service 58 :param cmd: command (including arguments) for starting up the server, 59 suitable for passing to `subprocess.POpen`. 60 :param port: (optional) the port at which the subprocess will serve its 61 service. If not given, one will be randomly chosen and the special 62 string "{{PORT}}" will be substituted in the command line arguments 63 with the chosen port. 64 """ 65 self._process_lock = threading.RLock() 66 self._process = None 67 self._stub_class = stub_class 68 self._cmd = [str(arg) for arg in cmd] 69 self._port = port 70 71 def __enter__(self): 72 return self.start() 73 74 def __exit__(self, *unused_args): 75 self.stop() 76 77 def start(self): 78 try: 79 endpoint = self.start_process() 80 wait_secs = .1 81 channel_options = [("grpc.max_receive_message_length", -1), 82 ("grpc.max_send_message_length", -1)] 83 channel = grpc.insecure_channel(endpoint, options=channel_options) 84 channel_ready = grpc.channel_ready_future(channel) 85 while True: 86 if self._process is not None and self._process.poll() is not None: 87 _LOGGER.error("Starting job service with %s", self._process.args) 88 raise RuntimeError( 89 'Service failed to start up with error %s' % self._process.poll()) 90 try: 91 channel_ready.result(timeout=wait_secs) 92 break 93 except (grpc.FutureTimeoutError, grpc.RpcError): 94 wait_secs *= 1.2 95 logging.log( 96 logging.WARNING if wait_secs > 1 else logging.DEBUG, 97 'Waiting for grpc channel to be ready at %s.', 98 endpoint) 99 return self._stub_class(channel) 100 except: # pylint: disable=bare-except 101 _LOGGER.exception("Error bringing up service") 102 self.stop() 103 raise 104 105 def start_process(self): 106 with self._process_lock: 107 if self._process: 108 self.stop() 109 if self._port: 110 port = self._port 111 cmd = self._cmd 112 else: 113 port, = pick_port(None) 114 cmd = [arg.replace('{{PORT}}', str(port)) for arg in self._cmd] 115 endpoint = 'localhost:%s' % port 116 _LOGGER.info("Starting service with %s", str(cmd).replace("',", "'")) 117 self._process = subprocess.Popen( 118 cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 119 120 # Emit the output of this command as info level logging. 121 def log_stdout(): 122 line = self._process.stdout.readline() 123 while line: 124 # The log obtained from stdout is bytes, decode it into string. 125 # Remove newline via rstrip() to not print an empty line. 126 _LOGGER.info(line.decode(errors='backslashreplace').rstrip()) 127 line = self._process.stdout.readline() 128 129 t = threading.Thread(target=log_stdout) 130 t.daemon = True 131 t.start() 132 return endpoint 133 134 def stop(self): 135 self.stop_process() 136 137 def stop_process(self): 138 with self._process_lock: 139 if not self._process: 140 return 141 for _ in range(5): 142 if self._process.poll() is not None: 143 break 144 logging.debug("Sending SIGINT to job_server") 145 self._process.send_signal(signal.SIGINT) 146 time.sleep(1) 147 if self._process.poll() is None: 148 self._process.kill() 149 self._process = None 150 151 def local_temp_dir(self, **kwargs): 152 return tempfile.mkdtemp(dir=self._local_temp_root, **kwargs) 153 154 155 class JavaJarServer(SubprocessServer): 156 157 MAVEN_CENTRAL_REPOSITORY = 'https://repo.maven.apache.org/maven2' 158 BEAM_GROUP_ID = 'org.apache.beam' 159 JAR_CACHE = os.path.expanduser("~/.apache_beam/cache/jars") 160 161 _BEAM_SERVICES = type( 162 'local', (threading.local, ), 163 dict(__init__=lambda self: setattr(self, 'replacements', {})))() 164 165 def __init__(self, stub_class, path_to_jar, java_arguments, classpath=None): 166 if classpath: 167 # java -jar ignores the classpath, so we make a new jar that embeds 168 # the requested classpath. 169 path_to_jar = self.make_classpath_jar(path_to_jar, classpath) 170 super().__init__( 171 stub_class, ['java', '-jar', path_to_jar] + list(java_arguments)) 172 self._existing_service = path_to_jar if _is_service_endpoint( 173 path_to_jar) else None 174 175 def start_process(self): 176 if self._existing_service: 177 return self._existing_service 178 else: 179 if not shutil.which('java'): 180 raise RuntimeError( 181 'Java must be installed on this system to use this ' 182 'transform/runner.') 183 return super().start_process() 184 185 def stop_process(self): 186 if self._existing_service: 187 pass 188 else: 189 return super().stop_process() 190 191 @classmethod 192 def jar_name(cls, artifact_id, version, classifier=None, appendix=None): 193 return '-'.join( 194 filter(None, [artifact_id, appendix, version, classifier])) + '.jar' 195 196 @classmethod 197 def path_to_maven_jar( 198 cls, 199 artifact_id, 200 group_id, 201 version, 202 repository=MAVEN_CENTRAL_REPOSITORY, 203 classifier=None, 204 appendix=None): 205 return '/'.join([ 206 repository, 207 group_id.replace('.', '/'), 208 artifact_id, 209 version, 210 cls.jar_name(artifact_id, version, classifier, appendix) 211 ]) 212 213 @classmethod 214 def path_to_beam_jar( 215 cls, 216 gradle_target, 217 appendix=None, 218 version=beam_version, 219 artifact_id=None): 220 if gradle_target in cls._BEAM_SERVICES.replacements: 221 return cls._BEAM_SERVICES.replacements[gradle_target] 222 223 gradle_package = gradle_target.strip(':').rsplit(':', 1)[0] 224 if not artifact_id: 225 artifact_id = 'beam-' + gradle_package.replace(':', '-') 226 project_root = os.path.sep.join( 227 os.path.abspath(__file__).split(os.path.sep)[:-5]) 228 local_path = os.path.join( 229 project_root, 230 gradle_package.replace(':', os.path.sep), 231 'build', 232 'libs', 233 cls.jar_name( 234 artifact_id, 235 version.replace('.dev', ''), 236 classifier='SNAPSHOT', 237 appendix=appendix)) 238 if os.path.exists(local_path): 239 _LOGGER.info('Using pre-built snapshot at %s', local_path) 240 return local_path 241 elif '.dev' in version: 242 # TODO: Attempt to use nightly snapshots? 243 raise RuntimeError( 244 ( 245 '%s not found. ' 246 'Please build the server with \n cd %s; ./gradlew %s') % 247 (local_path, os.path.abspath(project_root), gradle_target)) 248 else: 249 return cls.path_to_maven_jar( 250 artifact_id, 251 cls.BEAM_GROUP_ID, 252 version, 253 cls.MAVEN_CENTRAL_REPOSITORY, 254 appendix=appendix) 255 256 @classmethod 257 def local_jar(cls, url, cache_dir=None): 258 if cache_dir is None: 259 cache_dir = cls.JAR_CACHE 260 # TODO: Verify checksum? 261 if _is_service_endpoint(url): 262 return url 263 elif os.path.exists(url): 264 return url 265 else: 266 cached_jar = os.path.join(cache_dir, os.path.basename(url)) 267 if os.path.exists(cached_jar): 268 _LOGGER.info('Using cached job server jar from %s' % url) 269 else: 270 _LOGGER.info('Downloading job server jar from %s' % url) 271 if not os.path.exists(cache_dir): 272 os.makedirs(cache_dir) 273 # TODO: Clean up this cache according to some policy. 274 try: 275 url_read = urlopen(url) 276 with open(cached_jar + '.tmp', 'wb') as jar_write: 277 shutil.copyfileobj(url_read, jar_write, length=1 << 20) 278 os.rename(cached_jar + '.tmp', cached_jar) 279 except URLError as e: 280 raise RuntimeError( 281 'Unable to fetch remote job server jar at %s: %s' % (url, e)) 282 return cached_jar 283 284 @classmethod 285 @contextlib.contextmanager 286 def beam_services(cls, replacements): 287 try: 288 old = cls._BEAM_SERVICES.replacements 289 cls._BEAM_SERVICES.replacements = dict(old, **replacements) 290 yield 291 finally: 292 cls._BEAM_SERVICES.replacements = old 293 294 @classmethod 295 def make_classpath_jar(cls, main_jar, extra_jars, cache_dir=None): 296 if cache_dir is None: 297 cache_dir = cls.JAR_CACHE 298 composite_jar_dir = os.path.join(cache_dir, 'composite-jars') 299 os.makedirs(composite_jar_dir, exist_ok=True) 300 classpath = [] 301 # Class-Path references from a jar must be relative, so we create 302 # a relatively-addressable subdirectory with symlinks to all the 303 # required jars. 304 for pattern in [main_jar] + list(extra_jars): 305 for path in glob.glob(pattern) or [pattern]: 306 path = os.path.abspath(path) 307 rel_path = hashlib.sha256( 308 path.encode('utf-8')).hexdigest() + os.path.splitext(path)[1] 309 classpath.append(rel_path) 310 if not os.path.lexists(os.path.join(composite_jar_dir, rel_path)): 311 os.symlink(path, os.path.join(composite_jar_dir, rel_path)) 312 # Now create a single jar that simply references the rest and has the same 313 # main class as main_jar. 314 composite_jar = os.path.join( 315 composite_jar_dir, 316 hashlib.sha256(' '.join(sorted(classpath)).encode('ascii')).hexdigest() 317 + '.jar') 318 if not os.path.exists(composite_jar): 319 with zipfile.ZipFile(main_jar) as main: 320 with main.open('META-INF/MANIFEST.MF') as manifest: 321 main_class = next( 322 filter(lambda line: line.startswith(b'Main-Class: '), manifest)) 323 with zipfile.ZipFile(composite_jar + '.tmp', 'w') as composite: 324 with composite.open('META-INF/MANIFEST.MF', 'w') as manifest: 325 manifest.write(b'Manifest-Version: 1.0\n') 326 manifest.write(main_class) 327 manifest.write( 328 b'Class-Path: ' + '\n '.join(classpath).encode('ascii') + b'\n') 329 os.rename(composite_jar + '.tmp', composite_jar) 330 return composite_jar 331 332 333 def _is_service_endpoint(path): 334 return re.match(r'^[a-zA-Z0-9.-]+:\d+$', path) 335 336 337 def pick_port(*ports): 338 """ 339 Returns a list of ports, same length as input ports list, but replaces 340 all None or 0 ports with a random free port. 341 """ 342 sockets = [] 343 344 def find_free_port(port): 345 if port: 346 return port 347 else: 348 try: 349 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 350 except OSError as e: 351 # [Errno 97] Address family not supported by protocol 352 # Likely indicates we are in an IPv6-only environment (BEAM-10618). Try 353 # again with AF_INET6. 354 if e.errno == 97: 355 s = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) 356 else: 357 raise e 358 359 sockets.append(s) 360 s.bind(('localhost', 0)) 361 return s.getsockname()[1] 362 363 ports = list(map(find_free_port, ports)) 364 # Close sockets only now to avoid the same port to be chosen twice 365 for s in sockets: 366 s.close() 367 return ports