github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/clients/python-wrapper/lakefs/branch.py (about)

     1  """
     2  Module containing lakeFS branch implementation
     3  """
     4  
     5  from __future__ import annotations
     6  
     7  import uuid
     8  import warnings
     9  from contextlib import contextmanager
    10  from typing import Optional, Generator, Iterable, Literal, Dict
    11  
    12  import lakefs_sdk
    13  from lakefs.client import Client
    14  from lakefs.object import WriteableObject
    15  from lakefs.object import StoredObject
    16  from lakefs.import_manager import ImportManager
    17  from lakefs.reference import Reference, ReferenceType, generate_listing
    18  from lakefs.models import Change, Commit
    19  from lakefs.exceptions import (
    20      api_exception_handler,
    21      ConflictException,
    22      LakeFSException,
    23      TransactionException
    24  )
    25  
    26  
    27  class _BaseBranch(Reference):
    28  
    29      def object(self, path: str) -> WriteableObject:
    30          """
    31          Returns a writable object using the current repo id, reference and path
    32  
    33          :param path: The object's path
    34          """
    35  
    36          return WriteableObject(self.repo_id, self._id, path, client=self._client)
    37  
    38      def uncommitted(self, max_amount: Optional[int] = None, after: Optional[str] = None, prefix: Optional[str] = None,
    39                      **kwargs) -> Generator[Change]:
    40          """
    41          Returns a diff generator of uncommitted changes on this branch
    42  
    43          :param max_amount: Stop showing changes after this amount
    44          :param after: Return items after this value
    45          :param prefix: Return items prefixed with this value
    46          :param kwargs: Additional Keyword Arguments to send to the server
    47          :raise NotFoundException: if branch or repository do not exist
    48          :raise NotAuthorizedException: if user is not authorized to perform this operation
    49          :raise ServerException: for any other errors
    50          """
    51  
    52          for diff in generate_listing(self._client.sdk_client.branches_api.diff_branch,
    53                                       self._repo_id, self._id, max_amount=max_amount, after=after, prefix=prefix,
    54                                       **kwargs):
    55              yield Change(**diff.dict())
    56  
    57      def delete_objects(self, object_paths: str | StoredObject | Iterable[str | StoredObject]) -> None:
    58          """
    59          Delete objects from lakeFS
    60  
    61          This method can be used to delete single/multiple objects from branch. It accepts both str and StoredObject
    62          types as well as Iterables of these types.
    63          Using this method is more performant than sequentially calling delete on objects as it saves the back and forth
    64          from the server.
    65  
    66          This can also be used in combination with object listing. For example:
    67  
    68          .. code-block:: python
    69  
    70              import lakefs
    71  
    72              branch = lakefs.repository("<repository_name>").branch("<branch_name>")
    73              # list objects on a common prefix
    74              objs = branch.objects(prefix="my-object-prefix/", max_amount=100)
    75              # delete objects which have "foo" in their name
    76              branch.delete_objects([o.path for o in objs if "foo" in o.path])
    77  
    78          :param object_paths: a single path or an iterable of paths to delete
    79          :raise NotFoundException: if branch or repository do not exist
    80          :raise NotAuthorizedException: if user is not authorized to perform this operation
    81          :raise ServerException: for any other errors
    82          """
    83          if isinstance(object_paths, str):
    84              object_paths = [object_paths]
    85          elif isinstance(object_paths, StoredObject):
    86              object_paths = [object_paths.path]
    87          elif isinstance(object_paths, Iterable):
    88              object_paths = [o.path if isinstance(o, StoredObject) else o for o in object_paths]
    89          with api_exception_handler():
    90              return self._client.sdk_client.objects_api.delete_objects(
    91                  self._repo_id,
    92                  self._id,
    93                  lakefs_sdk.PathList(paths=object_paths)
    94              )
    95  
    96      def reset_changes(self, path_type: Literal["common_prefix", "object", "reset"] = "reset",
    97                        path: Optional[str] = None) -> None:
    98          """
    99          Reset uncommitted changes (if any) on this branch
   100  
   101          :param path_type: the type of path to reset ('common_prefix', 'object', 'reset' - for all changes)
   102          :param path: the path to reset (optional) - if path_type is 'reset' this parameter is ignored
   103          :raise ValidationError: if path_type is not one of the allowed values
   104          :raise NotFoundException: if branch or repository do not exist
   105          :raise NotAuthorizedException: if user is not authorized to perform this operation
   106          :raise ServerException: for any other errors
   107          """
   108  
   109          reset_creation = lakefs_sdk.ResetCreation(path=path, type=path_type)
   110          return self._client.sdk_client.branches_api.reset_branch(self._repo_id, self.id, reset_creation)
   111  
   112  
   113  class Branch(_BaseBranch):
   114      """
   115      Class representing a branch in lakeFS.
   116      """
   117  
   118      def __init__(self, repository_id: str, branch_id: str, client: Optional[Client] = None):
   119          super().__init__(repository_id, reference_id=branch_id, client=client)
   120  
   121      def get_commit(self):
   122          """
   123          For branches override the default _get_commit method to ensure we always fetch the latest head
   124          """
   125          self._commit = None
   126          return super().get_commit()
   127  
   128      def cherry_pick(self, reference: ReferenceType, parent_number: Optional[int] = None) -> Commit:
   129          """
   130          Cherry-pick a given reference onto the branch.
   131  
   132          :param reference: ID of the reference to cherry-pick.
   133          :param parent_number: When cherry-picking a merge commit, the parent number (starting from 1)
   134              with which to perform the diff. The default branch is parent 1.
   135          :return: The cherry-picked commit at the head of the branch.
   136          :raise NotFoundException: If either the repository or target reference do not exist.
   137          :raise NotAuthorizedException: If the user is not authorized to perform this operation.
   138          :raise ServerException: For any other errors.
   139          """
   140          ref = reference if isinstance(reference, str) else reference.id
   141          cherry_pick_creation = lakefs_sdk.CherryPickCreation(ref=ref, parent_number=parent_number)
   142          with api_exception_handler():
   143              res = self._client.sdk_client.branches_api.cherry_pick(self._repo_id, self._id, cherry_pick_creation)
   144              return Commit(**res.dict())
   145  
   146      def create(self, source_reference: ReferenceType, exist_ok: bool = False) -> Branch:
   147          """
   148          Create a new branch in lakeFS from this object
   149  
   150          Example of creating a new branch:
   151  
   152          .. code-block:: python
   153  
   154              import lakefs
   155  
   156              branch = lakefs.repository("<repository_name>").branch("<branch_name>").create("<source_reference>")
   157  
   158          :param source_reference: The reference to create the branch from (reference ID, object or Commit object)
   159          :param exist_ok: If False will throw an exception if a branch by this name already exists. Otherwise,
   160              return the existing branch without creating a new one
   161          :return: The lakeFS SDK object representing the branch
   162          :raise NotFoundException: if repo, branch or source reference id does not exist
   163          :raise ConflictException: if branch already exists and exist_ok is False
   164          :raise NotAuthorizedException: if user is not authorized to perform this operation
   165          :raise ServerException: for any other errors
   166          """
   167  
   168          def handle_conflict(e: LakeFSException):
   169              if isinstance(e, ConflictException) and exist_ok:
   170                  return None
   171              return e
   172  
   173          reference_id = source_reference if isinstance(source_reference, str) else source_reference.id
   174          branch_creation = lakefs_sdk.BranchCreation(name=self._id, source=reference_id)
   175          with api_exception_handler(handle_conflict):
   176              self._client.sdk_client.branches_api.create_branch(self._repo_id, branch_creation)
   177          return self
   178  
   179      @property
   180      def head(self) -> Reference:
   181          """
   182          Get the commit reference this branch is pointing to
   183  
   184          :return: The commit reference this branch is pointing to
   185          :raise NotFoundException: if branch by this id does not exist
   186          :raise NotAuthorizedException: if user is not authorized to perform this operation
   187          :raise ServerException: for any other errors
   188          """
   189          with api_exception_handler():
   190              branch = self._client.sdk_client.branches_api.get_branch(self._repo_id, self._id)
   191          return Reference(self._repo_id, branch.commit_id, self._client)
   192  
   193      def commit(self, message: str, metadata: dict = None, **kwargs) -> Reference:
   194          """
   195          Commit changes on the current branch
   196  
   197          :param message: Commit message
   198          :param metadata: Metadata to attach to the commit
   199          :param kwargs: Additional Keyword Arguments for commit creation
   200          :return: The new reference after the commit
   201          :raise NotFoundException: if branch by this id does not exist
   202          :raise ForbiddenException: if commit is not allowed on this branch
   203          :raise NotAuthorizedException: if user is not authorized to perform this operation
   204          :raise ServerException: for any other errors
   205          """
   206          commits_creation = lakefs_sdk.CommitCreation(message=message, metadata=metadata, **kwargs)
   207  
   208          with api_exception_handler():
   209              c = self._client.sdk_client.commits_api.commit(self._repo_id, self._id, commits_creation)
   210          return Reference(self._repo_id, c.id, self._client)
   211  
   212      def delete(self) -> None:
   213          """
   214          Delete branch from lakeFS server
   215  
   216          :raise NotFoundException: if branch or repository do not exist
   217          :raise NotAuthorizedException: if user is not authorized to perform this operation
   218          :raise ForbiddenException: for branches that are protected
   219          :raise ServerException: for any other errors
   220          """
   221          with api_exception_handler():
   222              return self._client.sdk_client.branches_api.delete_branch(self._repo_id, self._id)
   223  
   224      def revert(self, reference: Optional[ReferenceType], parent_number: int = 0, *,
   225                 reference_id: Optional[str] = None) -> Commit:
   226          """
   227          revert the changes done by the provided reference on the current branch
   228  
   229          :param reference_id: (Optional) The reference ID to revert
   230  
   231              .. deprecated:: 0.4.0
   232                  Use ``reference`` instead.
   233  
   234          :param parent_number: when reverting a merge commit, the parent number (starting from 1) relative to which to
   235              perform the revert. The default for non merge commits is 0
   236          :param reference: the reference to revert
   237          :return: The commit created by the revert
   238          :raise NotFoundException: if branch by this id does not exist
   239          :raise NotAuthorizedException: if user is not authorized to perform this operation
   240          :raise ServerException: for any other errors
   241          """
   242          if parent_number < 0:
   243              raise ValueError("parent_number must be a non-negative integer")
   244  
   245          if reference_id is not None:
   246              warnings.warn(
   247                  "reference_id is deprecated, please use the `reference` argument.", DeprecationWarning
   248              )
   249  
   250          # Handle reference_id as a deprecated alias to reference.
   251          reference = reference or reference_id
   252          if reference is None:
   253              raise ValueError("reference to revert must be specified")
   254          ref = reference if isinstance(reference, str) else reference.id
   255  
   256          with api_exception_handler():
   257              self._client.sdk_client.branches_api.revert_branch(
   258                  self._repo_id,
   259                  self._id,
   260                  lakefs_sdk.RevertCreation(ref=ref, parent_number=parent_number)
   261              )
   262              commit = self._client.sdk_client.commits_api.get_commit(self._repo_id, self._id)
   263              return Commit(**commit.dict())
   264  
   265      def import_data(self, commit_message: str = "", metadata: Optional[dict] = None) -> ImportManager:
   266          """
   267          Import data to lakeFS
   268  
   269          :param metadata: metadata to attach to the commit
   270          :param commit_message: once the data is imported, a commit is created with this message. If default (empty)
   271              message is provided, uses the default server commit message for imports.
   272          :return: an ImportManager object
   273          """
   274          return ImportManager(self._repo_id, self._id, commit_message, metadata, self._client)
   275  
   276      @contextmanager
   277      def transact(self, commit_message: str = "", commit_metadata: Optional[Dict] = None,
   278                   delete_branch_on_error: bool = True) -> _Transaction:
   279          """
   280          Create a transaction for multiple operations.
   281          Transaction allows for multiple modifications to be performed atomically on a branch,
   282          similar to a database transaction.
   283          It ensures that the branch remains unaffected until the transaction is successfully completed.
   284          The process includes:
   285  
   286          1. Creating an ephemeral branch from this branch
   287          2. Perform object operations on ephemeral branch
   288          3. Commit changes
   289          4. Merge back to source branch
   290          5. Delete ephemeral branch
   291  
   292          Using a transaction the code for this flow will look like this:
   293  
   294          .. code-block:: python
   295  
   296              import lakefs
   297  
   298              branch = lakefs.repository("<repository_name>").branch("<branch_name>")
   299              with branch.transact(commit_message="my transaction") as tx:
   300                  for obj in tx.objects(prefix="prefix_to_delete/"):  # Delete some objects
   301                      obj.delete()
   302  
   303                  # Create new object
   304                  tx.object("new_object").upload("new object data")
   305  
   306          Note that unlike database transactions, lakeFS transaction does not take a "lock" on the branch, and therefore
   307          the transaction might fail due to changes in source branch after the transaction was created.
   308  
   309          :param commit_message: once the transaction is committed, a commit is created with this message
   310          :param commit_metadata: user metadata for the transaction commit
   311          :param delete_branch_on_error: Defaults to True. Ensures ephemeral branch is deleted on error.
   312          :return: a Transaction object to perform the operations on
   313          """
   314          with Transaction(self._repo_id, self._id, commit_message, commit_metadata, delete_branch_on_error,
   315                           self._client) as tx:
   316              yield tx
   317  
   318  
   319  class _Transaction(_BaseBranch):
   320      @staticmethod
   321      def _get_tx_name() -> str:
   322          return f"tx-{uuid.uuid4()}"  # Don't rely on source branch name as this might exceed valid branch length
   323  
   324      def __init__(self, repository_id: str, branch_id: str, commit_message: str = "",
   325                   commit_metadata: Optional[Dict] = None, client: Client = None):
   326          self._commit_message = commit_message
   327          self._commit_metadata = commit_metadata
   328          self._source_branch = branch_id
   329  
   330          tx_name = self._get_tx_name()
   331          self._tx_branch = Branch(repository_id, tx_name, client).create(branch_id)
   332          super().__init__(repository_id, tx_name, client)
   333  
   334      @property
   335      def source_id(self) -> str:
   336          """
   337          Returns the source branch ID the transaction is associated to
   338          """
   339          return self._source_branch
   340  
   341      @property
   342      def commit_message(self) -> str:
   343          """
   344          Return the commit message configured for this transaction completion
   345          """
   346          return self._commit_message
   347  
   348      @commit_message.setter
   349      def commit_message(self, message: str) -> None:
   350          """
   351          Set the commit message for this transaction completion
   352          :param message: The commit message to use on the transaction merge commit
   353          """
   354          self._commit_message = message
   355  
   356      @property
   357      def commit_metadata(self) -> Optional[Dict]:
   358          """
   359          Return the commit metadata configured for this transaction completion
   360          """
   361          return self._commit_metadata
   362  
   363      @commit_metadata.setter
   364      def commit_metadata(self, metadata: Optional[Dict]) -> None:
   365          """
   366          Set the commit metadata for this transaction completion
   367          :param metadata: The metadata to use on the transaction merge commit
   368          """
   369          self._commit_metadata = metadata
   370  
   371  
   372  class Transaction:
   373      """
   374      Manage a transaction on a given branch
   375  
   376      The transaction creates an ephemeral branch from the source branch. The transaction can then be used to perform
   377      operations on the branch which will later be merged back into the source branch.
   378      Currently, transaction is supported only as a context manager.
   379      """
   380  
   381      def __init__(self, repository_id: str, branch_id: str, commit_message: str = "",
   382                   commit_metadata: Optional[Dict] = None, delete_branch_on_error: bool = True, client: Client = None):
   383          self._repo_id = repository_id
   384          self._commit_message = commit_message
   385          self._commit_metadata = commit_metadata
   386          self._source_branch = branch_id
   387          self._client = client
   388          self._tx = None
   389          self._tx_branch = None
   390          self._cleanup_branch = delete_branch_on_error
   391  
   392      def __enter__(self):
   393          self._tx = _Transaction(self._repo_id, self._source_branch, self._commit_message, self._commit_metadata,
   394                                  self._client)
   395          self._tx_branch = Branch(self._repo_id, self._tx.id, self._client)
   396          return self._tx
   397  
   398      def __exit__(self, typ, value, traceback) -> bool:
   399          if typ is not None:  # Perform only cleanup in case exception occurred
   400              if self._cleanup_branch:
   401                  self._tx_branch.delete()
   402              return False  # Raise the underlying exception
   403  
   404          try:
   405              self._tx_branch.commit(message=self._tx.commit_message, metadata=self._tx.commit_metadata)
   406              self._tx_branch.merge_into(self._source_branch, message=f"Merge transaction {self._tx.id} to branch")
   407              self._tx_branch.delete()
   408  
   409              return False
   410          except LakeFSException as e:
   411              if self._cleanup_branch:
   412                  self._tx_branch.delete()
   413              raise TransactionException(f"Failed committing transaction {self._tx.id}: {e}") from e