
     1  #
     2  # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
     3  #
     4  import os
     5  from urllib.parse import urljoin, urlencode
     6  from typing import TypeVar, Type, Any, Dict
     8  import requests
    10  from aistore.sdk.const import (
    11      JSON_CONTENT_TYPE,
    12      HEADER_USER_AGENT,
    13      USER_AGENT_BASE,
    15      AIS_SERVER_CRT,
    16  )
    17  from aistore.sdk.utils import handle_errors, decode_response
    18  from aistore.version import __version__ as sdk_version
    20  T = TypeVar("T")
    23  # pylint: disable=unused-variable, duplicate-code
    24  class RequestClient:
    25      """
    26      Internal client for buckets, objects, jobs, etc. to use for making requests to an AIS cluster
    28      Args:
    29          endpoint (str): AIStore endpoint
    30      """
    32      def __init__(
    33          self,
    34          endpoint: str,
    35          skip_verify: bool = False,
    36          ca_cert: str = None,
    37          timeout=None,
    38      ):
    39          self._endpoint = endpoint
    40          self._base_url = urljoin(endpoint, "v1")
    41          self._session = requests.sessions.session()
    42          self._timeout = timeout
    43          if "https" in self._endpoint:
    44              self._set_session_verification(skip_verify, ca_cert)
    46      def _set_session_verification(self, skip_verify: bool, ca_cert: str):
    47          """
    48          Set session verify value for validating the server's SSL certificate
    49          The requests library allows this to be a boolean or a string path to the cert
    50          If we do not skip verification, the order is:
    51            1. Provided cert path
    52            2. Cert path from env var.
    53            3. True (verify with system's approved CA list)
    54          """
    55          if skip_verify:
    56              self._session.verify = False
    57              return
    58          if ca_cert:
    59              self._session.verify = ca_cert
    60              return
    61          env_crt = os.getenv(AIS_SERVER_CRT)
    62          self._session.verify = env_crt if env_crt else True
    64      @property
    65      def base_url(self):
    66          """
    67          Returns: AIS cluster base url
    68          """
    69          return self._base_url
    71      @property
    72      def endpoint(self):
    73          """
    74          Returns: AIS cluster endpoint
    75          """
    76          return self._endpoint
    78      @property
    79      def session(self):
    80          """
    81          Returns: Active request session
    82          """
    83          return self._session
    85      def request_deserialize(
    86          self, method: str, path: str, res_model: Type[T], **kwargs
    87      ) -> T:
    88          """
    89          Make a request to the AIS cluster and deserialize the response to a defined type
    90          Args:
    91              method (str): HTTP method, e.g. POST, GET, PUT, DELETE
    92              path (str): URL path to call
    93              res_model (Type[T]): Resulting type to which the response should be deserialized
    94              **kwargs (optional): Optional keyword arguments to pass with the call to request
    96          Returns:
    97              Parsed result of the call to the API, as res_model
    98          """
    99          resp = self.request(method, path, **kwargs)
   100          return decode_response(res_model, resp)
   102      def request(
   103          self,
   104          method: str,
   105          path: str,
   106          endpoint: str = None,
   107          headers: dict = None,
   108          **kwargs,
   109      ) -> requests.Response:
   110          """
   111          Make a request to the AIS cluster
   112          Args:
   113              method (str): HTTP method, e.g. POST, GET, PUT, DELETE
   114              path (str): URL path to call
   115              endpoint (str): Alternative endpoint for the AIS cluster, e.g. for connecting to a specific proxy
   116              headers (dict): Extra headers to be passed with the request. Content-Type and User-Agent will be overridden
   117              **kwargs (optional): Optional keyword arguments to pass with the call to request
   119          Returns:
   120              Raw response from the API
   121          """
   122          base = urljoin(endpoint, "v1") if endpoint else self._base_url
   123          url = f"{base}/{path.lstrip('/')}"
   124          if headers is None:
   125              headers = {}
   126          headers[HEADER_CONTENT_TYPE] = JSON_CONTENT_TYPE
   127          headers[HEADER_USER_AGENT] = f"{USER_AGENT_BASE}/{sdk_version}"
   128          resp = self._session.request(
   129              method,
   130              url,
   131              headers=headers,
   132              timeout=self._timeout,
   133              **kwargs,
   134          )
   135          if resp.status_code < 200 or resp.status_code >= 300:
   136              handle_errors(resp)
   137          return resp
   139      def get_full_url(self, path: str, params: Dict[str, Any]):
   140          """
   141          Get the full URL to the path on the cluster with the parameters given
   143          Args:
   144              path: Path on the cluster
   145              params: Query parameters to include
   147          Returns:
   148              URL including cluster base url and parameters
   150          """
   151          return f"{self._base_url}/{path.lstrip('/')}?{urlencode(params)}"