github.com/Cloud-Foundations/Dominator@v0.3.4/lib/srpc/localTransport.go (about)

     1  package srpc
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"flag"
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"runtime"
    12  	"sync"
    13  )
    14  
    15  const (
    16  	unixClientCookieLength = 8
    17  	unixServerCookieLength = 8
    18  	unixBufferSize         = 1 << 16
    19  )
    20  
    21  var (
    22  	srpcUnixSocketPath = flag.String("srpcUnixSocketPath",
    23  		defaultUnixSocketPath(),
    24  		"Pathname for server Unix sockets")
    25  
    26  	unixCookieToConnMapLock sync.Mutex
    27  	unixCookieToConn        map[[unixServerCookieLength]byte]net.Conn
    28  	unixListenerSetup       sync.Once
    29  )
    30  
    31  type localUpgradeToUnixRequestOne struct {
    32  	ClientCookie []byte
    33  }
    34  
    35  type localUpgradeToUnixResponseOne struct {
    36  	Error          string
    37  	ServerCookie   []byte
    38  	SocketPathname string
    39  }
    40  
    41  type localUpgradeToUnixRequestTwo struct {
    42  	SentServerCookie bool
    43  }
    44  
    45  type localUpgradeToUnixResponseTwo struct {
    46  	Error string
    47  }
    48  
    49  func acceptUnix(conn net.Conn,
    50  	unixCookieToConn map[[unixServerCookieLength]byte]net.Conn) {
    51  	doClose := true
    52  	defer func() {
    53  		if doClose {
    54  			conn.Close()
    55  		}
    56  	}()
    57  	var cookie [unixServerCookieLength]byte
    58  	if length, err := conn.Read(cookie[:]); err != nil {
    59  		return
    60  	} else if length != unixServerCookieLength {
    61  		return
    62  	}
    63  	unixCookieToConnMapLock.Lock()
    64  	if _, ok := unixCookieToConn[cookie]; ok {
    65  		unixCookieToConnMapLock.Unlock()
    66  		return
    67  	}
    68  	unixCookieToConn[cookie] = conn
    69  	unixCookieToConnMapLock.Unlock()
    70  	var ack [1]byte
    71  	length, err := conn.Write(ack[:])
    72  	if err != nil || length != 1 {
    73  		unixCookieToConnMapLock.Lock()
    74  		delete(unixCookieToConn, cookie)
    75  		unixCookieToConnMapLock.Unlock()
    76  		return
    77  	}
    78  	doClose = false
    79  }
    80  
    81  func acceptUnixLoop(l net.Listener,
    82  	unixCookieToConn map[[unixServerCookieLength]byte]net.Conn) {
    83  	defer l.Close()
    84  	for {
    85  		conn, err := l.Accept()
    86  		if err != nil {
    87  			fmt.Fprintf(os.Stderr, "Error accepting Unix connection: %s\n", err)
    88  			return
    89  		}
    90  		go acceptUnix(conn, unixCookieToConn)
    91  	}
    92  }
    93  
    94  func defaultUnixSocketPath() string {
    95  	if runtime.GOOS != "linux" {
    96  		return ""
    97  	}
    98  	return fmt.Sprintf("@SRPC.%d", os.Getpid())
    99  }
   100  
   101  func isLocal(client *Client) bool {
   102  	lhost, _, err := net.SplitHostPort(client.localAddr)
   103  	if err != nil {
   104  		return false
   105  	}
   106  	rhost, _, err := net.SplitHostPort(client.remoteAddr)
   107  	if err != nil {
   108  		return false
   109  	}
   110  	return lhost == rhost
   111  }
   112  
   113  func setupUnixListener() {
   114  	if *srpcUnixSocketPath == "" {
   115  		return
   116  	}
   117  	if (*srpcUnixSocketPath)[0] != '@' {
   118  		os.Remove(*srpcUnixSocketPath)
   119  	}
   120  	l, err := net.Listen("unix", *srpcUnixSocketPath)
   121  	if err != nil {
   122  		fmt.Fprintf(os.Stderr, "Error listening on Unix socket: %s\n", err)
   123  		return
   124  	}
   125  	unixCookieToConn = make(map[[unixServerCookieLength]byte]net.Conn)
   126  	go acceptUnixLoop(l, unixCookieToConn)
   127  }
   128  
   129  func (*builtinReceiver) LocalUpgradeToUnix(conn *Conn) error {
   130  	unixListenerSetup.Do(setupUnixListener)
   131  	var requestOne localUpgradeToUnixRequestOne
   132  	if err := conn.Decode(&requestOne); err != nil {
   133  		return err
   134  	}
   135  	if *srpcUnixSocketPath == "" || unixCookieToConn == nil {
   136  		return conn.Encode(localUpgradeToUnixResponseOne{Error: "no socket"})
   137  	}
   138  	var cookie [unixServerCookieLength]byte
   139  	if length, err := rand.Read(cookie[:]); err != nil {
   140  		return conn.Encode(localUpgradeToUnixResponseOne{Error: err.Error()})
   141  	} else if length != unixServerCookieLength {
   142  		return conn.Encode(localUpgradeToUnixResponseOne{Error: "bad length"})
   143  	}
   144  	err := conn.Encode(localUpgradeToUnixResponseOne{
   145  		ServerCookie:   cookie[:],
   146  		SocketPathname: *srpcUnixSocketPath,
   147  	})
   148  	if err != nil {
   149  		return err
   150  	}
   151  	if err := conn.Flush(); err != nil {
   152  		return err
   153  	}
   154  	var requestTwo localUpgradeToUnixRequestTwo
   155  	if err := conn.Decode(&requestTwo); err != nil {
   156  		return err
   157  	}
   158  	if !requestTwo.SentServerCookie {
   159  		return nil
   160  	}
   161  	unixCookieToConnMapLock.Lock()
   162  	newConn, ok := unixCookieToConn[cookie]
   163  	unixCookieToConnMapLock.Unlock()
   164  	doClose := true
   165  	defer func() {
   166  		if doClose && newConn != nil {
   167  			newConn.Close()
   168  		}
   169  	}()
   170  	if !ok {
   171  		return conn.Encode(
   172  			localUpgradeToUnixResponseTwo{Error: "cookie not found"})
   173  	}
   174  	if err := conn.Encode(localUpgradeToUnixResponseTwo{}); err != nil {
   175  		return err
   176  	}
   177  	if err := conn.Flush(); err != nil {
   178  		return err
   179  	}
   180  	if length, err := newConn.Write(requestOne.ClientCookie); err != nil {
   181  		return err
   182  	} else if length != len(requestOne.ClientCookie) {
   183  		return fmt.Errorf("could not write full client cookie")
   184  	}
   185  	doClose = false
   186  	conn.conn.Close()
   187  	conn.conn = newConn
   188  	conn.ReadWriter = bufio.NewReadWriter(
   189  		bufio.NewReaderSize(newConn, unixBufferSize),
   190  		bufio.NewWriterSize(newConn, unixBufferSize))
   191  	logger.Debugf(0, "upgraded connection from: %s to Unix\n", conn.remoteAddr)
   192  	return nil
   193  }
   194  
   195  func (client *Client) localAttemptUpgradeToUnix() (bool, error) {
   196  	if !isLocal(client) {
   197  		return false, nil
   198  	}
   199  	var cookie [unixClientCookieLength]byte
   200  	if length, err := rand.Read(cookie[:]); err != nil {
   201  		return false, nil
   202  	} else if length != unixClientCookieLength {
   203  		return false, nil
   204  	}
   205  	conn, err := client.Call(".LocalUpgradeToUnix")
   206  	if err != nil {
   207  		return false, nil
   208  	}
   209  	defer conn.Close()
   210  	defer conn.Flush()
   211  	err = conn.Encode(localUpgradeToUnixRequestOne{ClientCookie: cookie[:]})
   212  	if err != nil {
   213  		return false, err
   214  	}
   215  	if err := conn.Flush(); err != nil {
   216  		return false, err
   217  	}
   218  	var replyOne localUpgradeToUnixResponseOne
   219  	if err := conn.Decode(&replyOne); err != nil {
   220  		return false, err
   221  	}
   222  	if replyOne.Error != "" {
   223  		return false, nil
   224  	}
   225  	newConn, err := net.Dial("unix", replyOne.SocketPathname)
   226  	if err != nil {
   227  		conn.Encode(localUpgradeToUnixRequestTwo{})
   228  		logger.Println(err)
   229  		return false, nil
   230  	}
   231  	doClose := true
   232  	defer func() {
   233  		if doClose {
   234  			newConn.Close()
   235  		}
   236  	}()
   237  	if length, err := newConn.Write(replyOne.ServerCookie); err != nil {
   238  		conn.Encode(localUpgradeToUnixRequestTwo{})
   239  		return false, err
   240  	} else if length != len(replyOne.ServerCookie) {
   241  		conn.Encode(localUpgradeToUnixRequestTwo{})
   242  		return false, fmt.Errorf("bad cookie length: %d", length)
   243  	}
   244  	var ack [1]byte
   245  	if length, err := newConn.Read(ack[:]); err != nil {
   246  		conn.Encode(localUpgradeToUnixRequestTwo{})
   247  		return false, err
   248  	} else if length != 1 {
   249  		conn.Encode(localUpgradeToUnixRequestTwo{})
   250  		return false, fmt.Errorf("bad ack length: %d", length)
   251  	}
   252  	err = conn.Encode(localUpgradeToUnixRequestTwo{SentServerCookie: true})
   253  	if err != nil {
   254  		return false, err
   255  	}
   256  	if err := conn.Flush(); err != nil {
   257  		return false, err
   258  	}
   259  	var replyTwo localUpgradeToUnixResponseTwo
   260  	if err := conn.Decode(&replyTwo); err != nil {
   261  		return false, err
   262  	}
   263  	if replyTwo.Error != "" {
   264  		return false, nil
   265  	}
   266  	returnedClientCookie := make([]byte, len(cookie))
   267  	if length, err := newConn.Read(returnedClientCookie); err != nil {
   268  		return false, err
   269  	} else if length != len(cookie) {
   270  		return false, fmt.Errorf("bad returned cookie length: %d", length)
   271  	}
   272  	if !bytes.Equal(returnedClientCookie, cookie[:]) {
   273  		return false, fmt.Errorf("returned client cookie does not match")
   274  	}
   275  	doClose = false
   276  	client.conn.Close()
   277  	client.conn = newConn
   278  	client.connType = "Unix"
   279  	client.tcpConn = nil
   280  	client.bufrw = bufio.NewReadWriter(
   281  		bufio.NewReaderSize(newConn, unixBufferSize),
   282  		bufio.NewWriterSize(newConn, unixBufferSize))
   283  	return true, nil
   284  }