go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/devshell/server.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package devshell implements Devshell protocol for locally getting auth token.
    16  //
    17  // Some Google Cloud tools know how to use it for authentication.
    18  package devshell
    19  
    20  import (
    21  	"bytes"
    22  	"context"
    23  	"encoding/json"
    24  	"fmt"
    25  	"net"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  
    30  	"golang.org/x/oauth2"
    31  
    32  	"go.chromium.org/luci/common/clock"
    33  	"go.chromium.org/luci/common/logging"
    34  	"go.chromium.org/luci/common/runtime/paniccatcher"
    35  
    36  	"go.chromium.org/luci/auth/integration/internal/localsrv"
    37  )
    38  
    39  // EnvKey is the name of the environment variable which contains the Devshell
    40  // server port number which is picked up by Devshell clients.
    41  const EnvKey = "DEVSHELL_CLIENT_PORT"
    42  
    43  // Server runs a Devshell server.
    44  type Server struct {
    45  	// Source is used to obtain OAuth2 tokens.
    46  	Source oauth2.TokenSource
    47  	// Email is the email associated with the token.
    48  	Email string
    49  	// Port is a local TCP port to bind to or 0 to allow the OS to pick one.
    50  	Port int
    51  
    52  	srv localsrv.Server
    53  }
    54  
    55  // Start launches background goroutine with the serving loop.
    56  //
    57  // The provided context is used as base context for request handlers and for
    58  // logging.
    59  //
    60  // The server must be eventually stopped with Stop().
    61  func (s *Server) Start(ctx context.Context) (*net.TCPAddr, error) {
    62  	return s.srv.Start(ctx, "devshell", s.Port, s.serve)
    63  }
    64  
    65  // Stop closes the listening socket, notifies pending requests to abort and
    66  // stops the internal serving goroutine.
    67  //
    68  // Safe to call multiple times. Once stopped, the server cannot be started again
    69  // (make a new instance of Server instead).
    70  //
    71  // Uses the given context for the deadline when waiting for the serving loop
    72  // to stop.
    73  func (s *Server) Stop(ctx context.Context) error {
    74  	return s.srv.Stop(ctx)
    75  }
    76  
    77  // serve runs the serving loop.
    78  func (s *Server) serve(ctx context.Context, l net.Listener, wg *sync.WaitGroup) error {
    79  	for {
    80  		conn, err := l.Accept()
    81  		if err != nil {
    82  			return err
    83  		}
    84  
    85  		client := &client{
    86  			conn:   conn,
    87  			source: s.Source,
    88  			email:  s.Email,
    89  			ctx:    ctx,
    90  		}
    91  
    92  		wg.Add(1)
    93  		go func() {
    94  			defer wg.Done()
    95  
    96  			paniccatcher.Do(func() {
    97  				if err := client.handle(); err != nil {
    98  					logging.Fields{
    99  						logging.ErrorKey: err,
   100  					}.Errorf(client.ctx, "failed to handle client request")
   101  				}
   102  			}, func(p *paniccatcher.Panic) {
   103  				logging.Fields{
   104  					"panicReason": p.Reason,
   105  				}.Errorf(client.ctx, "panic during client handshake:\n%s", p.Stack)
   106  			})
   107  		}()
   108  	}
   109  }
   110  
   111  type client struct {
   112  	conn net.Conn
   113  
   114  	source oauth2.TokenSource
   115  	email  string
   116  
   117  	ctx context.Context
   118  }
   119  
   120  func (c *client) handle() error {
   121  	defer c.conn.Close()
   122  
   123  	if _, err := c.readRequest(); err != nil {
   124  		if err := c.sendResponse([]any{err.Error()}); err != nil {
   125  			return fmt.Errorf("failed to send error: %v", err)
   126  		}
   127  		return nil
   128  	}
   129  
   130  	// Get the token.
   131  	t, err := c.source.Token()
   132  	if err != nil {
   133  		if err := c.sendResponse([]any{"cannot get access token"}); err != nil {
   134  			return fmt.Errorf("failed to send error: %v", err)
   135  		}
   136  		return err
   137  	}
   138  
   139  	// Expiration is in seconds from now so compute the correct format.
   140  	expiry := int(t.Expiry.Sub(clock.Now(c.ctx)).Seconds())
   141  
   142  	return c.sendResponse([]any{c.email, nil, t.AccessToken, expiry})
   143  }
   144  
   145  func (c *client) readRequest() ([]any, error) {
   146  	header := make([]byte, 6)
   147  	if _, err := c.conn.Read(header); err != nil {
   148  		return nil, fmt.Errorf("failed to read the header: %v", err)
   149  	}
   150  
   151  	// The first six bytes contain the length separated by a newline.
   152  	str := strings.SplitN(string(header), "\n", 2)
   153  	if len(str) != 2 {
   154  		return nil, fmt.Errorf("no newline in the first 6 bytes")
   155  	}
   156  
   157  	l, err := strconv.Atoi(str[0])
   158  	if err != nil {
   159  		return nil, fmt.Errorf("length is not a number: %v", err)
   160  	}
   161  
   162  	data := make([]byte, l)
   163  	copy(data, str[1][:])
   164  
   165  	// Read the rest of the message.
   166  	if l > len(str[1]) {
   167  		if _, err := c.conn.Read(data[len(str[1]):]); err != nil {
   168  			return nil, fmt.Errorf("failed to receive request: %v", err)
   169  		}
   170  	}
   171  
   172  	// Parse the message to ensure it's a correct JSON.
   173  	request := []any{}
   174  	if err := json.Unmarshal(data, &request); err != nil {
   175  		return nil, fmt.Errorf("failed to deserialize from JSON: %v", err)
   176  	}
   177  
   178  	return request, nil
   179  }
   180  
   181  func (c *client) sendResponse(response []any) error {
   182  	// Encode the response as JSON array (aka JsPbLite format).
   183  	payload, err := json.Marshal(response)
   184  	if err != nil {
   185  		return fmt.Errorf("failed to serialize to JSON: %v", err)
   186  	}
   187  
   188  	var buf bytes.Buffer
   189  	buf.WriteString(fmt.Sprintf("%d\n", len(payload)))
   190  	buf.Write(payload)
   191  	if _, err := c.conn.Write(buf.Bytes()); err != nil {
   192  		return fmt.Errorf("failed to send response: %v", err)
   193  	}
   194  	return nil
   195  }