gitee.com/mysnapcore/mysnapd@v0.1.0/daemon/ucrednet.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon
    21  
    22  import (
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"regexp"
    27  	"strconv"
    28  	"sync"
    29  	sys "syscall"
    30  )
    31  
    32  var errNoID = errors.New("no pid/uid found")
    33  
    34  const (
    35  	ucrednetNoProcess = int32(0)
    36  	ucrednetNobody    = uint32((1 << 32) - 1)
    37  )
    38  
    39  var raddrRegexp = regexp.MustCompile(`^pid=(\d+);uid=(\d+);socket=([^;]*);$`)
    40  
    41  var ucrednetGet = ucrednetGetImpl
    42  
    43  func ucrednetGetImpl(remoteAddr string) (*ucrednet, error) {
    44  	// NOTE treat remoteAddr at one point included a user-controlled
    45  	// string. In case that happens again by accident, treat it as tainted,
    46  	// and be very suspicious of it.
    47  	u := &ucrednet{
    48  		Pid: ucrednetNoProcess,
    49  		Uid: ucrednetNobody,
    50  	}
    51  	subs := raddrRegexp.FindStringSubmatch(remoteAddr)
    52  	if subs != nil {
    53  		if v, err := strconv.ParseInt(subs[1], 10, 32); err == nil {
    54  			u.Pid = int32(v)
    55  		}
    56  		if v, err := strconv.ParseUint(subs[2], 10, 32); err == nil {
    57  			u.Uid = uint32(v)
    58  		}
    59  		u.Socket = subs[3]
    60  	}
    61  	if u.Pid == ucrednetNoProcess || u.Uid == ucrednetNobody {
    62  		return nil, errNoID
    63  	}
    64  
    65  	return u, nil
    66  }
    67  
    68  type ucrednet struct {
    69  	Pid    int32
    70  	Uid    uint32
    71  	Socket string
    72  }
    73  
    74  func (un *ucrednet) String() string {
    75  	if un == nil {
    76  		return "pid=;uid=;socket=;"
    77  	}
    78  	return fmt.Sprintf("pid=%d;uid=%d;socket=%s;", un.Pid, un.Uid, un.Socket)
    79  }
    80  
    81  type ucrednetAddr struct {
    82  	net.Addr
    83  	*ucrednet
    84  }
    85  
    86  func (wa *ucrednetAddr) String() string {
    87  	// NOTE we drop the original (user-supplied) net.Addr from the
    88  	// serialization entirely. We carry it this far so it helps debugging
    89  	// (via %#v logging), but from here on in it's not helpful.
    90  	return wa.ucrednet.String()
    91  }
    92  
    93  type ucrednetConn struct {
    94  	net.Conn
    95  	*ucrednet
    96  }
    97  
    98  func (wc *ucrednetConn) RemoteAddr() net.Addr {
    99  	return &ucrednetAddr{wc.Conn.RemoteAddr(), wc.ucrednet}
   100  }
   101  
   102  type ucrednetListener struct {
   103  	net.Listener
   104  
   105  	idempotClose sync.Once
   106  	closeErr     error
   107  }
   108  
   109  var getUcred = sys.GetsockoptUcred
   110  
   111  func (wl *ucrednetListener) Accept() (net.Conn, error) {
   112  	con, err := wl.Listener.Accept()
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	var unet *ucrednet
   118  	if ucon, ok := con.(*net.UnixConn); ok {
   119  		syscallConn, err := ucon.SyscallConn()
   120  		if err != nil {
   121  			return nil, err
   122  		}
   123  
   124  		var ucred *sys.Ucred
   125  		scErr := syscallConn.Control(func(fd uintptr) {
   126  			ucred, err = getUcred(int(fd), sys.SOL_SOCKET, sys.SO_PEERCRED)
   127  		})
   128  		if scErr != nil {
   129  			return nil, scErr
   130  		}
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  
   135  		unet = &ucrednet{
   136  			Pid:    ucred.Pid,
   137  			Uid:    ucred.Uid,
   138  			Socket: ucon.LocalAddr().String(),
   139  		}
   140  	}
   141  
   142  	return &ucrednetConn{con, unet}, nil
   143  }
   144  
   145  func (wl *ucrednetListener) Close() error {
   146  	wl.idempotClose.Do(func() {
   147  		wl.closeErr = wl.Listener.Close()
   148  	})
   149  	return wl.closeErr
   150  }