github.com/aitjcize/Overlord@v0.0.0-20240314041920-104a804cf5e8/test/overlord_e2e_unittest.py (about)

     1  #!/usr/bin/env python3
     2  # Copyright 2015 The Chromium OS Authors. All rights reserved.
     3  # Use of this source code is governed by a BSD-style license that can be
     4  # found in the LICENSE file.
     5  
     6  from contextlib import closing
     7  import json
     8  import os
     9  import shutil
    10  import socket
    11  import subprocess
    12  import tempfile
    13  import time
    14  import unittest
    15  import urllib.parse
    16  import urllib.request
    17  
    18  from ws4py.client import WebSocketBaseClient
    19  
    20  
    21  # Constants.
    22  _HOST = '127.0.0.1'
    23  _INCREMENT = 42
    24  
    25  
    26  class TestError(Exception):
    27    pass
    28  
    29  
    30  class CloseWebSocket(Exception):
    31    pass
    32  
    33  
    34  def FindUnusedPort():
    35    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
    36      s.bind(('', 0))
    37      s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    38      return s.getsockname()[1]
    39  
    40  
    41  class TestOverlord(unittest.TestCase):
    42    @classmethod
    43    def setUpClass(cls):
    44      # Build overlord, only do this once over all tests.
    45      gitroot = os.path.normpath(os.path.join(os.path.dirname(__file__),
    46                                 '..'))
    47      cls.bindir = tempfile.mkdtemp()
    48      subprocess.call('make -C %s BIN=%s' % (gitroot, cls.bindir), shell=True)
    49  
    50    @classmethod
    51    def tearDownClass(cls):
    52      if os.path.isdir(cls.bindir):
    53        shutil.rmtree(cls.bindir)
    54  
    55    def setUp(self):
    56      self.basedir = os.path.dirname(__file__)
    57      bindir = self.__class__.bindir
    58      scriptdir = os.path.normpath(os.path.join(self.basedir, '../scripts'))
    59  
    60      env = os.environ.copy()
    61      env['SHELL'] = os.path.join(os.getcwd(), self.basedir, 'test_shell.sh')
    62  
    63      # set ports for overlord to bind
    64      overlord_http_port = FindUnusedPort()
    65      self.host = '%s:%d' % (_HOST, overlord_http_port)
    66      env['OVERLORD_LD_PORT'] = str(FindUnusedPort())
    67      env['GHOST_RPC_PORT'] = str(FindUnusedPort())
    68  
    69      # Launch overlord
    70      self.ovl = subprocess.Popen(['%s/overlordd' % bindir, '-no-auth',
    71                                   '-port', str(overlord_http_port)], env=env)
    72  
    73      # Launch go implementation of ghost
    74      self.goghost = subprocess.Popen(['%s/ghost' % bindir,
    75                                       '-mid=go', '-no-lan-disc',
    76                                       '-no-rpc-server', '-tls=n',
    77                                       'localhost:%d' % overlord_http_port],
    78                                      env=env)
    79  
    80      # Launch python implementation of ghost
    81      self.pyghost = subprocess.Popen(['%s/ghost.py' % scriptdir,
    82                                       '--mid=python', '--no-lan-disc',
    83                                       '--no-rpc-server', '--tls=n',
    84                                       'localhost:%d' % overlord_http_port],
    85                                      env=env)
    86  
    87      def CheckClient():
    88        try:
    89          clients = self._GetJSON('/api/agents/list')
    90          return len(clients) == 2
    91        except IOError:
    92          # overlordd is not ready yet.
    93          return False
    94  
    95      # Wait for clients to connect
    96      try:
    97        for unused_i in range(30):
    98          if CheckClient():
    99            return
   100          time.sleep(1)
   101        raise RuntimeError('client not connected')
   102      except Exception:
   103        self.tearDown()
   104        raise
   105  
   106    def tearDown(self):
   107      self.goghost.kill()
   108      self.goghost.wait()
   109  
   110      # Python implementation uses process instead of goroutine, also kill those
   111      with subprocess.Popen('pkill -P %d' % self.pyghost.pid, shell=True) as p:
   112        p.wait()
   113  
   114      self.pyghost.kill()
   115      self.pyghost.wait()
   116  
   117      self.ovl.kill()
   118      self.ovl.wait()
   119  
   120    def _GetJSON(self, path):
   121      return json.loads(urllib.request.urlopen(
   122          'http://' + self.host + path).read())
   123  
   124    def testWebAPI(self):
   125      # Test /api/app/list
   126      appdir = os.path.join(self.basedir, '../overlord/app')
   127      specialApps = ['common', 'upgrade', 'third_party']
   128      apps = [x for x in os.listdir(appdir)
   129              if os.path.isdir(os.path.join(appdir, x)) and x not in specialApps]
   130      res = self._GetJSON('/api/apps/list')
   131      assert len(res['apps']) == len(apps)
   132  
   133      # Test /api/agents/list
   134      assert len(self._GetJSON('/api/agents/list')) == 2
   135  
   136      # Test /api/logcats/list. TODO(wnhuang): test this properly
   137      assert not self._GetJSON('/api/logcats/list')
   138  
   139      # Test /api/agent/properties/mid
   140      for client in self._GetJSON('/api/agents/list'):
   141        assert self._GetJSON(
   142            '/api/agent/properties/%s' % client['mid']) is not None
   143  
   144    def testShellCommand(self):
   145      class TestClient(WebSocketBaseClient):
   146        def __init__(self, *args, **kwargs):
   147          super().__init__(*args, **kwargs)
   148          self.message = b''
   149  
   150        def handshake_ok(self):
   151          pass
   152  
   153        def received_message(self, message):
   154          self.message += message.data
   155  
   156      clients = self._GetJSON('/api/agents/list')
   157      self.assertTrue(clients)
   158      answer = subprocess.check_output(['uname', '-r'])
   159  
   160      for client in clients:
   161        ws = TestClient('ws://' + self.host + '/api/agent/shell/%s' %
   162                        urllib.parse.quote(client['mid']) + '?command=' +
   163                        urllib.parse.quote('uname -r'))
   164        ws.connect()
   165        ws.run()
   166        self.assertEqual(ws.message, answer)
   167  
   168    def testTerminalCommand(self):
   169      class TestClient(WebSocketBaseClient):
   170        NONE, PROMPT, RESPONSE = range(0, 3)
   171  
   172        def __init__(self, *args, **kwargs):
   173          super().__init__(*args, **kwargs)
   174          self.state = self.NONE
   175          self.answer = 0
   176          self.test_run = False
   177          self.buffer = b''
   178  
   179        def handshake_ok(self):
   180          pass
   181  
   182        def closed(self, code, reason=None):
   183          if not self.test_run:
   184            raise RuntimeError('test exit before being run: %s' % reason)
   185  
   186        def received_message(self, message):
   187          if message.is_text:
   188            # Ignore control messages.
   189            return
   190  
   191          self.buffer += message.data
   192          if b'\r\n' not in self.buffer:
   193            return
   194  
   195          self.test_run = True
   196          msg_text, self.buffer = self.buffer.split(b'\r\n', 1)
   197          if self.state == self.NONE:
   198            if msg_text.startswith(b'TEST-SHELL-CHALLENGE'):
   199              self.state = self.PROMPT
   200              challenge_number = int(msg_text.split()[1])
   201              self.answer = challenge_number + _INCREMENT
   202              self.send('%d\n' % self.answer)
   203          elif self.state == self.PROMPT:
   204            msg_text = msg_text.strip()
   205            if msg_text == b'SUCCESS':
   206              raise CloseWebSocket
   207            if msg_text == b'FAILED':
   208              raise TestError('Challange failed')
   209            if msg_text and int(msg_text) == self.answer:
   210              pass
   211            else:
   212              raise TestError('Unexpected response: %r' % msg_text)
   213  
   214      clients = self._GetJSON('/api/agents/list')
   215      assert clients
   216  
   217      for client in clients:
   218        ws = TestClient('ws://' + self.host + '/api/agent/tty/%s' %
   219                        urllib.parse.quote(client['mid']))
   220        ws.connect()
   221        try:
   222          ws.run()
   223        except TestError as e:
   224          raise e
   225        except CloseWebSocket:
   226          ws.close()
   227  
   228  
   229  if __name__ == '__main__':
   230    unittest.main()