github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/test_utils.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Utility methods for testing 19 20 For internal use only; no backwards-compatibility guarantees. 21 """ 22 23 # pytype: skip-file 24 25 import hashlib 26 import importlib 27 import os 28 import shutil 29 import tempfile 30 31 from apache_beam.io.filesystems import FileSystems 32 from apache_beam.utils import retry 33 34 DEFAULT_HASHING_ALG = 'sha1' 35 36 37 class TempDir(object): 38 """Context Manager to create and clean-up a temporary directory.""" 39 def __init__(self): 40 self._tempdir = tempfile.mkdtemp() 41 42 def __enter__(self): 43 return self 44 45 def __exit__(self, *args): 46 if os.path.exists(self._tempdir): 47 shutil.rmtree(self._tempdir) 48 49 def get_path(self): 50 """Returns the path to the temporary directory.""" 51 return self._tempdir 52 53 def create_temp_file(self, suffix='', lines=None): 54 """Creates a temporary file in the temporary directory. 55 56 Args: 57 suffix (str): The filename suffix of the temporary file (e.g. '.txt') 58 lines (List[str]): A list of lines that will be written to the temporary 59 file. 60 Returns: 61 The name of the temporary file created. 62 """ 63 with tempfile.NamedTemporaryFile(delete=False, 64 dir=self._tempdir, 65 suffix=suffix) as f: 66 if lines: 67 for line in lines: 68 f.write(line) 69 70 return f.name 71 72 73 def compute_hash(content, hashing_alg=DEFAULT_HASHING_ALG): 74 """Compute a hash value of a list of objects by hashing their string 75 representations.""" 76 content = [ 77 str(x).encode('utf-8') if not isinstance(x, bytes) else x for x in content 78 ] 79 content.sort() 80 m = hashlib.new(hashing_alg) 81 for elem in content: 82 m.update(elem) 83 return m.hexdigest() 84 85 86 def patch_retry(testcase, module): 87 """A function to patch retry module to use mock clock and logger. 88 89 Clock and logger that defined in retry decorator will be replaced in test 90 in order to skip sleep phase when retry happens. 91 92 Args: 93 testcase: An instance of unittest.TestCase that calls this function to 94 patch retry module. 95 module: The module that uses retry and need to be replaced with mock 96 clock and logger in test. 97 """ 98 # Import mock here to avoid execution time errors for other utilities 99 from mock import Mock 100 from mock import patch 101 102 real_retry_with_exponential_backoff = retry.with_exponential_backoff 103 104 def patched_retry_with_exponential_backoff(**kwargs): 105 """A patch for retry decorator to use a mock dummy clock and logger.""" 106 kwargs.update(logger=Mock(), clock=Mock()) 107 return real_retry_with_exponential_backoff(**kwargs) 108 109 patch.object( 110 retry, 111 'with_exponential_backoff', 112 side_effect=patched_retry_with_exponential_backoff).start() 113 114 # Reload module after patching. 115 importlib.reload(module) 116 117 def remove_patches(): 118 patch.stopall() 119 # Reload module again after removing patch. 120 importlib.reload(module) 121 122 testcase.addCleanup(remove_patches) 123 124 125 @retry.with_exponential_backoff( 126 num_retries=3, retry_filter=retry.retry_on_beam_io_error_filter) 127 def delete_files(file_paths): 128 """A function to clean up files or directories using ``FileSystems``. 129 130 Glob is supported in file path and directories will be deleted recursively. 131 132 Args: 133 file_paths: A list of strings contains file paths or directories. 134 """ 135 if len(file_paths) == 0: 136 raise RuntimeError('Clean up failed. Invalid file path: %s.' % file_paths) 137 FileSystems.delete(file_paths) 138 139 140 def cleanup_subscriptions(sub_client, subs): 141 """Cleanup PubSub subscriptions if exist.""" 142 for sub in subs: 143 sub_client.delete_subscription(subscription=sub.name) 144 145 146 def cleanup_topics(pub_client, topics): 147 """Cleanup PubSub topics if exist.""" 148 for topic in topics: 149 pub_client.delete_topic(topic=topic.name) 150 151 152 class PullResponseMessage(object): 153 """Data representing a pull request response. 154 155 Utility class for ``create_pull_response``. 156 """ 157 def __init__( 158 self, 159 data, 160 attributes=None, 161 publish_time_secs=None, 162 publish_time_nanos=None, 163 ack_id=None): 164 self.data = data 165 self.attributes = attributes 166 self.publish_time_secs = publish_time_secs 167 self.publish_time_nanos = publish_time_nanos 168 self.ack_id = ack_id 169 170 171 def create_pull_response(responses): 172 """Create an instance of ``google.cloud.pubsub.types.ReceivedMessage``. 173 174 Used to simulate the response from pubsub.SubscriberClient().pull(). 175 176 Args: 177 responses: list of ``PullResponseMessage`` 178 179 Returns: 180 An instance of ``google.cloud.pubsub.types.PullResponse`` populated with 181 responses. 182 """ 183 from google.cloud import pubsub 184 from google.protobuf import timestamp_pb2 185 186 res = pubsub.types.PullResponse() 187 for response in responses: 188 received_message = pubsub.types.ReceivedMessage() 189 190 message = received_message.message 191 message.data = response.data 192 if response.attributes is not None: 193 for k, v in response.attributes.items(): 194 message.attributes[k] = v 195 196 publish_time = timestamp_pb2.Timestamp() 197 if response.publish_time_secs is not None: 198 publish_time.seconds = response.publish_time_secs 199 if response.publish_time_nanos is not None: 200 publish_time.nanos = response.publish_time_nanos 201 message.publish_time = publish_time 202 203 if response.ack_id is not None: 204 received_message.ack_id = response.ack_id 205 206 res.received_messages.append(received_message) 207 208 return res 209 210 211 def create_file(path, contents): 212 """Create a file to use as input to test pipelines""" 213 with FileSystems.create(path) as f: 214 f.write(str.encode(contents, 'utf-8')) 215 return path 216 217 218 def read_files_from_pattern(file_pattern): 219 """Reads the files that match a pattern""" 220 metadata_list = FileSystems.match([file_pattern])[0].metadata_list 221 output = [] 222 for metadata in metadata_list: 223 with FileSystems.open(metadata.path) as f: 224 output.append(f.read().decode('utf-8').strip()) 225 return '\n'.join(output)