github.com/argoproj/argo-cd/v2@v2.10.9/server/application/websocket.go (about)

     1  package application
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"github.com/argoproj/argo-cd/v2/common"
     7  	httputil "github.com/argoproj/argo-cd/v2/util/http"
     8  	util_session "github.com/argoproj/argo-cd/v2/util/session"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  	log "github.com/sirupsen/logrus"
    15  	"k8s.io/client-go/tools/remotecommand"
    16  )
    17  
    18  const (
    19  	ReconnectCode    = 1
    20  	ReconnectMessage = "\nReconnect because the token was refreshed...\n"
    21  )
    22  
    23  var upgrader = func() websocket.Upgrader {
    24  	upgrader := websocket.Upgrader{}
    25  	upgrader.HandshakeTimeout = time.Second * 2
    26  	upgrader.CheckOrigin = func(r *http.Request) bool {
    27  		return true
    28  	}
    29  	return upgrader
    30  }()
    31  
    32  // terminalSession implements PtyHandler
    33  type terminalSession struct {
    34  	wsConn         *websocket.Conn
    35  	sizeChan       chan remotecommand.TerminalSize
    36  	doneChan       chan struct{}
    37  	tty            bool
    38  	readLock       sync.Mutex
    39  	writeLock      sync.Mutex
    40  	sessionManager *util_session.SessionManager
    41  	token          *string
    42  }
    43  
    44  // getToken get auth token from web socket request
    45  func getToken(r *http.Request) (string, error) {
    46  	cookies := r.Cookies()
    47  	return httputil.JoinCookies(common.AuthCookieName, cookies)
    48  }
    49  
    50  // newTerminalSession create terminalSession
    51  func newTerminalSession(w http.ResponseWriter, r *http.Request, responseHeader http.Header, sessionManager *util_session.SessionManager) (*terminalSession, error) {
    52  	token, err := getToken(r)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	conn, err := upgrader.Upgrade(w, r, responseHeader)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	session := &terminalSession{
    62  		wsConn:         conn,
    63  		tty:            true,
    64  		sizeChan:       make(chan remotecommand.TerminalSize),
    65  		doneChan:       make(chan struct{}),
    66  		sessionManager: sessionManager,
    67  		token:          &token,
    68  	}
    69  	return session, nil
    70  }
    71  
    72  // Done close the done channel.
    73  func (t *terminalSession) Done() {
    74  	close(t.doneChan)
    75  }
    76  
    77  func (t *terminalSession) StartKeepalives(dur time.Duration) {
    78  	ticker := time.NewTicker(dur)
    79  	defer ticker.Stop()
    80  	for {
    81  		select {
    82  		case <-ticker.C:
    83  			err := t.Ping()
    84  			if err != nil {
    85  				log.Errorf("ping error: %v", err)
    86  				return
    87  			}
    88  		case <-t.doneChan:
    89  			return
    90  		}
    91  	}
    92  }
    93  
    94  // Next called in a loop from remotecommand as long as the process is running
    95  func (t *terminalSession) Next() *remotecommand.TerminalSize {
    96  	select {
    97  	case size := <-t.sizeChan:
    98  		return &size
    99  	case <-t.doneChan:
   100  		return nil
   101  	}
   102  }
   103  
   104  // reconnect send reconnect code to client and ask them init new ws session
   105  func (t *terminalSession) reconnect() (int, error) {
   106  	reconnectCommand, _ := json.Marshal(TerminalCommand{
   107  		Code: ReconnectCode,
   108  	})
   109  	reconnectMessage, _ := json.Marshal(TerminalMessage{
   110  		Operation: "stdout",
   111  		Data:      ReconnectMessage,
   112  	})
   113  	t.writeLock.Lock()
   114  	err := t.wsConn.WriteMessage(websocket.TextMessage, reconnectMessage)
   115  	if err != nil {
   116  		log.Errorf("write message err: %v", err)
   117  		return 0, err
   118  	}
   119  	err = t.wsConn.WriteMessage(websocket.TextMessage, reconnectCommand)
   120  	if err != nil {
   121  		log.Errorf("write message err: %v", err)
   122  		return 0, err
   123  	}
   124  	t.writeLock.Unlock()
   125  	return 0, nil
   126  }
   127  
   128  // Read called in a loop from remotecommand as long as the process is running
   129  func (t *terminalSession) Read(p []byte) (int, error) {
   130  	// check if token still valid
   131  	_, newToken, err := t.sessionManager.VerifyToken(*t.token)
   132  	// err in case if token is revoked, newToken in case if refresh happened
   133  	if err != nil || newToken != "" {
   134  		// need to send reconnect code in case if token was refreshed
   135  		return t.reconnect()
   136  	}
   137  
   138  	t.readLock.Lock()
   139  	_, message, err := t.wsConn.ReadMessage()
   140  	t.readLock.Unlock()
   141  	if err != nil {
   142  		log.Errorf("read message err: %v", err)
   143  		return copy(p, EndOfTransmission), err
   144  	}
   145  	var msg TerminalMessage
   146  	if err := json.Unmarshal(message, &msg); err != nil {
   147  		log.Errorf("read parse message err: %v", err)
   148  		return copy(p, EndOfTransmission), err
   149  	}
   150  	switch msg.Operation {
   151  	case "stdin":
   152  		return copy(p, msg.Data), nil
   153  	case "resize":
   154  		t.sizeChan <- remotecommand.TerminalSize{Width: msg.Cols, Height: msg.Rows}
   155  		return 0, nil
   156  	default:
   157  		return copy(p, EndOfTransmission), fmt.Errorf("unknown message type %s", msg.Operation)
   158  	}
   159  }
   160  
   161  // Ping called periodically to ensure connection stays alive through load balancers
   162  func (t *terminalSession) Ping() error {
   163  	t.writeLock.Lock()
   164  	err := t.wsConn.WriteMessage(websocket.PingMessage, []byte("ping"))
   165  	t.writeLock.Unlock()
   166  	if err != nil {
   167  		log.Errorf("ping message err: %v", err)
   168  	}
   169  	return err
   170  }
   171  
   172  // Write called from remotecommand whenever there is any output
   173  func (t *terminalSession) Write(p []byte) (int, error) {
   174  	msg, err := json.Marshal(TerminalMessage{
   175  		Operation: "stdout",
   176  		Data:      string(p),
   177  	})
   178  	if err != nil {
   179  		log.Errorf("write parse message err: %v", err)
   180  		return 0, err
   181  	}
   182  	t.writeLock.Lock()
   183  	err = t.wsConn.WriteMessage(websocket.TextMessage, msg)
   184  	t.writeLock.Unlock()
   185  	if err != nil {
   186  		log.Errorf("write message err: %v", err)
   187  		return 0, err
   188  	}
   189  	return len(p), nil
   190  }
   191  
   192  // Close closes websocket connection
   193  func (t *terminalSession) Close() error {
   194  	return t.wsConn.Close()
   195  }