github.com/snowflakedb/gosnowflake@v1.9.0/ci/scripts/hang_webserver.py (about)

     1  #!/usr/bin/env python3
     2  import sys
     3  from http.server import BaseHTTPRequestHandler,HTTPServer
     4  from socketserver import ThreadingMixIn
     5  import threading
     6  import time
     7  import json
     8  
     9  class HTTPRequestHandler(BaseHTTPRequestHandler):
    10      invocations = 0
    11  
    12      def do_POST(self):
    13          if self.path.startswith('/reset'):
    14              print("Resetting HTTP mocks")
    15              HTTPRequestHandler.invocations = 0
    16              self.__respond(200)
    17          elif self.path.startswith('/invocations'):
    18              self.__respond(200, body=str(HTTPRequestHandler.invocations))
    19          elif self.path.startswith('/ocsp'):
    20              print("ocsp")
    21              self.ocspMocks()
    22          elif self.path.startswith('/session/v1/login-request'):
    23              self.authMocks()
    24  
    25      def ocspMocks(self):
    26          if self.path.startswith('/ocsp/403'):
    27              self.send_response(403)
    28              self.send_header('Content-Type', 'text/plain')
    29              self.end_headers()
    30          elif self.path.startswith('/ocsp/404'):
    31              self.send_response(404)
    32              self.send_header('Content-Type', 'text/plain')
    33              self.end_headers()
    34          elif self.path.startswith('/ocsp/hang'):
    35              print("Hanging")
    36              time.sleep(300)
    37              self.send_response(200, 'OK')
    38              self.send_header('Content-Type', 'text/plain')
    39              self.end_headers()
    40          else:
    41              self.send_response(200, 'OK')
    42              self.send_header('Content-Type', 'text/plain')
    43              self.end_headers()
    44  
    45      def authMocks(self):
    46          content_length = int(self.headers.get('content-length', 0))
    47          body = self.rfile.read(content_length)
    48          jsonBody = json.loads(body)
    49          if jsonBody['data']['ACCOUNT_NAME'] == "jwtAuthTokenTimeout":
    50              HTTPRequestHandler.invocations += 1
    51              if HTTPRequestHandler.invocations >= 3:
    52                  self.__respond(200, body='''{
    53                      "data": {
    54                          "token": "someToken"
    55                      },
    56                      "success": true
    57                  }''')
    58              else:
    59                  time.sleep(2000)
    60                  self.send_response(200)
    61          else:
    62              print("Unknown auth request")
    63              self.send_response(500)
    64  
    65      def __respond(self, http_code, content_type='application/json', body=None):
    66          print("responding:", body)
    67          self.send_response(http_code)
    68          self.send_header('Content-Type', content_type)
    69          self.end_headers()
    70          if body != None:
    71              responseBody = bytes(body, "utf-8")
    72              self.wfile.write(responseBody)
    73  
    74      do_GET = do_POST
    75  
    76  class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    77    allow_reuse_address = True
    78  
    79    def shutdown(self):
    80      self.socket.close()
    81      HTTPServer.shutdown(self)
    82  
    83  class SimpleHttpServer():
    84    def __init__(self, ip, port):
    85      self.server = ThreadedHTTPServer((ip,port), HTTPRequestHandler)
    86  
    87    def start(self):
    88      self.server_thread = threading.Thread(target=self.server.serve_forever)
    89      self.server_thread.daemon = True
    90      self.server_thread.start()
    91  
    92    def waitForThread(self):
    93      self.server_thread.join()
    94  
    95    def stop(self):
    96      self.server.shutdown()
    97      self.waitForThread()
    98  
    99  if __name__=='__main__':
   100      if len(sys.argv) != 2:
   101          print("Usage: python3 {} PORT".format(sys.argv[0]))
   102          sys.exit(2)
   103  
   104      PORT = int(sys.argv[1])
   105  
   106      server = SimpleHttpServer('localhost', PORT)
   107      print('HTTP Server Running on PORT {}..........'.format(PORT))
   108      server.start()
   109      server.waitForThread()
   110