github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/aws/clients/s3/fake_client.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import datetime
    21  import time
    22  
    23  import pytz
    24  
    25  from apache_beam.io.aws.clients.s3 import messages
    26  
    27  
    28  class FakeFile(object):
    29    def __init__(self, bucket, key, contents, etag=None):
    30      self.bucket = bucket
    31      self.key = key
    32      self.contents = contents
    33  
    34      self.last_modified = time.time()
    35  
    36      if not etag:
    37        self.etag = '"%s-1"' % ('x' * 32)
    38      else:
    39        self.etag = etag
    40  
    41    def get_metadata(self):
    42      last_modified_datetime = None
    43      if self.last_modified:
    44        last_modified_datetime = datetime.datetime.fromtimestamp(
    45            self.last_modified, pytz.utc)
    46  
    47      return messages.Item(
    48          self.etag,
    49          self.key,
    50          last_modified_datetime,
    51          len(self.contents),
    52          mime_type=None)
    53  
    54  
    55  class FakeS3Client(object):
    56    def __init__(self):
    57      self.files = {}
    58      self.list_continuation_tokens = {}
    59      self.multipart_uploads = {}
    60  
    61      # boto3 has different behavior when running some operations against a bucket
    62      # that exists vs. against one that doesn't. To emulate that behavior, the
    63      # mock client keeps a set of bucket names that it knows "exist".
    64      self.known_buckets = set()
    65  
    66    def add_file(self, f):
    67      self.files[(f.bucket, f.key)] = f
    68      if f.bucket not in self.known_buckets:
    69        self.known_buckets.add(f.bucket)
    70  
    71    def get_file(self, bucket, obj):
    72      try:
    73        return self.files[bucket, obj]
    74      except:
    75        raise messages.S3ClientError('Not Found', 404)
    76  
    77    def delete_file(self, bucket, obj):
    78      del self.files[(bucket, obj)]
    79  
    80    def get_object_metadata(self, request):
    81      r"""Retrieves an object's metadata.
    82  
    83      Args:
    84        request: (GetRequest) input message
    85  
    86      Returns:
    87        (Item) The response message.
    88      """
    89      # TODO: Do we want to mock out a lack of credentials?
    90      file_ = self.get_file(request.bucket, request.object)
    91      return file_.get_metadata()
    92  
    93    def list(self, request):
    94      bucket = request.bucket
    95      prefix = request.prefix or ''
    96      matching_files = []
    97  
    98      for file_bucket, file_name in sorted(iter(self.files)):
    99        if bucket == file_bucket and file_name.startswith(prefix):
   100          file_object = self.get_file(file_bucket, file_name).get_metadata()
   101          matching_files.append(file_object)
   102  
   103      if not matching_files:
   104        message = 'Tried to list nonexistent S3 path: s3://%s/%s' % (
   105            bucket, prefix)
   106        raise messages.S3ClientError(message, 404)
   107  
   108      # Handle pagination.
   109      items_per_page = 5
   110      if not request.continuation_token:
   111        range_start = 0
   112      else:
   113        if request.continuation_token not in self.list_continuation_tokens:
   114          raise ValueError('Invalid page token.')
   115        range_start = self.list_continuation_tokens[request.continuation_token]
   116        del self.list_continuation_tokens[request.continuation_token]
   117  
   118      result = messages.ListResponse(
   119          items=matching_files[range_start:range_start + items_per_page])
   120  
   121      if range_start + items_per_page < len(matching_files):
   122        next_range_start = range_start + items_per_page
   123        next_continuation_token = '_page_token_%s_%s_%d' % (
   124            bucket, prefix, next_range_start)
   125        self.list_continuation_tokens[next_continuation_token] = next_range_start
   126        result.next_token = next_continuation_token
   127  
   128      return result
   129  
   130    def get_range(self, request, start, end):
   131      r"""Retrieves an object.
   132  
   133        Args:
   134          request: (GetRequest) request
   135        Returns:
   136          (bytes) The response message.
   137        """
   138  
   139      file_ = self.get_file(request.bucket, request.object)
   140  
   141      # Replicates S3's behavior, per the spec here:
   142      # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35
   143      if start < 0 or end <= start:
   144        return file_.contents
   145  
   146      return file_.contents[start:end]
   147  
   148    def delete(self, request):
   149      if request.bucket not in self.known_buckets:
   150        raise messages.S3ClientError('The specified bucket does not exist', 404)
   151  
   152      if (request.bucket, request.object) in self.files:
   153        self.delete_file(request.bucket, request.object)
   154      else:
   155        # S3 doesn't raise an error if you try to delete a nonexistent file from
   156        # an extant bucket
   157        return
   158  
   159    def delete_batch(self, request):
   160  
   161      deleted, failed, errors = [], [], []
   162      for object in request.objects:
   163        try:
   164          delete_request = messages.DeleteRequest(request.bucket, object)
   165          self.delete(delete_request)
   166          deleted.append(object)
   167        except messages.S3ClientError as e:
   168          failed.append(object)
   169          errors.append(e)
   170  
   171      return messages.DeleteBatchResponse(deleted, failed, errors)
   172  
   173    def copy(self, request):
   174  
   175      src_file = self.get_file(request.src_bucket, request.src_key)
   176      dest_file = FakeFile(
   177          request.dest_bucket, request.dest_key, src_file.contents)
   178      self.add_file(dest_file)
   179  
   180    def create_multipart_upload(self, request):
   181      # Create hash of bucket and key
   182      # Store upload_id internally
   183      upload_id = request.bucket + request.object
   184      self.multipart_uploads[upload_id] = {}
   185      return messages.UploadResponse(upload_id)
   186  
   187    def upload_part(self, request):
   188      # Save off bytes passed to internal data store
   189      upload_id, part_number = request.upload_id, request.part_number
   190  
   191      if part_number < 0 or not isinstance(part_number, int):
   192        raise messages.S3ClientError(
   193            'Param validation failed on part number', 400)
   194  
   195      if upload_id not in self.multipart_uploads:
   196        raise messages.S3ClientError('The specified upload does not exist', 404)
   197  
   198      self.multipart_uploads[upload_id][part_number] = request.bytes
   199  
   200      etag = '"%s"' % ('x' * 32)
   201      return messages.UploadPartResponse(etag, part_number)
   202  
   203    def complete_multipart_upload(self, request):
   204      MIN_PART_SIZE = 5 * 2**10  # 5 KiB
   205  
   206      parts_received = self.multipart_uploads[request.upload_id]
   207  
   208      # Check that we got all the parts that they intended to send
   209      part_numbers_to_confirm = set(part['PartNumber'] for part in request.parts)
   210  
   211      # Make sure all the expected parts are present
   212      if part_numbers_to_confirm != set(parts_received.keys()):
   213        raise messages.S3ClientError(
   214            'One or more of the specified parts could not be found', 400)
   215  
   216      # Sort by part number
   217      sorted_parts = sorted(parts_received.items(), key=lambda pair: pair[0])
   218      sorted_bytes = [bytes_ for (_, bytes_) in sorted_parts]
   219  
   220      # Make sure that the parts aren't too small (except the last part)
   221      part_sizes = [len(bytes_) for bytes_ in sorted_bytes]
   222      if any(size < MIN_PART_SIZE for size in part_sizes[:-1]):
   223        e_message = """
   224        All parts but the last must be larger than %d bytes
   225        """ % MIN_PART_SIZE
   226        raise messages.S3ClientError(e_message, 400)
   227  
   228      # String together all bytes for the given upload
   229      final_contents = b''.join(sorted_bytes)
   230  
   231      # Create FakeFile object
   232      num_parts = len(parts_received)
   233      etag = '"%s-%d"' % ('x' * 32, num_parts)
   234      file_ = FakeFile(request.bucket, request.object, final_contents, etag=etag)
   235  
   236      # Store FakeFile in self.files
   237      self.add_file(file_)