github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/artifact_service_test.py (about)

     1  # Licensed to the Apache Software Foundation (ASF) under one or more
     2  # contributor license agreements.  See the NOTICE file distributed with
     3  # this work for additional information regarding copyright ownership.
     4  # The ASF licenses this file to You under the Apache License, Version 2.0
     5  # (the "License"); you may not use this file except in compliance with
     6  # the License.  You may obtain a copy of the License at
     7  #
     8  #    http://www.apache.org/licenses/LICENSE-2.0
     9  #
    10  # Unless required by applicable law or agreed to in writing, software
    11  # distributed under the License is distributed on an "AS IS" BASIS,
    12  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  # See the License for the specific language governing permissions and
    14  # limitations under the License.
    15  #
    16  
    17  """Test cases for :module:`artifact_service_client`."""
    18  
    19  # pytype: skip-file
    20  
    21  import contextlib
    22  import io
    23  import threading
    24  import unittest
    25  from urllib.parse import quote
    26  
    27  from apache_beam.portability import common_urns
    28  from apache_beam.portability.api import beam_artifact_api_pb2
    29  from apache_beam.portability.api import beam_runner_api_pb2
    30  from apache_beam.runners.portability import artifact_service
    31  from apache_beam.utils import proto_utils
    32  
    33  
    34  class InMemoryFileManager(object):
    35    def __init__(self, contents=()):
    36      self._contents = dict(contents)
    37  
    38    def get(self, path):
    39      return self._contents[path]
    40  
    41    def file_reader(self, path):
    42      return io.BytesIO(self._contents[path])
    43  
    44    def file_writer(self, name):
    45      path = 'prefix:' + name
    46  
    47      @contextlib.contextmanager
    48      def writable():
    49        buffer = io.BytesIO()
    50        yield buffer
    51        buffer.seek(0)
    52        self._contents[path] = buffer.read()
    53  
    54      return writable(), path
    55  
    56  
    57  class ArtifactServiceTest(unittest.TestCase):
    58    def file_artifact(self, path):
    59      return beam_runner_api_pb2.ArtifactInformation(
    60          type_urn=common_urns.artifact_types.FILE.urn,
    61          type_payload=beam_runner_api_pb2.ArtifactFilePayload(
    62              path=path).SerializeToString())
    63  
    64    def embedded_artifact(self, data, name=None):
    65      return beam_runner_api_pb2.ArtifactInformation(
    66          type_urn=common_urns.artifact_types.EMBEDDED.urn,
    67          type_payload=beam_runner_api_pb2.EmbeddedFilePayload(
    68              data=data).SerializeToString(),
    69          role_urn=common_urns.artifact_roles.STAGING_TO.urn if name else None,
    70          role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload(
    71              staged_name=name).SerializeToString() if name else None)
    72  
    73    def test_file_retrieval(self):
    74      file_manager = InMemoryFileManager({
    75          'path/to/a': b'a', 'path/to/b': b'b' * 37
    76      })
    77      retrieval_service = artifact_service.ArtifactRetrievalService(
    78          file_manager.file_reader, chunk_size=10)
    79      dep_a = self.file_artifact('path/to/a')
    80      self.assertEqual(
    81          retrieval_service.ResolveArtifacts(
    82              beam_artifact_api_pb2.ResolveArtifactsRequest(artifacts=[dep_a])),
    83          beam_artifact_api_pb2.ResolveArtifactsResponse(replacements=[dep_a]))
    84  
    85      self.assertEqual(
    86          list(
    87              retrieval_service.GetArtifact(
    88                  beam_artifact_api_pb2.GetArtifactRequest(artifact=dep_a))),
    89          [beam_artifact_api_pb2.GetArtifactResponse(data=b'a')])
    90      self.assertEqual(
    91          list(
    92              retrieval_service.GetArtifact(
    93                  beam_artifact_api_pb2.GetArtifactRequest(
    94                      artifact=self.file_artifact('path/to/b')))),
    95          [
    96              beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
    97              beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
    98              beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
    99              beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 7)
   100          ])
   101  
   102    def test_embedded_retrieval(self):
   103      retrieval_service = artifact_service.ArtifactRetrievalService(None)
   104      embedded_dep = self.embedded_artifact(b'some_data')
   105      self.assertEqual(
   106          list(
   107              retrieval_service.GetArtifact(
   108                  beam_artifact_api_pb2.GetArtifactRequest(
   109                      artifact=embedded_dep))),
   110          [beam_artifact_api_pb2.GetArtifactResponse(data=b'some_data')])
   111  
   112    def test_url_retrieval(self):
   113      retrieval_service = artifact_service.ArtifactRetrievalService(None)
   114      url_dep = beam_runner_api_pb2.ArtifactInformation(
   115          type_urn=common_urns.artifact_types.URL.urn,
   116          type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
   117              url='file:' + quote(__file__)).SerializeToString())
   118      content = b''.join([
   119          r.data for r in retrieval_service.GetArtifact(
   120              beam_artifact_api_pb2.GetArtifactRequest(artifact=url_dep))
   121      ])
   122      with open(__file__, 'rb') as fin:
   123        self.assertEqual(content, fin.read())
   124  
   125    def test_push_artifacts(self):
   126      unresolved = beam_runner_api_pb2.ArtifactInformation(type_urn='unresolved')
   127      resolved_a = self.embedded_artifact(data=b'a', name='a.txt')
   128      resolved_b = self.embedded_artifact(data=b'bb', name='b.txt')
   129      dep_big = self.embedded_artifact(data=b'big ' * 100, name='big.txt')
   130  
   131      class TestArtifacts(object):
   132        def ResolveArtifacts(self, request):
   133          replacements = []
   134          for artifact in request.artifacts:
   135            if artifact.type_urn == 'unresolved':
   136              replacements += [resolved_a, resolved_b]
   137            else:
   138              replacements.append(artifact)
   139          return beam_artifact_api_pb2.ResolveArtifactsResponse(
   140              replacements=replacements)
   141  
   142        def GetArtifact(self, request):
   143          if request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
   144            content = proto_utils.parse_Bytes(
   145                request.artifact.type_payload,
   146                beam_runner_api_pb2.EmbeddedFilePayload).data
   147            for k in range(0, len(content), 13):
   148              yield beam_artifact_api_pb2.GetArtifactResponse(
   149                  data=content[k:k + 13])
   150          else:
   151            raise NotImplementedError
   152  
   153      file_manager = InMemoryFileManager()
   154      server = artifact_service.ArtifactStagingService(file_manager.file_writer)
   155  
   156      server.register_job('staging_token', {'env': [unresolved, dep_big]})
   157  
   158      # "Push" artifacts as if from a client.
   159      t = threading.Thread(
   160          target=lambda: artifact_service.offer_artifacts(
   161              server, TestArtifacts(), 'staging_token'))
   162      t.daemon = True
   163      t.start()
   164  
   165      resolved_deps = server.resolved_deps('staging_token', timeout=5)['env']
   166      expected = {
   167          'a.txt': b'a',
   168          'b.txt': b'bb',
   169          'big.txt': b'big ' * 100,
   170      }
   171      for dep in resolved_deps:
   172        self.assertEqual(dep.type_urn, common_urns.artifact_types.FILE.urn)
   173        self.assertEqual(dep.role_urn, common_urns.artifact_roles.STAGING_TO.urn)
   174        type_payload = proto_utils.parse_Bytes(
   175            dep.type_payload, beam_runner_api_pb2.ArtifactFilePayload)
   176        role_payload = proto_utils.parse_Bytes(
   177            dep.role_payload, beam_runner_api_pb2.ArtifactStagingToRolePayload)
   178        self.assertTrue(
   179            type_payload.path.endswith(role_payload.staged_name),
   180            type_payload.path)
   181        self.assertEqual(
   182            file_manager.get(type_payload.path),
   183            expected.pop(role_payload.staged_name))
   184      self.assertEqual(expected, {})
   185  
   186  
   187  if __name__ == '__main__':
   188    unittest.main()