go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/auth/integration/internal/localsrv/localsrv.go (about)

     1  // Copyright 2017 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package localsrv provides helpers for running local TCP servers.
    16  //
    17  // It is used by various machine-local authentication protocols to launch
    18  // local listening servers.
    19  package localsrv
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"net"
    25  	"sync"
    26  	"time"
    27  
    28  	"go.chromium.org/luci/common/clock"
    29  	"go.chromium.org/luci/common/errors"
    30  	"go.chromium.org/luci/common/logging"
    31  )
    32  
    33  // Server runs a local TCP server.
    34  type Server struct {
    35  	l        sync.Mutex
    36  	name     string             // name passed to Start
    37  	listener net.Listener       // to know what to stop in killServe, nil after that
    38  	wg       sync.WaitGroup     // +1 for each request being processed now
    39  	ctx      context.Context    // derived from ctx in Start, never resets to nil after that
    40  	cancel   context.CancelFunc // cancels 'ctx'
    41  	stopped  chan struct{}      // closed when serve() goroutine stops
    42  }
    43  
    44  // ServeFunc is called from internal goroutine to run the server loop.
    45  //
    46  // When server stops, the given listener will be closed and the given context
    47  // will be canceled. The wait group is used to wait for pending requests:
    48  // increment it when starting processing a request, and decrement when done.
    49  //
    50  // If ServeFunc returns after the listener is closed, the returned error is
    51  // ignored (it is most likely caused by the closed listener).
    52  type ServeFunc func(c context.Context, l net.Listener, wg *sync.WaitGroup) error
    53  
    54  // Start launches background goroutine with the serving loop 'serve'.
    55  //
    56  // Returns the address the listening socket is bound to.
    57  //
    58  // The provided context is used as base context for request handlers and for
    59  // logging. 'name' identifies this server in logs, and 'port' specifies a TCP
    60  // port number to bind to (or 0 to auto-pick one).
    61  //
    62  // The server must be eventually stopped with Stop().
    63  func (s *Server) Start(ctx context.Context, name string, port int, serve ServeFunc) (*net.TCPAddr, error) {
    64  	s.l.Lock()
    65  	defer s.l.Unlock()
    66  
    67  	if s.ctx != nil {
    68  		return nil, errors.New("already initialized")
    69  	}
    70  
    71  	ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
    72  	if err != nil {
    73  		return nil, errors.Annotate(err, "failed to create listening socket").Err()
    74  	}
    75  
    76  	s.name = name
    77  	s.ctx, s.cancel = context.WithCancel(ctx)
    78  	s.listener = ln
    79  
    80  	// Start serving in background.
    81  	s.stopped = make(chan struct{})
    82  	go func() {
    83  		defer close(s.stopped)
    84  		if err := s.serve(serve); err != nil {
    85  			logging.WithError(err).Errorf(s.ctx, "Unexpected error in the server loop of %q", s.name)
    86  		}
    87  	}()
    88  
    89  	return ln.Addr().(*net.TCPAddr), nil
    90  }
    91  
    92  // Stop closes the listening socket, notifies pending requests to abort and
    93  // stops the internal serving goroutine.
    94  //
    95  // Safe to call multiple times. Once stopped, the server cannot be started again
    96  // (make a new instance of Server instead).
    97  //
    98  // Uses the given context for the deadline when waiting for the serving loop
    99  // to stop.
   100  func (s *Server) Stop(ctx context.Context) error {
   101  	// Close the socket. It notifies the serving loop to stop.
   102  	if err := s.killServe(); err != nil {
   103  		return err
   104  	}
   105  
   106  	// Wait for the serving loop to actually stop.
   107  	select {
   108  	case <-s.stopped:
   109  		logging.Debugf(ctx, "The local server %q has stopped", s.name)
   110  	case <-clock.After(ctx, 10*time.Second):
   111  		logging.Errorf(ctx, "Giving up waiting for the local server %q to stop", s.name)
   112  	}
   113  
   114  	return nil
   115  }
   116  
   117  // serve runs the serving loop.
   118  //
   119  // It unblocks once killServe is called and all pending requests are served.
   120  //
   121  // Returns nil if serving was stopped by killServe or non-nil if it failed for
   122  // some other reason.
   123  func (s *Server) serve(cb ServeFunc) error {
   124  	s.l.Lock()
   125  	if s.listener == nil {
   126  		s.l.Unlock()
   127  		return errors.New("already closed")
   128  	}
   129  	listener := s.listener // accessed outside the lock
   130  	ctx := s.ctx
   131  	s.l.Unlock()
   132  
   133  	err := cb(ctx, listener, &s.wg) // blocks until killServe() is called
   134  	s.wg.Wait()                     // waits for all pending requests
   135  
   136  	// If it was a planned shutdown with killServe(), ignore the error. It says
   137  	// that the listening socket was closed.
   138  	s.l.Lock()
   139  	if s.listener == nil {
   140  		err = nil
   141  	}
   142  	s.l.Unlock()
   143  
   144  	if err != nil {
   145  		return errors.Annotate(err, "error in the serving loop").Err()
   146  	}
   147  	return nil
   148  }
   149  
   150  // killServe notifies the serving goroutine to stop (if it is running).
   151  func (s *Server) killServe() error {
   152  	s.l.Lock()
   153  	defer s.l.Unlock()
   154  
   155  	if s.ctx == nil {
   156  		return errors.New("not initialized")
   157  	}
   158  
   159  	// Stop accepting requests, unblocks serve(). Do it only once.
   160  	if s.listener != nil {
   161  		logging.Debugf(s.ctx, "Stopping the local server %q...", s.name)
   162  		if err := s.listener.Close(); err != nil {
   163  			logging.WithError(err).Errorf(s.ctx, "Failed to close the listening socket of %q", s.name)
   164  		}
   165  		s.listener = nil
   166  	}
   167  	s.cancel() // notify all running handlers to stop
   168  
   169  	return nil
   170  }