github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/libkb/socket.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package libkb
     5  
     6  import (
     7  	"fmt"
     8  	"net"
     9  
    10  	"github.com/keybase/client/go/logger"
    11  	"github.com/keybase/client/go/protocol/keybase1"
    12  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    13  )
    14  
    15  // NewSocket() (Socket, err) is defined in the various platform-specific socket_*.go files.
    16  type Socket interface {
    17  	BindToSocket() (net.Listener, error)
    18  	DialSocket() (net.Conn, error)
    19  }
    20  
    21  type SocketInfo struct {
    22  	log       logger.Logger
    23  	bindFile  string
    24  	dialFiles []string
    25  	testOwner bool //nolint
    26  }
    27  
    28  func (s SocketInfo) GetBindFile() string {
    29  	return s.bindFile
    30  }
    31  
    32  func (s SocketInfo) GetDialFiles() []string {
    33  	return s.dialFiles
    34  }
    35  
    36  type SocketWrapper struct {
    37  	Conn        net.Conn
    38  	Transporter rpc.Transporter
    39  	Err         error
    40  }
    41  
    42  func (g *GlobalContext) MakeLoopbackServer() (l net.Listener, err error) {
    43  	g.socketWrapperMu.Lock()
    44  	defer g.socketWrapperMu.Unlock()
    45  	g.LoopbackListener = NewLoopbackListener(g)
    46  	l = g.LoopbackListener
    47  	return l, err
    48  }
    49  
    50  func (g *GlobalContext) BindToSocket() (net.Listener, error) {
    51  	return g.SocketInfo.BindToSocket()
    52  }
    53  
    54  func NewTransportFromSocket(g *GlobalContext, s net.Conn, src keybase1.NetworkSource) rpc.Transporter {
    55  	return rpc.NewTransport(s, NewRPCLogFactory(g), NetworkInstrumenterStorageFromSrc(g, src), MakeWrapError(g), rpc.DefaultMaxFrameLength)
    56  }
    57  
    58  // ResetSocket clears and returns a new socket
    59  func (g *GlobalContext) ResetSocket(clearError bool) (net.Conn, rpc.Transporter, bool, error) {
    60  	g.socketWrapperMu.Lock()
    61  	defer g.socketWrapperMu.Unlock()
    62  
    63  	g.SocketWrapper = nil
    64  	return g.getSocketLocked(clearError)
    65  }
    66  
    67  func (g *GlobalContext) GetSocket(clearError bool) (conn net.Conn, xp rpc.Transporter, isNew bool, err error) {
    68  	g.Trace("GetSocket", &err)()
    69  	g.socketWrapperMu.Lock()
    70  	defer g.socketWrapperMu.Unlock()
    71  	return g.getSocketLocked(clearError)
    72  }
    73  
    74  func (g *GlobalContext) getSocketLocked(clearError bool) (conn net.Conn, xp rpc.Transporter, isNew bool, err error) {
    75  	needWrapper := false
    76  	if g.SocketWrapper == nil {
    77  		needWrapper = true
    78  		g.Log.Debug("| empty socket wrapper; need a new one")
    79  	} else if g.SocketWrapper.Transporter != nil && !g.SocketWrapper.Transporter.IsConnected() {
    80  		// need reconnect
    81  		g.Log.Debug("| rpc transport isn't connected, reconnecting...")
    82  		needWrapper = true
    83  	}
    84  
    85  	if needWrapper {
    86  		sw := SocketWrapper{}
    87  		if g.LoopbackListener != nil {
    88  			sw.Conn, sw.Err = g.LoopbackListener.Dial()
    89  		} else if g.SocketInfo == nil {
    90  			sw.Err = fmt.Errorf("Cannot get socket in standalone mode")
    91  		} else {
    92  			sw.Conn, sw.Err = g.SocketInfo.DialSocket()
    93  			g.Log.Debug("| DialSocket -> %s", ErrToOk(sw.Err))
    94  			isNew = true
    95  		}
    96  		if sw.Err == nil {
    97  			sw.Transporter = NewTransportFromSocket(g, sw.Conn, keybase1.NetworkSource_LOCAL)
    98  		}
    99  		g.SocketWrapper = &sw
   100  	}
   101  
   102  	sw := g.SocketWrapper
   103  	if sw.Err != nil && clearError {
   104  		g.SocketWrapper = nil
   105  	}
   106  	err = sw.Err
   107  
   108  	return sw.Conn, sw.Transporter, isNew, err
   109  }