github.com/puellanivis/breton@v0.2.16/lib/files/socketfiles/socket.go (about)

     1  // Package socketfiles implements the "tcp:", "udp:", and "unix:" URL schemes.
     2  package socketfiles
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/url"
     9  	"strconv"
    10  	"syscall"
    11  
    12  	"golang.org/x/net/ipv4"
    13  )
    14  
    15  var (
    16  	errInvalidURL = errors.New("invalid url")
    17  	errInvalidIP  = errors.New("invalid ip")
    18  )
    19  
    20  // URL query field keys.
    21  const (
    22  	FieldBufferSize    = "buffer_size"
    23  	FieldLocalAddress  = "localaddr"
    24  	FieldLocalPort     = "localport"
    25  	FieldMaxBitrate    = "max_bitrate"
    26  	FieldMaxPacketSize = "max_pkt_size"
    27  	FieldPacketSize    = "pkt_size"
    28  	FieldTOS           = "tos"
    29  	FieldTTL           = "ttl"
    30  )
    31  
    32  type socket struct {
    33  	conn net.Conn
    34  
    35  	addr, qaddr net.Addr
    36  
    37  	bufferSize    int
    38  	packetSize    int
    39  	maxPacketSize int
    40  
    41  	tos, ttl int
    42  
    43  	throttler
    44  }
    45  
    46  func (s *socket) uri() *url.URL {
    47  	q := s.uriQuery()
    48  
    49  	switch qaddr := s.qaddr.(type) {
    50  	case *net.TCPAddr:
    51  		q.Set(FieldLocalAddress, qaddr.IP.String())
    52  		q.Set(FieldLocalPort, strconv.Itoa(qaddr.Port))
    53  
    54  	case *net.UDPAddr:
    55  		q.Set(FieldLocalAddress, qaddr.IP.String())
    56  		q.Set(FieldLocalPort, strconv.Itoa(qaddr.Port))
    57  
    58  	case *net.UnixAddr:
    59  		q.Set(FieldLocalAddress, qaddr.String())
    60  	}
    61  
    62  	host, path := s.addr.String(), ""
    63  
    64  	switch s.addr.Network() {
    65  	case "unix", "unixgram", "unixpacket":
    66  		host, path = "", host
    67  	}
    68  
    69  	return &url.URL{
    70  		Scheme:   s.addr.Network(),
    71  		Host:     host,
    72  		Path:     path,
    73  		RawQuery: q.Encode(),
    74  	}
    75  }
    76  
    77  func (s *socket) uriQuery() url.Values {
    78  	q := make(url.Values)
    79  
    80  	if s.bitrate > 0 {
    81  		q.Set(FieldMaxBitrate, strconv.Itoa(s.bitrate))
    82  	}
    83  
    84  	if s.bufferSize > 0 {
    85  		q.Set(FieldBufferSize, strconv.Itoa(s.bufferSize))
    86  	}
    87  
    88  	network := s.addr.Network()
    89  
    90  	switch network {
    91  	case "udp", "udp4", "udp6", "unixgram", "unixpacket":
    92  		if s.packetSize > 0 {
    93  			q.Set(FieldPacketSize, strconv.Itoa(s.packetSize))
    94  		}
    95  		if s.maxPacketSize > 0 {
    96  			q.Set(FieldMaxPacketSize, strconv.Itoa(s.maxPacketSize))
    97  		}
    98  	}
    99  
   100  	switch network {
   101  	case "udp", "udp4", "tcp", "tcp4":
   102  		if s.tos > 0 {
   103  			q.Set(FieldTOS, "0x"+strconv.FormatInt(int64(s.tos), 16))
   104  		}
   105  
   106  		if s.ttl > 0 {
   107  			q.Set(FieldTTL, strconv.Itoa(s.ttl))
   108  		}
   109  	}
   110  
   111  	return q
   112  }
   113  
   114  func sockReader(conn net.Conn, q url.Values) (*socket, error) {
   115  	bufferSize, err := getSize(q, FieldBufferSize)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	if bufferSize > 0 {
   121  		type readBufferSetter interface {
   122  			SetReadBuffer(int) error
   123  		}
   124  
   125  		conn, ok := conn.(readBufferSetter)
   126  		if !ok {
   127  			return nil, syscall.EINVAL
   128  		}
   129  
   130  		if err := conn.SetReadBuffer(bufferSize); err != nil {
   131  			return nil, err
   132  		}
   133  	}
   134  
   135  	laddr := conn.LocalAddr()
   136  
   137  	var maxPacketSize int
   138  	switch laddr.Network() {
   139  	case "udp", "udp4", "udp6", "unixgram", "unixpacket":
   140  		maxPacketSize, err = getSize(q, FieldMaxPacketSize)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  	}
   145  
   146  	return &socket{
   147  		conn: conn,
   148  
   149  		addr: conn.LocalAddr(),
   150  
   151  		bufferSize:    bufferSize,
   152  		maxPacketSize: maxPacketSize,
   153  	}, nil
   154  }
   155  
   156  func sockWriter(conn net.Conn, showLocalAddr bool, q url.Values) (*socket, error) {
   157  	raddr := conn.RemoteAddr()
   158  
   159  	bufferSize, err := getSize(q, FieldBufferSize)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	if bufferSize > 0 {
   165  		type writeBufferSetter interface {
   166  			SetWriteBuffer(int) error
   167  		}
   168  
   169  		conn, ok := conn.(writeBufferSetter)
   170  		if !ok {
   171  			return nil, syscall.EINVAL
   172  		}
   173  
   174  		if err := conn.SetWriteBuffer(bufferSize); err != nil {
   175  			return nil, err
   176  		}
   177  	}
   178  
   179  	var packetSize int
   180  	switch raddr.Network() {
   181  	case "udp", "udp4", "udp6", "unixgram", "unixpacket":
   182  		packetSize, err = getSize(q, FieldPacketSize)
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  	}
   187  
   188  	bitrate, err := getSize(q, FieldMaxBitrate)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	var t throttler
   194  	if bitrate > 0 {
   195  		t.setBitrate(bitrate, packetSize)
   196  	}
   197  
   198  	var tos, ttl int
   199  
   200  	switch raddr.Network() {
   201  	case "udp", "udp4", "tcp", "tcp4":
   202  		var p *ipv4.Conn
   203  
   204  		tos, err = getInt(q, FieldTOS)
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  
   209  		if tos > 0 {
   210  			if p == nil {
   211  				p = ipv4.NewConn(conn)
   212  			}
   213  
   214  			if err := p.SetTOS(tos); err != nil {
   215  				return nil, err
   216  			}
   217  
   218  			tos, _ = p.TOS()
   219  		}
   220  
   221  		ttl, err = getInt(q, FieldTTL)
   222  		if err != nil {
   223  			return nil, err
   224  		}
   225  
   226  		if ttl > 0 {
   227  			if p == nil {
   228  				p = ipv4.NewConn(conn)
   229  			}
   230  
   231  			if err := p.SetTTL(ttl); err != nil {
   232  				return nil, err
   233  			}
   234  
   235  			ttl, _ = p.TTL()
   236  		}
   237  	}
   238  
   239  	var laddr net.Addr
   240  	if showLocalAddr {
   241  		laddr = conn.LocalAddr()
   242  	}
   243  
   244  	return &socket{
   245  		conn: conn,
   246  
   247  		addr:  raddr,
   248  		qaddr: laddr,
   249  
   250  		bufferSize: bufferSize,
   251  		packetSize: packetSize,
   252  
   253  		tos: tos,
   254  		ttl: ttl,
   255  
   256  		throttler: t,
   257  	}, nil
   258  }
   259  
   260  var scales = map[byte]int{
   261  	'G': 1000000000,
   262  	'g': 1000000000,
   263  	'M': 1000000,
   264  	'm': 1000000,
   265  	'K': 1000,
   266  	'k': 1000,
   267  }
   268  
   269  func getSize(q url.Values, field string) (val int, err error) {
   270  	value := q.Get(field)
   271  	if value == "" {
   272  		return 0, nil
   273  	}
   274  
   275  	suffix := value[len(value)-1]
   276  
   277  	scale := 1
   278  	if s := scales[suffix]; s > 0 {
   279  		scale = s
   280  		value = value[:len(value)-1]
   281  	}
   282  
   283  	i, err := strconv.ParseInt(value, 0, strconv.IntSize)
   284  	if err != nil {
   285  		return 0, err
   286  	}
   287  
   288  	return int(i) * scale, nil
   289  }
   290  
   291  func getInt(q url.Values, field string) (val int, err error) {
   292  	value := q.Get(field)
   293  	if value == "" {
   294  		return 0, nil
   295  	}
   296  
   297  	i, err := strconv.ParseInt(value, 0, strconv.IntSize)
   298  	if err != nil {
   299  		return 0, err
   300  	}
   301  
   302  	return int(i), nil
   303  }
   304  
   305  func do(ctx context.Context, fn func() error) error {
   306  	done := make(chan struct{})
   307  
   308  	var err error
   309  	go func() {
   310  		defer close(done)
   311  
   312  		err = fn()
   313  	}()
   314  
   315  	select {
   316  	case <-done:
   317  	case <-ctx.Done():
   318  		return ctx.Err()
   319  	}
   320  
   321  	return err
   322  }