github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/sni/sni.go (about) 1 package sni 2 3 import ( 4 "bufio" 5 "bytes" 6 "errors" 7 "io" 8 "net" 9 ) 10 11 func ServerNameFromBytes(data []byte) (sn string, err error) { 12 reader := bytes.NewReader(data) 13 bufferedReader := bufio.NewReader(reader) 14 c := bufferedConn{bufferedReader, nil, nil} 15 sn, _, err = ServerNameFromConn(c) 16 return 17 } 18 19 type bufferedConn struct { 20 r *bufio.Reader 21 rout io.Reader 22 net.Conn 23 } 24 25 func newBufferedConn(c net.Conn) bufferedConn { 26 return bufferedConn{bufio.NewReader(c), nil, c} 27 } 28 29 func (b bufferedConn) Peek(n int) ([]byte, error) { 30 return b.r.Peek(n) 31 } 32 33 func (b bufferedConn) Read(p []byte) (int, error) { 34 if b.rout != nil { 35 return b.rout.Read(p) 36 } 37 return b.r.Read(p) 38 } 39 40 var malformedError = errors.New("malformed client hello") 41 42 func getHello(b []byte) (string, error) { 43 rest := b[5:] 44 45 if len(rest) == 0 { 46 return "", malformedError 47 } 48 49 current := 0 50 handshakeType := rest[0] 51 current += 1 52 if handshakeType != 0x1 { 53 return "", errors.New("Not a ClientHello") 54 } 55 56 // Skip over another length 57 current += 3 58 // Skip over protocolversion 59 current += 2 60 // Skip over random number 61 current += 4 + 28 62 63 if current > len(rest) { 64 return "", malformedError 65 } 66 67 // Skip over session ID 68 sessionIDLength := int(rest[current]) 69 current += 1 70 current += sessionIDLength 71 72 if current+1 > len(rest) { 73 return "", malformedError 74 } 75 76 cipherSuiteLength := (int(rest[current]) << 8) + int(rest[current+1]) 77 current += 2 78 current += cipherSuiteLength 79 80 if current > len(rest) { 81 return "", malformedError 82 } 83 compressionMethodLength := int(rest[current]) 84 current += 1 85 current += compressionMethodLength 86 87 if current > len(rest) { 88 return "", errors.New("no extensions") 89 } 90 91 current += 2 92 93 hostname := "" 94 for current+4 < len(rest) && hostname == "" { 95 extensionType := (int(rest[current]) << 8) + int(rest[current+1]) 96 current += 2 97 98 extensionDataLength := (int(rest[current]) << 8) + int(rest[current+1]) 99 current += 2 100 101 if extensionType == 0 { 102 103 // Skip over number of names as we're assuming there's just one 104 current += 2 105 if current > len(rest) { 106 return "", malformedError 107 } 108 109 nameType := rest[current] 110 current += 1 111 if nameType != 0 { 112 return "", errors.New("Not a hostname") 113 } 114 if current+1 > len(rest) { 115 return "", malformedError 116 } 117 nameLen := (int(rest[current]) << 8) + int(rest[current+1]) 118 current += 2 119 if current+nameLen > len(rest) { 120 return "", malformedError 121 } 122 hostname = string(rest[current : current+nameLen]) 123 } 124 125 current += extensionDataLength 126 } 127 if hostname == "" { 128 return "", errors.New("No hostname") 129 } 130 return hostname, nil 131 132 } 133 134 func getHelloBytes(c bufferedConn) ([]byte, error) { 135 b, err := c.Peek(5) 136 if err != nil { 137 return []byte{}, err 138 } 139 140 if b[0] != 0x16 { 141 return []byte{}, errors.New("not TLS") 142 } 143 144 restLengthBytes := b[3:] 145 restLength := (int(restLengthBytes[0]) << 8) + int(restLengthBytes[1]) 146 147 return c.Peek(5 + restLength) 148 149 } 150 151 func getServername(c bufferedConn) (string, []byte, error) { 152 all, err := getHelloBytes(c) 153 if err != nil { 154 return "", nil, err 155 } 156 name, err := getHello(all) 157 if err != nil { 158 return "", nil, err 159 } 160 return name, all, err 161 162 } 163 164 // Uses SNI to get the name of the server from the connection. Returns the ServerName and a buffered connection that will not have been read off of. 165 func ServerNameFromConn(c net.Conn) (string, net.Conn, error) { 166 bufconn := newBufferedConn(c) 167 sn, helloBytes, err := getServername(bufconn) 168 if err != nil { 169 return "", nil, err 170 } 171 bufconn.rout = io.MultiReader(bytes.NewBuffer(helloBytes), c) 172 return sn, bufconn, nil 173 }