github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/aws/clients/s3/fake_client.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import datetime 21 import time 22 23 import pytz 24 25 from apache_beam.io.aws.clients.s3 import messages 26 27 28 class FakeFile(object): 29 def __init__(self, bucket, key, contents, etag=None): 30 self.bucket = bucket 31 self.key = key 32 self.contents = contents 33 34 self.last_modified = time.time() 35 36 if not etag: 37 self.etag = '"%s-1"' % ('x' * 32) 38 else: 39 self.etag = etag 40 41 def get_metadata(self): 42 last_modified_datetime = None 43 if self.last_modified: 44 last_modified_datetime = datetime.datetime.fromtimestamp( 45 self.last_modified, pytz.utc) 46 47 return messages.Item( 48 self.etag, 49 self.key, 50 last_modified_datetime, 51 len(self.contents), 52 mime_type=None) 53 54 55 class FakeS3Client(object): 56 def __init__(self): 57 self.files = {} 58 self.list_continuation_tokens = {} 59 self.multipart_uploads = {} 60 61 # boto3 has different behavior when running some operations against a bucket 62 # that exists vs. against one that doesn't. To emulate that behavior, the 63 # mock client keeps a set of bucket names that it knows "exist". 64 self.known_buckets = set() 65 66 def add_file(self, f): 67 self.files[(f.bucket, f.key)] = f 68 if f.bucket not in self.known_buckets: 69 self.known_buckets.add(f.bucket) 70 71 def get_file(self, bucket, obj): 72 try: 73 return self.files[bucket, obj] 74 except: 75 raise messages.S3ClientError('Not Found', 404) 76 77 def delete_file(self, bucket, obj): 78 del self.files[(bucket, obj)] 79 80 def get_object_metadata(self, request): 81 r"""Retrieves an object's metadata. 82 83 Args: 84 request: (GetRequest) input message 85 86 Returns: 87 (Item) The response message. 88 """ 89 # TODO: Do we want to mock out a lack of credentials? 90 file_ = self.get_file(request.bucket, request.object) 91 return file_.get_metadata() 92 93 def list(self, request): 94 bucket = request.bucket 95 prefix = request.prefix or '' 96 matching_files = [] 97 98 for file_bucket, file_name in sorted(iter(self.files)): 99 if bucket == file_bucket and file_name.startswith(prefix): 100 file_object = self.get_file(file_bucket, file_name).get_metadata() 101 matching_files.append(file_object) 102 103 if not matching_files: 104 message = 'Tried to list nonexistent S3 path: s3://%s/%s' % ( 105 bucket, prefix) 106 raise messages.S3ClientError(message, 404) 107 108 # Handle pagination. 109 items_per_page = 5 110 if not request.continuation_token: 111 range_start = 0 112 else: 113 if request.continuation_token not in self.list_continuation_tokens: 114 raise ValueError('Invalid page token.') 115 range_start = self.list_continuation_tokens[request.continuation_token] 116 del self.list_continuation_tokens[request.continuation_token] 117 118 result = messages.ListResponse( 119 items=matching_files[range_start:range_start + items_per_page]) 120 121 if range_start + items_per_page < len(matching_files): 122 next_range_start = range_start + items_per_page 123 next_continuation_token = '_page_token_%s_%s_%d' % ( 124 bucket, prefix, next_range_start) 125 self.list_continuation_tokens[next_continuation_token] = next_range_start 126 result.next_token = next_continuation_token 127 128 return result 129 130 def get_range(self, request, start, end): 131 r"""Retrieves an object. 132 133 Args: 134 request: (GetRequest) request 135 Returns: 136 (bytes) The response message. 137 """ 138 139 file_ = self.get_file(request.bucket, request.object) 140 141 # Replicates S3's behavior, per the spec here: 142 # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 143 if start < 0 or end <= start: 144 return file_.contents 145 146 return file_.contents[start:end] 147 148 def delete(self, request): 149 if request.bucket not in self.known_buckets: 150 raise messages.S3ClientError('The specified bucket does not exist', 404) 151 152 if (request.bucket, request.object) in self.files: 153 self.delete_file(request.bucket, request.object) 154 else: 155 # S3 doesn't raise an error if you try to delete a nonexistent file from 156 # an extant bucket 157 return 158 159 def delete_batch(self, request): 160 161 deleted, failed, errors = [], [], [] 162 for object in request.objects: 163 try: 164 delete_request = messages.DeleteRequest(request.bucket, object) 165 self.delete(delete_request) 166 deleted.append(object) 167 except messages.S3ClientError as e: 168 failed.append(object) 169 errors.append(e) 170 171 return messages.DeleteBatchResponse(deleted, failed, errors) 172 173 def copy(self, request): 174 175 src_file = self.get_file(request.src_bucket, request.src_key) 176 dest_file = FakeFile( 177 request.dest_bucket, request.dest_key, src_file.contents) 178 self.add_file(dest_file) 179 180 def create_multipart_upload(self, request): 181 # Create hash of bucket and key 182 # Store upload_id internally 183 upload_id = request.bucket + request.object 184 self.multipart_uploads[upload_id] = {} 185 return messages.UploadResponse(upload_id) 186 187 def upload_part(self, request): 188 # Save off bytes passed to internal data store 189 upload_id, part_number = request.upload_id, request.part_number 190 191 if part_number < 0 or not isinstance(part_number, int): 192 raise messages.S3ClientError( 193 'Param validation failed on part number', 400) 194 195 if upload_id not in self.multipart_uploads: 196 raise messages.S3ClientError('The specified upload does not exist', 404) 197 198 self.multipart_uploads[upload_id][part_number] = request.bytes 199 200 etag = '"%s"' % ('x' * 32) 201 return messages.UploadPartResponse(etag, part_number) 202 203 def complete_multipart_upload(self, request): 204 MIN_PART_SIZE = 5 * 2**10 # 5 KiB 205 206 parts_received = self.multipart_uploads[request.upload_id] 207 208 # Check that we got all the parts that they intended to send 209 part_numbers_to_confirm = set(part['PartNumber'] for part in request.parts) 210 211 # Make sure all the expected parts are present 212 if part_numbers_to_confirm != set(parts_received.keys()): 213 raise messages.S3ClientError( 214 'One or more of the specified parts could not be found', 400) 215 216 # Sort by part number 217 sorted_parts = sorted(parts_received.items(), key=lambda pair: pair[0]) 218 sorted_bytes = [bytes_ for (_, bytes_) in sorted_parts] 219 220 # Make sure that the parts aren't too small (except the last part) 221 part_sizes = [len(bytes_) for bytes_ in sorted_bytes] 222 if any(size < MIN_PART_SIZE for size in part_sizes[:-1]): 223 e_message = """ 224 All parts but the last must be larger than %d bytes 225 """ % MIN_PART_SIZE 226 raise messages.S3ClientError(e_message, 400) 227 228 # String together all bytes for the given upload 229 final_contents = b''.join(sorted_bytes) 230 231 # Create FakeFile object 232 num_parts = len(parts_received) 233 etag = '"%s-%d"' % ('x' * 32, num_parts) 234 file_ = FakeFile(request.bucket, request.object, final_contents, etag=etag) 235 236 # Store FakeFile in self.files 237 self.add_file(file_)