github.com/secure-build/gitlab-runner@v12.5.0+incompatible/executors/docker/terminal_test.go (about)

     1  package docker
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"os"
    12  	"strings"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/docker/docker/api/types"
    17  	"github.com/gorilla/websocket"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/mock"
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"gitlab.com/gitlab-org/gitlab-runner/common"
    23  	"gitlab.com/gitlab-org/gitlab-runner/executors"
    24  	"gitlab.com/gitlab-org/gitlab-runner/helpers"
    25  	"gitlab.com/gitlab-org/gitlab-runner/helpers/docker"
    26  	"gitlab.com/gitlab-org/gitlab-runner/session"
    27  )
    28  
    29  func TestInteractiveTerminal(t *testing.T) {
    30  	if helpers.SkipIntegrationTests(t, "docker", "info") {
    31  		return
    32  	}
    33  
    34  	successfulBuild, err := common.GetRemoteLongRunningBuild()
    35  	assert.NoError(t, err)
    36  
    37  	sess, err := session.NewSession(nil)
    38  	require.NoError(t, err)
    39  
    40  	build := &common.Build{
    41  		JobResponse: successfulBuild,
    42  		Runner: &common.RunnerConfig{
    43  			RunnerSettings: common.RunnerSettings{
    44  				Executor: "docker",
    45  				Docker: &common.DockerConfig{
    46  					Image:      common.TestAlpineImage,
    47  					PullPolicy: common.PullPolicyIfNotPresent,
    48  				},
    49  			},
    50  		},
    51  		Session: sess,
    52  	}
    53  
    54  	// Start build
    55  	go func() {
    56  		_ = build.Run(&common.Config{}, &common.Trace{Writer: os.Stdout})
    57  	}()
    58  
    59  	srv := httptest.NewServer(build.Session.Mux())
    60  	defer srv.Close()
    61  
    62  	u := url.URL{
    63  		Scheme: "ws",
    64  		Host:   srv.Listener.Addr().String(),
    65  		Path:   build.Session.Endpoint + "/exec",
    66  	}
    67  	headers := http.Header{
    68  		"Authorization": []string{build.Session.Token},
    69  	}
    70  
    71  	var webSocket *websocket.Conn
    72  	var resp *http.Response
    73  
    74  	started := time.Now()
    75  
    76  	for time.Since(started) < 25*time.Second {
    77  		webSocket, resp, err = websocket.DefaultDialer.Dial(u.String(), headers)
    78  		if err == nil {
    79  			break
    80  		}
    81  
    82  		time.Sleep(50 * time.Millisecond)
    83  	}
    84  
    85  	require.NotNil(t, webSocket)
    86  	require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
    87  
    88  	defer webSocket.Close()
    89  
    90  	err = webSocket.WriteMessage(websocket.BinaryMessage, []byte("uname\n"))
    91  	require.NoError(t, err)
    92  
    93  	readStarted := time.Now()
    94  	var tty []byte
    95  	for time.Since(readStarted) < 5*time.Second {
    96  		typ, b, err := webSocket.ReadMessage()
    97  		require.NoError(t, err)
    98  		require.Equal(t, websocket.BinaryMessage, typ)
    99  		tty = append(tty, b...)
   100  
   101  		if strings.Contains(string(b), "Linux") {
   102  			break
   103  		}
   104  
   105  		time.Sleep(50 * time.Microsecond)
   106  	}
   107  
   108  	t.Log(string(tty))
   109  	assert.Contains(t, string(tty), "Linux")
   110  }
   111  
   112  func TestCommandExecutor_Connect(t *testing.T) {
   113  	tests := []struct {
   114  		name                  string
   115  		buildContainerRunning bool
   116  		hasBuildContainer     bool
   117  		containerInspectErr   error
   118  		expectedErr           error
   119  	}{
   120  		{
   121  			name: "Connect Timeout",
   122  			buildContainerRunning: false,
   123  			hasBuildContainer:     true,
   124  			expectedErr:           buildContainerTerminalTimeout{},
   125  		},
   126  		{
   127  			name: "Successful connect",
   128  			buildContainerRunning: true,
   129  			hasBuildContainer:     true,
   130  			containerInspectErr:   nil,
   131  		},
   132  		{
   133  			name: "Container inspect failed",
   134  			buildContainerRunning: false,
   135  			hasBuildContainer:     true,
   136  			containerInspectErr:   errors.New("container not found"),
   137  			expectedErr:           errors.New("container not found"),
   138  		},
   139  		{
   140  			name: "No build container",
   141  			buildContainerRunning: false,
   142  			hasBuildContainer:     false,
   143  			expectedErr:           buildContainerTerminalTimeout{},
   144  		},
   145  	}
   146  
   147  	for _, test := range tests {
   148  		t.Run(test.name, func(t *testing.T) {
   149  			c := &docker_helpers.MockClient{}
   150  			defer c.AssertExpectations(t)
   151  
   152  			s := commandExecutor{
   153  				executor: executor{
   154  					AbstractExecutor: executors.AbstractExecutor{
   155  						Context: context.Background(),
   156  						BuildShell: &common.ShellConfiguration{
   157  							DockerCommand: []string{"/bin/sh"},
   158  						},
   159  					},
   160  					client: c,
   161  				},
   162  			}
   163  
   164  			if test.hasBuildContainer {
   165  				s.buildContainer = &types.ContainerJSON{
   166  					ContainerJSONBase: &types.ContainerJSONBase{
   167  						ID: "1234",
   168  					},
   169  				}
   170  
   171  				c.On("ContainerInspect", s.Context, "1234").Return(types.ContainerJSON{
   172  					ContainerJSONBase: &types.ContainerJSONBase{
   173  						State: &types.ContainerState{
   174  							Running: test.buildContainerRunning,
   175  						},
   176  					},
   177  				}, test.containerInspectErr)
   178  			}
   179  
   180  			conn, err := s.Connect()
   181  
   182  			if test.buildContainerRunning {
   183  				assert.NoError(t, err)
   184  				assert.NotNil(t, conn)
   185  				assert.IsType(t, terminalConn{}, conn)
   186  				return
   187  			}
   188  
   189  			assert.EqualError(t, err, test.expectedErr.Error())
   190  			assert.Nil(t, conn)
   191  		})
   192  	}
   193  
   194  }
   195  
   196  func TestTerminalConn_FailToStart(t *testing.T) {
   197  	tests := []struct {
   198  		name                   string
   199  		containerExecCreateErr error
   200  		containerExecAttachErr error
   201  	}{
   202  		{
   203  			name: "Failed to create exec container",
   204  			containerExecCreateErr: errors.New("failed to create exec container"),
   205  			containerExecAttachErr: nil,
   206  		},
   207  		{
   208  			name: "Failed to attach exec container",
   209  			containerExecCreateErr: nil,
   210  			containerExecAttachErr: errors.New("failed to attach exec container"),
   211  		},
   212  	}
   213  
   214  	for _, test := range tests {
   215  		t.Run(test.name, func(t *testing.T) {
   216  			c := &docker_helpers.MockClient{}
   217  			defer c.AssertExpectations(t)
   218  
   219  			s := commandExecutor{
   220  				executor: executor{
   221  					AbstractExecutor: executors.AbstractExecutor{
   222  						Context: context.Background(),
   223  						BuildShell: &common.ShellConfiguration{
   224  							DockerCommand: []string{"/bin/sh"},
   225  						},
   226  					},
   227  					client: c,
   228  				},
   229  				buildContainer: &types.ContainerJSON{
   230  					ContainerJSONBase: &types.ContainerJSONBase{
   231  						ID: "1234",
   232  					},
   233  				},
   234  			}
   235  
   236  			c.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{
   237  				ContainerJSONBase: &types.ContainerJSONBase{
   238  					State: &types.ContainerState{
   239  						Running: true,
   240  					},
   241  				},
   242  			}, nil)
   243  
   244  			c.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(
   245  				types.IDResponse{},
   246  				test.containerExecCreateErr,
   247  			).Once()
   248  
   249  			if test.containerExecCreateErr == nil {
   250  				c.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(
   251  					types.HijackedResponse{},
   252  					test.containerExecAttachErr,
   253  				).Once()
   254  			}
   255  
   256  			conn, err := s.Connect()
   257  			require.NoError(t, err)
   258  
   259  			timeoutCh := make(chan error)
   260  			disconnectCh := make(chan error)
   261  			w := httptest.NewRecorder()
   262  			req := httptest.NewRequest(http.MethodGet, "wss://example.com/foo", nil)
   263  			conn.Start(w, req, timeoutCh, disconnectCh)
   264  
   265  			resp := w.Result()
   266  			assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
   267  		})
   268  	}
   269  }
   270  
   271  type nopReader struct {
   272  }
   273  
   274  func (w *nopReader) Read(b []byte) (int, error) {
   275  	return len(b), nil
   276  }
   277  
   278  type nopConn struct {
   279  }
   280  
   281  func (nopConn) Read(b []byte) (n int, err error) {
   282  	return len(b), nil
   283  }
   284  
   285  func (nopConn) Write(b []byte) (n int, err error) {
   286  	return len(b), nil
   287  }
   288  
   289  func (nopConn) Close() error {
   290  	return nil
   291  }
   292  
   293  func (nopConn) LocalAddr() net.Addr {
   294  	return &net.TCPAddr{}
   295  }
   296  
   297  func (nopConn) RemoteAddr() net.Addr {
   298  	return &net.TCPAddr{}
   299  }
   300  
   301  func (nopConn) SetDeadline(t time.Time) error {
   302  	return nil
   303  }
   304  
   305  func (nopConn) SetReadDeadline(t time.Time) error {
   306  	return nil
   307  }
   308  
   309  func (nopConn) SetWriteDeadline(t time.Time) error {
   310  	return nil
   311  }
   312  
   313  func TestTerminalConn_Start(t *testing.T) {
   314  	c := &docker_helpers.MockClient{}
   315  	defer c.AssertExpectations(t)
   316  
   317  	s := commandExecutor{
   318  		executor: executor{
   319  			AbstractExecutor: executors.AbstractExecutor{
   320  				Context: context.Background(),
   321  				BuildShell: &common.ShellConfiguration{
   322  					DockerCommand: []string{"/bin/sh"},
   323  				},
   324  			},
   325  			client: c,
   326  		},
   327  		buildContainer: &types.ContainerJSON{
   328  			ContainerJSONBase: &types.ContainerJSONBase{
   329  				ID: "1234",
   330  			},
   331  		},
   332  	}
   333  
   334  	c.On("ContainerInspect", mock.Anything, "1234").Return(types.ContainerJSON{
   335  		ContainerJSONBase: &types.ContainerJSONBase{
   336  			State: &types.ContainerState{
   337  				Running: true,
   338  			},
   339  		},
   340  	}, nil).Once()
   341  
   342  	c.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(types.IDResponse{
   343  		ID: "4321",
   344  	}, nil).Once()
   345  
   346  	c.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{
   347  		Conn:   nopConn{},
   348  		Reader: bufio.NewReader(&nopReader{}),
   349  	}, nil).Once()
   350  
   351  	c.On("ContainerInspect", mock.Anything, "1234").Return(types.ContainerJSON{
   352  		ContainerJSONBase: &types.ContainerJSONBase{
   353  			State: &types.ContainerState{
   354  				Running: false,
   355  			},
   356  		},
   357  	}, nil)
   358  
   359  	session, err := session.NewSession(nil)
   360  	require.NoError(t, err)
   361  	session.Token = "validToken"
   362  
   363  	session.SetInteractiveTerminal(&s)
   364  
   365  	srv := httptest.NewServer(session.Mux())
   366  
   367  	u := url.URL{
   368  		Scheme: "ws",
   369  		Host:   srv.Listener.Addr().String(),
   370  		Path:   session.Endpoint + "/exec",
   371  	}
   372  	headers := http.Header{
   373  		"Authorization": []string{"validToken"},
   374  	}
   375  
   376  	conn, resp, err := websocket.DefaultDialer.Dial(u.String(), headers)
   377  	require.NoError(t, err)
   378  	require.NotNil(t, conn)
   379  	require.Equal(t, resp.StatusCode, http.StatusSwitchingProtocols)
   380  
   381  	defer conn.Close()
   382  
   383  	go func() {
   384  		for {
   385  			err := conn.WriteMessage(websocket.BinaryMessage, []byte("data"))
   386  			if err != nil {
   387  				return
   388  			}
   389  
   390  			time.Sleep(time.Second)
   391  		}
   392  	}()
   393  
   394  	started := time.Now()
   395  
   396  	for time.Since(started) < 5*time.Second {
   397  		if !session.Connected() {
   398  			break
   399  		}
   400  
   401  		time.Sleep(50 * time.Microsecond)
   402  	}
   403  
   404  	assert.False(t, session.Connected())
   405  }