github.com/Uptycs/basequery-go@v0.8.0/transport/transport.go (about)

     1  // +build !windows
     2  
     3  package transport
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"os"
     9  	"time"
    10  
    11  	"github.com/apache/thrift/lib/go/thrift"
    12  	"github.com/pkg/errors"
    13  )
    14  
    15  // Open opens the unix domain socket with the provided path and timeout,
    16  // returning a TTransport.
    17  func Open(sockPath string, timeout time.Duration) (*thrift.TSocket, error) {
    18  	addr, err := net.ResolveUnixAddr("unix", sockPath)
    19  	if err != nil {
    20  		return nil, errors.Wrapf(err, "resolving socket path '%s'", sockPath)
    21  	}
    22  
    23  	// the timeout parameter is passed to thrift, which passes it to net.DialTimeout
    24  	// but it looks like net.DialTimeout ignores timeouts for unix socket and immediately returns an error
    25  	// waitForSocket will loop every 200ms to stat the socket path,
    26  	// or until the timeout value passes, similar to the C++ and python implementations.
    27  	if err := waitForSocket(sockPath, timeout); err != nil {
    28  		return nil, errors.Wrapf(err, "waiting for unix socket to be available: %s", sockPath)
    29  	}
    30  
    31  	trans := thrift.NewTSocketFromAddrTimeout(addr, timeout, timeout)
    32  	if err := trans.Open(); err != nil {
    33  		return nil, errors.Wrap(err, "opening socket transport")
    34  	}
    35  
    36  	return trans, nil
    37  }
    38  
    39  // OpenServer resolves the specified listenPath and creates new thrift server socket on specified listen path.
    40  func OpenServer(listenPath string, timeout time.Duration) (*thrift.TServerSocket, error) {
    41  	addr, err := net.ResolveUnixAddr("unix", listenPath)
    42  	if err != nil {
    43  		return nil, errors.Wrapf(err, "resolving addr (%s)", addr)
    44  	}
    45  
    46  	return thrift.NewTServerSocketFromAddrTimeout(addr, 0), nil
    47  }
    48  
    49  func waitForSocket(sockPath string, timeout time.Duration) error {
    50  	ticker := time.NewTicker(200 * time.Millisecond)
    51  	defer ticker.Stop()
    52  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
    53  	defer cancel()
    54  	for {
    55  		select {
    56  		case <-ctx.Done():
    57  			return ctx.Err()
    58  		case <-ticker.C:
    59  			if _, err := os.Stat(sockPath); err == nil {
    60  				return nil
    61  			}
    62  		}
    63  	}
    64  }