github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/artifact_service.py (about)

     1  # Licensed to the Apache Software Foundation (ASF) under one or more
     2  # contributor license agreements.  See the NOTICE file distributed with
     3  # this work for additional information regarding copyright ownership.
     4  # The ASF licenses this file to You under the Apache License, Version 2.0
     5  # (the "License"); you may not use this file except in compliance with
     6  # the License.  You may obtain a copy of the License at
     7  #
     8  #    http://www.apache.org/licenses/LICENSE-2.0
     9  #
    10  # Unless required by applicable law or agreed to in writing, software
    11  # distributed under the License is distributed on an "AS IS" BASIS,
    12  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  # See the License for the specific language governing permissions and
    14  # limitations under the License.
    15  #
    16  
    17  """Implementation of an Artifact{Staging,Retrieval}Service.
    18  
    19  The staging service here can be backed by any beam filesystem.
    20  """
    21  
    22  # pytype: skip-file
    23  
    24  import concurrent.futures
    25  import hashlib
    26  import os
    27  import queue
    28  import sys
    29  import tempfile
    30  import threading
    31  from io import BytesIO
    32  from typing import Any
    33  from typing import BinaryIO  # pylint: disable=unused-import
    34  from typing import Callable
    35  from typing import Dict
    36  from typing import List
    37  from typing import MutableMapping
    38  from typing import Optional
    39  from typing import Tuple
    40  from urllib.request import urlopen
    41  
    42  import grpc
    43  
    44  from apache_beam.io import filesystems
    45  from apache_beam.io.filesystems import CompressionTypes
    46  from apache_beam.portability import common_urns
    47  from apache_beam.portability.api import beam_artifact_api_pb2
    48  from apache_beam.portability.api import beam_artifact_api_pb2_grpc
    49  from apache_beam.portability.api import beam_runner_api_pb2
    50  from apache_beam.utils import proto_utils
    51  
    52  
    53  class ArtifactRetrievalService(
    54      beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer):
    55  
    56    _DEFAULT_CHUNK_SIZE = 2 << 20
    57  
    58    def __init__(
    59        self,
    60        file_reader,  # type: Callable[[str], BinaryIO]
    61        chunk_size=None,
    62    ):
    63      self._file_reader = file_reader
    64      self._chunk_size = chunk_size or self._DEFAULT_CHUNK_SIZE
    65  
    66    def ResolveArtifacts(self, request, context=None):
    67      return beam_artifact_api_pb2.ResolveArtifactsResponse(
    68          replacements=request.artifacts)
    69  
    70    def GetArtifact(self, request, context=None):
    71      if request.artifact.type_urn == common_urns.artifact_types.FILE.urn:
    72        payload = proto_utils.parse_Bytes(
    73            request.artifact.type_payload,
    74            beam_runner_api_pb2.ArtifactFilePayload)
    75        read_handle = self._file_reader(payload.path)
    76      elif request.artifact.type_urn == common_urns.artifact_types.URL.urn:
    77        payload = proto_utils.parse_Bytes(
    78            request.artifact.type_payload, beam_runner_api_pb2.ArtifactUrlPayload)
    79        read_handle = urlopen(payload.url)
    80      elif request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
    81        payload = proto_utils.parse_Bytes(
    82            request.artifact.type_payload,
    83            beam_runner_api_pb2.EmbeddedFilePayload)
    84        read_handle = BytesIO(payload.data)
    85      else:
    86        raise NotImplementedError(request.artifact.type_urn)
    87  
    88      with read_handle as fin:
    89        while True:
    90          chunk = fin.read(self._chunk_size)
    91          if not chunk:
    92            break
    93          yield beam_artifact_api_pb2.GetArtifactResponse(data=chunk)
    94  
    95  
    96  class ArtifactStagingService(
    97      beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer):
    98    def __init__(
    99        self,
   100        file_writer,  # type: Callable[[str, Optional[str]], Tuple[BinaryIO, str]]
   101      ):
   102      self._lock = threading.Lock()
   103      self._jobs_to_stage = {
   104      }  # type: Dict[str, Tuple[Dict[Any, List[beam_runner_api_pb2.ArtifactInformation]], threading.Event]]
   105      self._file_writer = file_writer
   106  
   107    def register_job(
   108        self,
   109        staging_token,  # type: str
   110        dependency_sets  # type: MutableMapping[Any, List[beam_runner_api_pb2.ArtifactInformation]]
   111      ):
   112      if staging_token in self._jobs_to_stage:
   113        raise ValueError('Already staging %s' % staging_token)
   114      with self._lock:
   115        self._jobs_to_stage[staging_token] = (
   116            dict(dependency_sets), threading.Event())
   117  
   118    def resolved_deps(self, staging_token, timeout=None):
   119      with self._lock:
   120        dependency_sets, event = self._jobs_to_stage[staging_token]
   121      try:
   122        if not event.wait(timeout):
   123          raise concurrent.futures.TimeoutError()
   124        return dependency_sets
   125      finally:
   126        with self._lock:
   127          del self._jobs_to_stage[staging_token]
   128  
   129    def ReverseArtifactRetrievalService(self, responses, context=None):
   130      staging_token = next(responses).staging_token
   131      with self._lock:
   132        try:
   133          dependency_sets, event = self._jobs_to_stage[staging_token]
   134        except KeyError:
   135          if context:
   136            context.set_code(grpc.StatusCode.NOT_FOUND)
   137            context.set_details('No such staging token: %r' % staging_token)
   138          raise
   139  
   140      requests = _QueueIter()
   141  
   142      class ForwardingRetrievalService(object):
   143        def ResolveArtifactss(self, request):
   144          requests.put(
   145              beam_artifact_api_pb2.ArtifactRequestWrapper(
   146                  resolve_artifact=request))
   147          return next(responses).resolve_artifact_response
   148  
   149        def GetArtifact(self, request):
   150          requests.put(
   151              beam_artifact_api_pb2.ArtifactRequestWrapper(get_artifact=request))
   152          while True:
   153            response = next(responses)
   154            yield response.get_artifact_response
   155            if response.is_last:
   156              break
   157  
   158      def resolve():
   159        try:
   160          for key, dependencies in dependency_sets.items():
   161            dependency_sets[key] = list(
   162                resolve_as_files(
   163                    ForwardingRetrievalService(),
   164                    lambda name: self._file_writer(
   165                        os.path.join(staging_token, name)),
   166                    dependencies))
   167          requests.done()
   168        except:  # pylint: disable=bare-except
   169          requests.abort()
   170          raise
   171        finally:
   172          event.set()
   173  
   174      t = threading.Thread(target=resolve)
   175      t.daemon = True
   176      t.start()
   177  
   178      return requests
   179  
   180  
   181  def resolve_as_files(retrieval_service, file_writer, dependencies):
   182    """Translates a set of dependencies into file-based dependencies."""
   183    # Resolve until nothing changes.  This ensures that they can be fetched.
   184    resolution = retrieval_service.ResolveArtifactss(
   185        beam_artifact_api_pb2.ResolveArtifactsRequest(
   186            artifacts=dependencies,
   187            # Anything fetchable will do.
   188            # TODO(robertwb): Take advantage of shared filesystems, urls.
   189            preferred_urns=[],
   190        ))
   191    dependencies = resolution.replacements
   192  
   193    # Fetch each of the dependencies, using file_writer to store them as
   194    # file-based artifacts.
   195    # TODO(robertwb): Consider parallelizing the actual writes.
   196    for dep in dependencies:
   197      if dep.role_urn == common_urns.artifact_roles.STAGING_TO.urn:
   198        base_name = os.path.basename(
   199            proto_utils.parse_Bytes(
   200                dep.role_payload,
   201                beam_runner_api_pb2.ArtifactStagingToRolePayload).staged_name)
   202      else:
   203        base_name = None
   204      unique_name = '-'.join(
   205          filter(
   206              None,
   207              [hashlib.sha256(dep.SerializeToString()).hexdigest(), base_name]))
   208      file_handle, path = file_writer(unique_name)
   209      with file_handle as fout:
   210        for chunk in retrieval_service.GetArtifact(
   211            beam_artifact_api_pb2.GetArtifactRequest(artifact=dep)):
   212          fout.write(chunk.data)
   213      yield beam_runner_api_pb2.ArtifactInformation(
   214          type_urn=common_urns.artifact_types.FILE.urn,
   215          type_payload=beam_runner_api_pb2.ArtifactFilePayload(
   216              path=path).SerializeToString(),
   217          role_urn=dep.role_urn,
   218          role_payload=dep.role_payload)
   219  
   220  
   221  def offer_artifacts(
   222      artifact_staging_service, artifact_retrieval_service, staging_token):
   223    """Offers a set of artifacts to an artifact staging service, via the
   224    ReverseArtifactRetrievalService API.
   225  
   226    The given artifact_retrieval_service should be able to resolve/get all
   227    artifacts relevant to this job.
   228    """
   229    responses = _QueueIter()
   230    responses.put(
   231        beam_artifact_api_pb2.ArtifactResponseWrapper(
   232            staging_token=staging_token))
   233    requests = artifact_staging_service.ReverseArtifactRetrievalService(responses)
   234    try:
   235      for request in requests:
   236        if request.HasField('resolve_artifact'):
   237          responses.put(
   238              beam_artifact_api_pb2.ArtifactResponseWrapper(
   239                  resolve_artifact_response=artifact_retrieval_service.
   240                  ResolveArtifacts(request.resolve_artifact)))
   241        elif request.HasField('get_artifact'):
   242          for chunk in artifact_retrieval_service.GetArtifact(
   243              request.get_artifact):
   244            responses.put(
   245                beam_artifact_api_pb2.ArtifactResponseWrapper(
   246                    get_artifact_response=chunk))
   247          responses.put(
   248              beam_artifact_api_pb2.ArtifactResponseWrapper(
   249                  get_artifact_response=beam_artifact_api_pb2.GetArtifactResponse(
   250                      data=b''),
   251                  is_last=True))
   252      responses.done()
   253    except:  # pylint: disable=bare-except
   254      responses.abort()
   255      raise
   256  
   257  
   258  class BeamFilesystemHandler(object):
   259    def __init__(self, root):
   260      self._root = root
   261  
   262    def file_reader(self, path):
   263      return filesystems.FileSystems.open(
   264          path, compression_type=CompressionTypes.UNCOMPRESSED)
   265  
   266    def file_writer(self, name=None):
   267      full_path = filesystems.FileSystems.join(self._root, name)
   268      return filesystems.FileSystems.create(
   269          full_path, compression_type=CompressionTypes.UNCOMPRESSED), full_path
   270  
   271  
   272  def resolve_artifacts(artifacts, service, dest_dir):
   273    if not artifacts:
   274      return artifacts
   275    else:
   276      return [
   277          maybe_store_artifact(artifact, service,
   278                               dest_dir) for artifact in service.ResolveArtifacts(
   279                                   beam_artifact_api_pb2.ResolveArtifactsRequest(
   280                                       artifacts=artifacts)).replacements
   281      ]
   282  
   283  
   284  def maybe_store_artifact(artifact, service, dest_dir):
   285    if artifact.type_urn in (common_urns.artifact_types.URL.urn,
   286                             common_urns.artifact_types.EMBEDDED.urn):
   287      return artifact
   288    elif artifact.type_urn == common_urns.artifact_types.FILE.urn:
   289      payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
   290          artifact.type_payload)
   291      # pylint: disable=condition-evals-to-constant
   292      if os.path.exists(
   293          payload.path) and payload.sha256 and payload.sha256 == sha256(
   294              payload.path) and False:
   295        return artifact
   296      else:
   297        return store_artifact(artifact, service, dest_dir)
   298    else:
   299      return store_artifact(artifact, service, dest_dir)
   300  
   301  
   302  def store_artifact(artifact, service, dest_dir):
   303    hasher = hashlib.sha256()
   304    with tempfile.NamedTemporaryFile(dir=dest_dir, delete=False) as fout:
   305      for block in service.GetArtifact(
   306          beam_artifact_api_pb2.GetArtifactRequest(artifact=artifact)):
   307        hasher.update(block.data)
   308        fout.write(block.data)
   309    return beam_runner_api_pb2.ArtifactInformation(
   310        type_urn=common_urns.artifact_types.FILE.urn,
   311        type_payload=beam_runner_api_pb2.ArtifactFilePayload(
   312            path=fout.name, sha256=hasher.hexdigest()).SerializeToString(),
   313        role_urn=artifact.role_urn,
   314        role_payload=artifact.role_payload)
   315  
   316  
   317  def sha256(path):
   318    hasher = hashlib.sha256()
   319    with open(path, 'rb') as fin:
   320      for block in iter(lambda: fin.read(4 << 20), b''):
   321        hasher.update(block)
   322    return hasher.hexdigest()
   323  
   324  
   325  class _QueueIter(object):
   326  
   327    _END = object()
   328  
   329    def __init__(self):
   330      self._queue = queue.Queue()
   331  
   332    def put(self, item):
   333      self._queue.put(item)
   334  
   335    def done(self):
   336      self._queue.put(self._END)
   337      self._queue.put(StopIteration)
   338  
   339    def abort(self, exn=None):
   340      if exn is None:
   341        exn = sys.exc_info()[1]
   342      self._queue.put(self._END)
   343      self._queue.put(exn)
   344  
   345    def __iter__(self):
   346      return self
   347  
   348    def __next__(self):
   349      item = self._queue.get()
   350      if item is self._END:
   351        raise self._queue.get()
   352      else:
   353        return item