github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/artifact_service.py (about) 1 # Licensed to the Apache Software Foundation (ASF) under one or more 2 # contributor license agreements. See the NOTICE file distributed with 3 # this work for additional information regarding copyright ownership. 4 # The ASF licenses this file to You under the Apache License, Version 2.0 5 # (the "License"); you may not use this file except in compliance with 6 # the License. You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # 16 17 """Implementation of an Artifact{Staging,Retrieval}Service. 18 19 The staging service here can be backed by any beam filesystem. 20 """ 21 22 # pytype: skip-file 23 24 import concurrent.futures 25 import hashlib 26 import os 27 import queue 28 import sys 29 import tempfile 30 import threading 31 from io import BytesIO 32 from typing import Any 33 from typing import BinaryIO # pylint: disable=unused-import 34 from typing import Callable 35 from typing import Dict 36 from typing import List 37 from typing import MutableMapping 38 from typing import Optional 39 from typing import Tuple 40 from urllib.request import urlopen 41 42 import grpc 43 44 from apache_beam.io import filesystems 45 from apache_beam.io.filesystems import CompressionTypes 46 from apache_beam.portability import common_urns 47 from apache_beam.portability.api import beam_artifact_api_pb2 48 from apache_beam.portability.api import beam_artifact_api_pb2_grpc 49 from apache_beam.portability.api import beam_runner_api_pb2 50 from apache_beam.utils import proto_utils 51 52 53 class ArtifactRetrievalService( 54 beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer): 55 56 _DEFAULT_CHUNK_SIZE = 2 << 20 57 58 def __init__( 59 self, 60 file_reader, # type: Callable[[str], BinaryIO] 61 chunk_size=None, 62 ): 63 self._file_reader = file_reader 64 self._chunk_size = chunk_size or self._DEFAULT_CHUNK_SIZE 65 66 def ResolveArtifacts(self, request, context=None): 67 return beam_artifact_api_pb2.ResolveArtifactsResponse( 68 replacements=request.artifacts) 69 70 def GetArtifact(self, request, context=None): 71 if request.artifact.type_urn == common_urns.artifact_types.FILE.urn: 72 payload = proto_utils.parse_Bytes( 73 request.artifact.type_payload, 74 beam_runner_api_pb2.ArtifactFilePayload) 75 read_handle = self._file_reader(payload.path) 76 elif request.artifact.type_urn == common_urns.artifact_types.URL.urn: 77 payload = proto_utils.parse_Bytes( 78 request.artifact.type_payload, beam_runner_api_pb2.ArtifactUrlPayload) 79 read_handle = urlopen(payload.url) 80 elif request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn: 81 payload = proto_utils.parse_Bytes( 82 request.artifact.type_payload, 83 beam_runner_api_pb2.EmbeddedFilePayload) 84 read_handle = BytesIO(payload.data) 85 else: 86 raise NotImplementedError(request.artifact.type_urn) 87 88 with read_handle as fin: 89 while True: 90 chunk = fin.read(self._chunk_size) 91 if not chunk: 92 break 93 yield beam_artifact_api_pb2.GetArtifactResponse(data=chunk) 94 95 96 class ArtifactStagingService( 97 beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer): 98 def __init__( 99 self, 100 file_writer, # type: Callable[[str, Optional[str]], Tuple[BinaryIO, str]] 101 ): 102 self._lock = threading.Lock() 103 self._jobs_to_stage = { 104 } # type: Dict[str, Tuple[Dict[Any, List[beam_runner_api_pb2.ArtifactInformation]], threading.Event]] 105 self._file_writer = file_writer 106 107 def register_job( 108 self, 109 staging_token, # type: str 110 dependency_sets # type: MutableMapping[Any, List[beam_runner_api_pb2.ArtifactInformation]] 111 ): 112 if staging_token in self._jobs_to_stage: 113 raise ValueError('Already staging %s' % staging_token) 114 with self._lock: 115 self._jobs_to_stage[staging_token] = ( 116 dict(dependency_sets), threading.Event()) 117 118 def resolved_deps(self, staging_token, timeout=None): 119 with self._lock: 120 dependency_sets, event = self._jobs_to_stage[staging_token] 121 try: 122 if not event.wait(timeout): 123 raise concurrent.futures.TimeoutError() 124 return dependency_sets 125 finally: 126 with self._lock: 127 del self._jobs_to_stage[staging_token] 128 129 def ReverseArtifactRetrievalService(self, responses, context=None): 130 staging_token = next(responses).staging_token 131 with self._lock: 132 try: 133 dependency_sets, event = self._jobs_to_stage[staging_token] 134 except KeyError: 135 if context: 136 context.set_code(grpc.StatusCode.NOT_FOUND) 137 context.set_details('No such staging token: %r' % staging_token) 138 raise 139 140 requests = _QueueIter() 141 142 class ForwardingRetrievalService(object): 143 def ResolveArtifactss(self, request): 144 requests.put( 145 beam_artifact_api_pb2.ArtifactRequestWrapper( 146 resolve_artifact=request)) 147 return next(responses).resolve_artifact_response 148 149 def GetArtifact(self, request): 150 requests.put( 151 beam_artifact_api_pb2.ArtifactRequestWrapper(get_artifact=request)) 152 while True: 153 response = next(responses) 154 yield response.get_artifact_response 155 if response.is_last: 156 break 157 158 def resolve(): 159 try: 160 for key, dependencies in dependency_sets.items(): 161 dependency_sets[key] = list( 162 resolve_as_files( 163 ForwardingRetrievalService(), 164 lambda name: self._file_writer( 165 os.path.join(staging_token, name)), 166 dependencies)) 167 requests.done() 168 except: # pylint: disable=bare-except 169 requests.abort() 170 raise 171 finally: 172 event.set() 173 174 t = threading.Thread(target=resolve) 175 t.daemon = True 176 t.start() 177 178 return requests 179 180 181 def resolve_as_files(retrieval_service, file_writer, dependencies): 182 """Translates a set of dependencies into file-based dependencies.""" 183 # Resolve until nothing changes. This ensures that they can be fetched. 184 resolution = retrieval_service.ResolveArtifactss( 185 beam_artifact_api_pb2.ResolveArtifactsRequest( 186 artifacts=dependencies, 187 # Anything fetchable will do. 188 # TODO(robertwb): Take advantage of shared filesystems, urls. 189 preferred_urns=[], 190 )) 191 dependencies = resolution.replacements 192 193 # Fetch each of the dependencies, using file_writer to store them as 194 # file-based artifacts. 195 # TODO(robertwb): Consider parallelizing the actual writes. 196 for dep in dependencies: 197 if dep.role_urn == common_urns.artifact_roles.STAGING_TO.urn: 198 base_name = os.path.basename( 199 proto_utils.parse_Bytes( 200 dep.role_payload, 201 beam_runner_api_pb2.ArtifactStagingToRolePayload).staged_name) 202 else: 203 base_name = None 204 unique_name = '-'.join( 205 filter( 206 None, 207 [hashlib.sha256(dep.SerializeToString()).hexdigest(), base_name])) 208 file_handle, path = file_writer(unique_name) 209 with file_handle as fout: 210 for chunk in retrieval_service.GetArtifact( 211 beam_artifact_api_pb2.GetArtifactRequest(artifact=dep)): 212 fout.write(chunk.data) 213 yield beam_runner_api_pb2.ArtifactInformation( 214 type_urn=common_urns.artifact_types.FILE.urn, 215 type_payload=beam_runner_api_pb2.ArtifactFilePayload( 216 path=path).SerializeToString(), 217 role_urn=dep.role_urn, 218 role_payload=dep.role_payload) 219 220 221 def offer_artifacts( 222 artifact_staging_service, artifact_retrieval_service, staging_token): 223 """Offers a set of artifacts to an artifact staging service, via the 224 ReverseArtifactRetrievalService API. 225 226 The given artifact_retrieval_service should be able to resolve/get all 227 artifacts relevant to this job. 228 """ 229 responses = _QueueIter() 230 responses.put( 231 beam_artifact_api_pb2.ArtifactResponseWrapper( 232 staging_token=staging_token)) 233 requests = artifact_staging_service.ReverseArtifactRetrievalService(responses) 234 try: 235 for request in requests: 236 if request.HasField('resolve_artifact'): 237 responses.put( 238 beam_artifact_api_pb2.ArtifactResponseWrapper( 239 resolve_artifact_response=artifact_retrieval_service. 240 ResolveArtifacts(request.resolve_artifact))) 241 elif request.HasField('get_artifact'): 242 for chunk in artifact_retrieval_service.GetArtifact( 243 request.get_artifact): 244 responses.put( 245 beam_artifact_api_pb2.ArtifactResponseWrapper( 246 get_artifact_response=chunk)) 247 responses.put( 248 beam_artifact_api_pb2.ArtifactResponseWrapper( 249 get_artifact_response=beam_artifact_api_pb2.GetArtifactResponse( 250 data=b''), 251 is_last=True)) 252 responses.done() 253 except: # pylint: disable=bare-except 254 responses.abort() 255 raise 256 257 258 class BeamFilesystemHandler(object): 259 def __init__(self, root): 260 self._root = root 261 262 def file_reader(self, path): 263 return filesystems.FileSystems.open( 264 path, compression_type=CompressionTypes.UNCOMPRESSED) 265 266 def file_writer(self, name=None): 267 full_path = filesystems.FileSystems.join(self._root, name) 268 return filesystems.FileSystems.create( 269 full_path, compression_type=CompressionTypes.UNCOMPRESSED), full_path 270 271 272 def resolve_artifacts(artifacts, service, dest_dir): 273 if not artifacts: 274 return artifacts 275 else: 276 return [ 277 maybe_store_artifact(artifact, service, 278 dest_dir) for artifact in service.ResolveArtifacts( 279 beam_artifact_api_pb2.ResolveArtifactsRequest( 280 artifacts=artifacts)).replacements 281 ] 282 283 284 def maybe_store_artifact(artifact, service, dest_dir): 285 if artifact.type_urn in (common_urns.artifact_types.URL.urn, 286 common_urns.artifact_types.EMBEDDED.urn): 287 return artifact 288 elif artifact.type_urn == common_urns.artifact_types.FILE.urn: 289 payload = beam_runner_api_pb2.ArtifactFilePayload.FromString( 290 artifact.type_payload) 291 # pylint: disable=condition-evals-to-constant 292 if os.path.exists( 293 payload.path) and payload.sha256 and payload.sha256 == sha256( 294 payload.path) and False: 295 return artifact 296 else: 297 return store_artifact(artifact, service, dest_dir) 298 else: 299 return store_artifact(artifact, service, dest_dir) 300 301 302 def store_artifact(artifact, service, dest_dir): 303 hasher = hashlib.sha256() 304 with tempfile.NamedTemporaryFile(dir=dest_dir, delete=False) as fout: 305 for block in service.GetArtifact( 306 beam_artifact_api_pb2.GetArtifactRequest(artifact=artifact)): 307 hasher.update(block.data) 308 fout.write(block.data) 309 return beam_runner_api_pb2.ArtifactInformation( 310 type_urn=common_urns.artifact_types.FILE.urn, 311 type_payload=beam_runner_api_pb2.ArtifactFilePayload( 312 path=fout.name, sha256=hasher.hexdigest()).SerializeToString(), 313 role_urn=artifact.role_urn, 314 role_payload=artifact.role_payload) 315 316 317 def sha256(path): 318 hasher = hashlib.sha256() 319 with open(path, 'rb') as fin: 320 for block in iter(lambda: fin.read(4 << 20), b''): 321 hasher.update(block) 322 return hasher.hexdigest() 323 324 325 class _QueueIter(object): 326 327 _END = object() 328 329 def __init__(self): 330 self._queue = queue.Queue() 331 332 def put(self, item): 333 self._queue.put(item) 334 335 def done(self): 336 self._queue.put(self._END) 337 self._queue.put(StopIteration) 338 339 def abort(self, exn=None): 340 if exn is None: 341 exn = sys.exc_info()[1] 342 self._queue.put(self._END) 343 self._queue.put(exn) 344 345 def __iter__(self): 346 return self 347 348 def __next__(self): 349 item = self._queue.get() 350 if item is self._END: 351 raise self._queue.get() 352 else: 353 return item