github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/sni/sni.go (about)

     1  package sni
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  )
    10  
    11  func ServerNameFromBytes(data []byte) (sn string, err error) {
    12  	reader := bytes.NewReader(data)
    13  	bufferedReader := bufio.NewReader(reader)
    14  	c := bufferedConn{bufferedReader, nil, nil}
    15  	sn, _, err = ServerNameFromConn(c)
    16  	return
    17  }
    18  
    19  type bufferedConn struct {
    20  	r    *bufio.Reader
    21  	rout io.Reader
    22  	net.Conn
    23  }
    24  
    25  func newBufferedConn(c net.Conn) bufferedConn {
    26  	return bufferedConn{bufio.NewReader(c), nil, c}
    27  }
    28  
    29  func (b bufferedConn) Peek(n int) ([]byte, error) {
    30  	return b.r.Peek(n)
    31  }
    32  
    33  func (b bufferedConn) Read(p []byte) (int, error) {
    34  	if b.rout != nil {
    35  		return b.rout.Read(p)
    36  	}
    37  	return b.r.Read(p)
    38  }
    39  
    40  var malformedError = errors.New("malformed client hello")
    41  
    42  func getHello(b []byte) (string, error) {
    43  	rest := b[5:]
    44  
    45  	if len(rest) == 0 {
    46  		return "", malformedError
    47  	}
    48  
    49  	current := 0
    50  	handshakeType := rest[0]
    51  	current += 1
    52  	if handshakeType != 0x1 {
    53  		return "", errors.New("Not a ClientHello")
    54  	}
    55  
    56  	// Skip over another length
    57  	current += 3
    58  	// Skip over protocolversion
    59  	current += 2
    60  	// Skip over random number
    61  	current += 4 + 28
    62  
    63  	if current > len(rest) {
    64  		return "", malformedError
    65  	}
    66  
    67  	// Skip over session ID
    68  	sessionIDLength := int(rest[current])
    69  	current += 1
    70  	current += sessionIDLength
    71  
    72  	if current+1 > len(rest) {
    73  		return "", malformedError
    74  	}
    75  
    76  	cipherSuiteLength := (int(rest[current]) << 8) + int(rest[current+1])
    77  	current += 2
    78  	current += cipherSuiteLength
    79  
    80  	if current > len(rest) {
    81  		return "", malformedError
    82  	}
    83  	compressionMethodLength := int(rest[current])
    84  	current += 1
    85  	current += compressionMethodLength
    86  
    87  	if current > len(rest) {
    88  		return "", errors.New("no extensions")
    89  	}
    90  
    91  	current += 2
    92  
    93  	hostname := ""
    94  	for current+4 < len(rest) && hostname == "" {
    95  		extensionType := (int(rest[current]) << 8) + int(rest[current+1])
    96  		current += 2
    97  
    98  		extensionDataLength := (int(rest[current]) << 8) + int(rest[current+1])
    99  		current += 2
   100  
   101  		if extensionType == 0 {
   102  
   103  			// Skip over number of names as we're assuming there's just one
   104  			current += 2
   105  			if current > len(rest) {
   106  				return "", malformedError
   107  			}
   108  
   109  			nameType := rest[current]
   110  			current += 1
   111  			if nameType != 0 {
   112  				return "", errors.New("Not a hostname")
   113  			}
   114  			if current+1 > len(rest) {
   115  				return "", malformedError
   116  			}
   117  			nameLen := (int(rest[current]) << 8) + int(rest[current+1])
   118  			current += 2
   119  			if current+nameLen > len(rest) {
   120  				return "", malformedError
   121  			}
   122  			hostname = string(rest[current : current+nameLen])
   123  		}
   124  
   125  		current += extensionDataLength
   126  	}
   127  	if hostname == "" {
   128  		return "", errors.New("No hostname")
   129  	}
   130  	return hostname, nil
   131  
   132  }
   133  
   134  func getHelloBytes(c bufferedConn) ([]byte, error) {
   135  	b, err := c.Peek(5)
   136  	if err != nil {
   137  		return []byte{}, err
   138  	}
   139  
   140  	if b[0] != 0x16 {
   141  		return []byte{}, errors.New("not TLS")
   142  	}
   143  
   144  	restLengthBytes := b[3:]
   145  	restLength := (int(restLengthBytes[0]) << 8) + int(restLengthBytes[1])
   146  
   147  	return c.Peek(5 + restLength)
   148  
   149  }
   150  
   151  func getServername(c bufferedConn) (string, []byte, error) {
   152  	all, err := getHelloBytes(c)
   153  	if err != nil {
   154  		return "", nil, err
   155  	}
   156  	name, err := getHello(all)
   157  	if err != nil {
   158  		return "", nil, err
   159  	}
   160  	return name, all, err
   161  
   162  }
   163  
   164  // Uses SNI to get the name of the server from the connection. Returns the ServerName and a buffered connection that will not have been read off of.
   165  func ServerNameFromConn(c net.Conn) (string, net.Conn, error) {
   166  	bufconn := newBufferedConn(c)
   167  	sn, helloBytes, err := getServername(bufconn)
   168  	if err != nil {
   169  		return "", nil, err
   170  	}
   171  	bufconn.rout = io.MultiReader(bytes.NewBuffer(helloBytes), c)
   172  	return sn, bufconn, nil
   173  }