github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/containerserver/hijack.go (about)

     1  package containerserver
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"time"
     9  
    10  	"code.cloudfoundry.org/garden"
    11  	"code.cloudfoundry.org/lager"
    12  	"github.com/pf-qiu/concourse/v6/atc"
    13  	"github.com/pf-qiu/concourse/v6/atc/api/accessor"
    14  	"github.com/pf-qiu/concourse/v6/atc/db"
    15  	"github.com/pf-qiu/concourse/v6/atc/worker"
    16  	"github.com/gorilla/websocket"
    17  )
    18  
    19  var upgrader = websocket.Upgrader{
    20  	HandshakeTimeout: 5 * time.Second,
    21  }
    22  
    23  type InterceptTimeoutError struct {
    24  	duration time.Duration
    25  }
    26  
    27  func (err InterceptTimeoutError) Error() string {
    28  	return fmt.Sprintf("idle timeout (%s) reached", err.duration)
    29  }
    30  
    31  //go:generate counterfeiter . InterceptTimeoutFactory
    32  
    33  type InterceptTimeoutFactory interface {
    34  	NewInterceptTimeout() InterceptTimeout
    35  }
    36  
    37  func NewInterceptTimeoutFactory(duration time.Duration) InterceptTimeoutFactory {
    38  	return &interceptTimeoutFactory{
    39  		duration: duration,
    40  	}
    41  }
    42  
    43  type interceptTimeoutFactory struct {
    44  	duration time.Duration
    45  }
    46  
    47  func (t *interceptTimeoutFactory) NewInterceptTimeout() InterceptTimeout {
    48  	return &interceptTimeout{
    49  		duration: t.duration,
    50  		timer:    time.NewTimer(t.duration),
    51  	}
    52  }
    53  
    54  //go:generate counterfeiter . InterceptTimeout
    55  
    56  type InterceptTimeout interface {
    57  	Reset()
    58  	Channel() <-chan time.Time
    59  	Error() error
    60  }
    61  
    62  type interceptTimeout struct {
    63  	duration time.Duration
    64  	timer    *time.Timer
    65  }
    66  
    67  func (t *interceptTimeout) Reset() {
    68  	if t.duration > 0 {
    69  		t.timer.Reset(t.duration)
    70  	}
    71  }
    72  
    73  func (t *interceptTimeout) Channel() <-chan time.Time {
    74  	if t.duration > 0 {
    75  		return t.timer.C
    76  	}
    77  	return make(chan time.Time)
    78  }
    79  
    80  func (t *interceptTimeout) Error() error {
    81  	return InterceptTimeoutError{duration: t.duration}
    82  }
    83  
    84  func (s *Server) HijackContainer(team db.Team) http.Handler {
    85  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    86  		handle := r.FormValue(":id")
    87  
    88  		hLog := s.logger.Session("hijack", lager.Data{
    89  			"handle": handle,
    90  		})
    91  
    92  		container, found, err := s.workerClient.FindContainer(hLog, team.ID(), handle)
    93  		if err != nil {
    94  			hLog.Error("failed-to-find-container", err)
    95  			w.WriteHeader(http.StatusInternalServerError)
    96  			return
    97  		}
    98  
    99  		if !found {
   100  			hLog.Info("container-not-found")
   101  			w.WriteHeader(http.StatusNotFound)
   102  			return
   103  		}
   104  
   105  		isCheckContainer, err := team.IsCheckContainer(handle)
   106  		if err != nil {
   107  			hLog.Error("failed-to-find-container", err)
   108  			w.WriteHeader(http.StatusInternalServerError)
   109  			return
   110  		}
   111  
   112  		if isCheckContainer {
   113  			acc := accessor.GetAccessor(r)
   114  			if !acc.IsAdmin() {
   115  				hLog.Error("user-not-authorized-to-hijack-check-container", err)
   116  				w.WriteHeader(http.StatusForbidden)
   117  				return
   118  			}
   119  		}
   120  
   121  		ok, err := team.IsContainerWithinTeam(handle, isCheckContainer)
   122  		if err != nil {
   123  			hLog.Error("failed-to-find-container-within-team", err)
   124  			w.WriteHeader(http.StatusInternalServerError)
   125  			return
   126  		}
   127  
   128  		if !ok {
   129  			hLog.Error("container-not-found-within-team", err)
   130  			w.WriteHeader(http.StatusNotFound)
   131  			return
   132  		}
   133  
   134  		hLog.Debug("found-container")
   135  
   136  		conn, err := upgrader.Upgrade(w, r, nil)
   137  		if err != nil {
   138  			hLog.Error("unable-to-upgrade-connection-for-websockets", err)
   139  			return
   140  		}
   141  
   142  		defer db.Close(conn)
   143  
   144  		var processSpec atc.HijackProcessSpec
   145  		err = conn.ReadJSON(&processSpec)
   146  		if err != nil {
   147  			hLog.Error("malformed-process-spec", err)
   148  			closeWithErr(hLog, conn, websocket.CloseUnsupportedData, fmt.Sprintf("malformed process spec"))
   149  			return
   150  		}
   151  
   152  		hijackRequest := hijackRequest{
   153  			Container: container,
   154  			Process:   processSpec,
   155  		}
   156  
   157  		s.hijack(hLog, conn, hijackRequest)
   158  	})
   159  }
   160  
   161  type hijackRequest struct {
   162  	Container worker.Container
   163  	Process   atc.HijackProcessSpec
   164  }
   165  
   166  func closeWithErr(log lager.Logger, conn *websocket.Conn, code int, reason string) {
   167  	err := conn.WriteControl(
   168  		websocket.CloseMessage,
   169  		websocket.FormatCloseMessage(code, reason),
   170  		time.Time{},
   171  	)
   172  
   173  	if err != nil {
   174  		log.Error("failed-to-close-websocket-connection", err)
   175  	}
   176  }
   177  
   178  func (s *Server) hijack(hLog lager.Logger, conn *websocket.Conn, request hijackRequest) {
   179  	hLog = hLog.Session("hijack", lager.Data{
   180  		"handle":  request.Container.Handle(),
   181  		"process": request.Process,
   182  	})
   183  
   184  	stdinR, stdinW := io.Pipe()
   185  	defer db.Close(stdinW)
   186  
   187  	inputs := make(chan atc.HijackInput)
   188  	outputs := make(chan atc.HijackOutput)
   189  	exited := make(chan int, 1)
   190  	errs := make(chan error, 1)
   191  
   192  	cleanup := make(chan struct{})
   193  	defer close(cleanup)
   194  
   195  	outW := &stdoutWriter{
   196  		outputs: outputs,
   197  		done:    cleanup,
   198  	}
   199  
   200  	errW := &stderrWriter{
   201  		outputs: outputs,
   202  		done:    cleanup,
   203  	}
   204  
   205  	var tty *garden.TTYSpec
   206  	var idle InterceptTimeout
   207  
   208  	if request.Process.TTY != nil {
   209  		tty = &garden.TTYSpec{
   210  			WindowSize: &garden.WindowSize{
   211  				Columns: request.Process.TTY.WindowSize.Columns,
   212  				Rows:    request.Process.TTY.WindowSize.Rows,
   213  			},
   214  		}
   215  	}
   216  
   217  	process, err := request.Container.Run(context.Background(), garden.ProcessSpec{
   218  		Path: request.Process.Path,
   219  		Args: request.Process.Args,
   220  		Env:  request.Process.Env,
   221  		Dir:  request.Process.Dir,
   222  
   223  		User: request.Process.User,
   224  
   225  		TTY: tty,
   226  	}, garden.ProcessIO{
   227  		Stdin:  stdinR,
   228  		Stdout: outW,
   229  		Stderr: errW,
   230  	})
   231  	if err != nil {
   232  		if _, ok := err.(garden.ExecutableNotFoundError); ok {
   233  			hLog.Info("executable-not-found")
   234  
   235  			_ = conn.WriteJSON(atc.HijackOutput{
   236  				ExecutableNotFound: true,
   237  			})
   238  		}
   239  
   240  		_ = conn.WriteJSON(atc.HijackOutput{
   241  			Error: err.Error(),
   242  		})
   243  		hLog.Error("failed-to-hijack", err)
   244  		return
   245  	}
   246  
   247  	err = request.Container.UpdateLastHijack()
   248  	if err != nil {
   249  		hLog.Error("failed-to-update-container-hijack-time", err)
   250  		return
   251  	}
   252  
   253  	go func() {
   254  		for {
   255  			select {
   256  			case <-s.clock.After(s.interceptUpdateInterval):
   257  				err = request.Container.UpdateLastHijack()
   258  				if err != nil {
   259  					hLog.Error("failed-to-update-container-hijack-time", err)
   260  					return
   261  				}
   262  
   263  			case <-cleanup:
   264  				return
   265  			}
   266  		}
   267  	}()
   268  
   269  	hLog.Info("hijacked")
   270  
   271  	go func() {
   272  		for {
   273  			var input atc.HijackInput
   274  			err := conn.ReadJSON(&input)
   275  			if err != nil {
   276  				break
   277  			}
   278  
   279  			select {
   280  			case inputs <- input:
   281  			case <-cleanup:
   282  				return
   283  			}
   284  		}
   285  	}()
   286  
   287  	go func() {
   288  		status, err := process.Wait()
   289  		if err != nil {
   290  			errs <- err
   291  		} else {
   292  			exited <- status
   293  		}
   294  	}()
   295  
   296  	idle = s.interceptTimeoutFactory.NewInterceptTimeout()
   297  	idleChan := idle.Channel()
   298  
   299  	for {
   300  		select {
   301  		case input := <-inputs:
   302  			idle.Reset()
   303  			if input.Closed {
   304  				_ = stdinW.Close()
   305  			} else if input.TTYSpec != nil {
   306  				err := process.SetTTY(garden.TTYSpec{
   307  					WindowSize: &garden.WindowSize{
   308  						Columns: input.TTYSpec.WindowSize.Columns,
   309  						Rows:    input.TTYSpec.WindowSize.Rows,
   310  					},
   311  				})
   312  				if err != nil {
   313  					_ = conn.WriteJSON(atc.HijackOutput{
   314  						Error: err.Error(),
   315  					})
   316  				}
   317  			} else {
   318  				_, _ = stdinW.Write(input.Stdin)
   319  			}
   320  
   321  		case <-idleChan:
   322  			errs <- idle.Error()
   323  
   324  		case output := <-outputs:
   325  			err := conn.WriteJSON(output)
   326  			if err != nil {
   327  				return
   328  			}
   329  
   330  		case status := <-exited:
   331  			_ = conn.WriteJSON(atc.HijackOutput{
   332  				ExitStatus: &status,
   333  			})
   334  
   335  			return
   336  
   337  		case err := <-errs:
   338  			_ = conn.WriteJSON(atc.HijackOutput{
   339  				Error: err.Error(),
   340  			})
   341  
   342  			return
   343  		}
   344  	}
   345  }
   346  
   347  type stdoutWriter struct {
   348  	outputs chan<- atc.HijackOutput
   349  	done    chan struct{}
   350  }
   351  
   352  func (writer *stdoutWriter) Write(b []byte) (int, error) {
   353  	chunk := make([]byte, len(b))
   354  	copy(chunk, b)
   355  
   356  	output := atc.HijackOutput{
   357  		Stdout: chunk,
   358  	}
   359  
   360  	select {
   361  	case writer.outputs <- output:
   362  	case <-writer.done:
   363  	}
   364  
   365  	return len(b), nil
   366  }
   367  
   368  func (writer *stdoutWriter) Close() error {
   369  	close(writer.done)
   370  	return nil
   371  }
   372  
   373  type stderrWriter struct {
   374  	outputs chan<- atc.HijackOutput
   375  	done    chan struct{}
   376  }
   377  
   378  func (writer *stderrWriter) Write(b []byte) (int, error) {
   379  	chunk := make([]byte, len(b))
   380  	copy(chunk, b)
   381  
   382  	output := atc.HijackOutput{
   383  		Stderr: chunk,
   384  	}
   385  
   386  	select {
   387  	case writer.outputs <- output:
   388  	case <-writer.done:
   389  	}
   390  
   391  	return len(b), nil
   392  }
   393  
   394  func (writer *stderrWriter) Close() error {
   395  	close(writer.done)
   396  	return nil
   397  }