github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/artifact_service_test.py (about) 1 # Licensed to the Apache Software Foundation (ASF) under one or more 2 # contributor license agreements. See the NOTICE file distributed with 3 # this work for additional information regarding copyright ownership. 4 # The ASF licenses this file to You under the Apache License, Version 2.0 5 # (the "License"); you may not use this file except in compliance with 6 # the License. You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # 16 17 """Test cases for :module:`artifact_service_client`.""" 18 19 # pytype: skip-file 20 21 import contextlib 22 import io 23 import threading 24 import unittest 25 from urllib.parse import quote 26 27 from apache_beam.portability import common_urns 28 from apache_beam.portability.api import beam_artifact_api_pb2 29 from apache_beam.portability.api import beam_runner_api_pb2 30 from apache_beam.runners.portability import artifact_service 31 from apache_beam.utils import proto_utils 32 33 34 class InMemoryFileManager(object): 35 def __init__(self, contents=()): 36 self._contents = dict(contents) 37 38 def get(self, path): 39 return self._contents[path] 40 41 def file_reader(self, path): 42 return io.BytesIO(self._contents[path]) 43 44 def file_writer(self, name): 45 path = 'prefix:' + name 46 47 @contextlib.contextmanager 48 def writable(): 49 buffer = io.BytesIO() 50 yield buffer 51 buffer.seek(0) 52 self._contents[path] = buffer.read() 53 54 return writable(), path 55 56 57 class ArtifactServiceTest(unittest.TestCase): 58 def file_artifact(self, path): 59 return beam_runner_api_pb2.ArtifactInformation( 60 type_urn=common_urns.artifact_types.FILE.urn, 61 type_payload=beam_runner_api_pb2.ArtifactFilePayload( 62 path=path).SerializeToString()) 63 64 def embedded_artifact(self, data, name=None): 65 return beam_runner_api_pb2.ArtifactInformation( 66 type_urn=common_urns.artifact_types.EMBEDDED.urn, 67 type_payload=beam_runner_api_pb2.EmbeddedFilePayload( 68 data=data).SerializeToString(), 69 role_urn=common_urns.artifact_roles.STAGING_TO.urn if name else None, 70 role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload( 71 staged_name=name).SerializeToString() if name else None) 72 73 def test_file_retrieval(self): 74 file_manager = InMemoryFileManager({ 75 'path/to/a': b'a', 'path/to/b': b'b' * 37 76 }) 77 retrieval_service = artifact_service.ArtifactRetrievalService( 78 file_manager.file_reader, chunk_size=10) 79 dep_a = self.file_artifact('path/to/a') 80 self.assertEqual( 81 retrieval_service.ResolveArtifacts( 82 beam_artifact_api_pb2.ResolveArtifactsRequest(artifacts=[dep_a])), 83 beam_artifact_api_pb2.ResolveArtifactsResponse(replacements=[dep_a])) 84 85 self.assertEqual( 86 list( 87 retrieval_service.GetArtifact( 88 beam_artifact_api_pb2.GetArtifactRequest(artifact=dep_a))), 89 [beam_artifact_api_pb2.GetArtifactResponse(data=b'a')]) 90 self.assertEqual( 91 list( 92 retrieval_service.GetArtifact( 93 beam_artifact_api_pb2.GetArtifactRequest( 94 artifact=self.file_artifact('path/to/b')))), 95 [ 96 beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10), 97 beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10), 98 beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10), 99 beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 7) 100 ]) 101 102 def test_embedded_retrieval(self): 103 retrieval_service = artifact_service.ArtifactRetrievalService(None) 104 embedded_dep = self.embedded_artifact(b'some_data') 105 self.assertEqual( 106 list( 107 retrieval_service.GetArtifact( 108 beam_artifact_api_pb2.GetArtifactRequest( 109 artifact=embedded_dep))), 110 [beam_artifact_api_pb2.GetArtifactResponse(data=b'some_data')]) 111 112 def test_url_retrieval(self): 113 retrieval_service = artifact_service.ArtifactRetrievalService(None) 114 url_dep = beam_runner_api_pb2.ArtifactInformation( 115 type_urn=common_urns.artifact_types.URL.urn, 116 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 117 url='file:' + quote(__file__)).SerializeToString()) 118 content = b''.join([ 119 r.data for r in retrieval_service.GetArtifact( 120 beam_artifact_api_pb2.GetArtifactRequest(artifact=url_dep)) 121 ]) 122 with open(__file__, 'rb') as fin: 123 self.assertEqual(content, fin.read()) 124 125 def test_push_artifacts(self): 126 unresolved = beam_runner_api_pb2.ArtifactInformation(type_urn='unresolved') 127 resolved_a = self.embedded_artifact(data=b'a', name='a.txt') 128 resolved_b = self.embedded_artifact(data=b'bb', name='b.txt') 129 dep_big = self.embedded_artifact(data=b'big ' * 100, name='big.txt') 130 131 class TestArtifacts(object): 132 def ResolveArtifacts(self, request): 133 replacements = [] 134 for artifact in request.artifacts: 135 if artifact.type_urn == 'unresolved': 136 replacements += [resolved_a, resolved_b] 137 else: 138 replacements.append(artifact) 139 return beam_artifact_api_pb2.ResolveArtifactsResponse( 140 replacements=replacements) 141 142 def GetArtifact(self, request): 143 if request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn: 144 content = proto_utils.parse_Bytes( 145 request.artifact.type_payload, 146 beam_runner_api_pb2.EmbeddedFilePayload).data 147 for k in range(0, len(content), 13): 148 yield beam_artifact_api_pb2.GetArtifactResponse( 149 data=content[k:k + 13]) 150 else: 151 raise NotImplementedError 152 153 file_manager = InMemoryFileManager() 154 server = artifact_service.ArtifactStagingService(file_manager.file_writer) 155 156 server.register_job('staging_token', {'env': [unresolved, dep_big]}) 157 158 # "Push" artifacts as if from a client. 159 t = threading.Thread( 160 target=lambda: artifact_service.offer_artifacts( 161 server, TestArtifacts(), 'staging_token')) 162 t.daemon = True 163 t.start() 164 165 resolved_deps = server.resolved_deps('staging_token', timeout=5)['env'] 166 expected = { 167 'a.txt': b'a', 168 'b.txt': b'bb', 169 'big.txt': b'big ' * 100, 170 } 171 for dep in resolved_deps: 172 self.assertEqual(dep.type_urn, common_urns.artifact_types.FILE.urn) 173 self.assertEqual(dep.role_urn, common_urns.artifact_roles.STAGING_TO.urn) 174 type_payload = proto_utils.parse_Bytes( 175 dep.type_payload, beam_runner_api_pb2.ArtifactFilePayload) 176 role_payload = proto_utils.parse_Bytes( 177 dep.role_payload, beam_runner_api_pb2.ArtifactStagingToRolePayload) 178 self.assertTrue( 179 type_payload.path.endswith(role_payload.staged_name), 180 type_payload.path) 181 self.assertEqual( 182 file_manager.get(type_payload.path), 183 expected.pop(role_payload.staged_name)) 184 self.assertEqual(expected, {}) 185 186 187 if __name__ == '__main__': 188 unittest.main()