github.com/puellanivis/breton@v0.2.16/lib/files/socketfiles/socket.go (about) 1 // Package socketfiles implements the "tcp:", "udp:", and "unix:" URL schemes. 2 package socketfiles 3 4 import ( 5 "context" 6 "errors" 7 "net" 8 "net/url" 9 "strconv" 10 "syscall" 11 12 "golang.org/x/net/ipv4" 13 ) 14 15 var ( 16 errInvalidURL = errors.New("invalid url") 17 errInvalidIP = errors.New("invalid ip") 18 ) 19 20 // URL query field keys. 21 const ( 22 FieldBufferSize = "buffer_size" 23 FieldLocalAddress = "localaddr" 24 FieldLocalPort = "localport" 25 FieldMaxBitrate = "max_bitrate" 26 FieldMaxPacketSize = "max_pkt_size" 27 FieldPacketSize = "pkt_size" 28 FieldTOS = "tos" 29 FieldTTL = "ttl" 30 ) 31 32 type socket struct { 33 conn net.Conn 34 35 addr, qaddr net.Addr 36 37 bufferSize int 38 packetSize int 39 maxPacketSize int 40 41 tos, ttl int 42 43 throttler 44 } 45 46 func (s *socket) uri() *url.URL { 47 q := s.uriQuery() 48 49 switch qaddr := s.qaddr.(type) { 50 case *net.TCPAddr: 51 q.Set(FieldLocalAddress, qaddr.IP.String()) 52 q.Set(FieldLocalPort, strconv.Itoa(qaddr.Port)) 53 54 case *net.UDPAddr: 55 q.Set(FieldLocalAddress, qaddr.IP.String()) 56 q.Set(FieldLocalPort, strconv.Itoa(qaddr.Port)) 57 58 case *net.UnixAddr: 59 q.Set(FieldLocalAddress, qaddr.String()) 60 } 61 62 host, path := s.addr.String(), "" 63 64 switch s.addr.Network() { 65 case "unix", "unixgram", "unixpacket": 66 host, path = "", host 67 } 68 69 return &url.URL{ 70 Scheme: s.addr.Network(), 71 Host: host, 72 Path: path, 73 RawQuery: q.Encode(), 74 } 75 } 76 77 func (s *socket) uriQuery() url.Values { 78 q := make(url.Values) 79 80 if s.bitrate > 0 { 81 q.Set(FieldMaxBitrate, strconv.Itoa(s.bitrate)) 82 } 83 84 if s.bufferSize > 0 { 85 q.Set(FieldBufferSize, strconv.Itoa(s.bufferSize)) 86 } 87 88 network := s.addr.Network() 89 90 switch network { 91 case "udp", "udp4", "udp6", "unixgram", "unixpacket": 92 if s.packetSize > 0 { 93 q.Set(FieldPacketSize, strconv.Itoa(s.packetSize)) 94 } 95 if s.maxPacketSize > 0 { 96 q.Set(FieldMaxPacketSize, strconv.Itoa(s.maxPacketSize)) 97 } 98 } 99 100 switch network { 101 case "udp", "udp4", "tcp", "tcp4": 102 if s.tos > 0 { 103 q.Set(FieldTOS, "0x"+strconv.FormatInt(int64(s.tos), 16)) 104 } 105 106 if s.ttl > 0 { 107 q.Set(FieldTTL, strconv.Itoa(s.ttl)) 108 } 109 } 110 111 return q 112 } 113 114 func sockReader(conn net.Conn, q url.Values) (*socket, error) { 115 bufferSize, err := getSize(q, FieldBufferSize) 116 if err != nil { 117 return nil, err 118 } 119 120 if bufferSize > 0 { 121 type readBufferSetter interface { 122 SetReadBuffer(int) error 123 } 124 125 conn, ok := conn.(readBufferSetter) 126 if !ok { 127 return nil, syscall.EINVAL 128 } 129 130 if err := conn.SetReadBuffer(bufferSize); err != nil { 131 return nil, err 132 } 133 } 134 135 laddr := conn.LocalAddr() 136 137 var maxPacketSize int 138 switch laddr.Network() { 139 case "udp", "udp4", "udp6", "unixgram", "unixpacket": 140 maxPacketSize, err = getSize(q, FieldMaxPacketSize) 141 if err != nil { 142 return nil, err 143 } 144 } 145 146 return &socket{ 147 conn: conn, 148 149 addr: conn.LocalAddr(), 150 151 bufferSize: bufferSize, 152 maxPacketSize: maxPacketSize, 153 }, nil 154 } 155 156 func sockWriter(conn net.Conn, showLocalAddr bool, q url.Values) (*socket, error) { 157 raddr := conn.RemoteAddr() 158 159 bufferSize, err := getSize(q, FieldBufferSize) 160 if err != nil { 161 return nil, err 162 } 163 164 if bufferSize > 0 { 165 type writeBufferSetter interface { 166 SetWriteBuffer(int) error 167 } 168 169 conn, ok := conn.(writeBufferSetter) 170 if !ok { 171 return nil, syscall.EINVAL 172 } 173 174 if err := conn.SetWriteBuffer(bufferSize); err != nil { 175 return nil, err 176 } 177 } 178 179 var packetSize int 180 switch raddr.Network() { 181 case "udp", "udp4", "udp6", "unixgram", "unixpacket": 182 packetSize, err = getSize(q, FieldPacketSize) 183 if err != nil { 184 return nil, err 185 } 186 } 187 188 bitrate, err := getSize(q, FieldMaxBitrate) 189 if err != nil { 190 return nil, err 191 } 192 193 var t throttler 194 if bitrate > 0 { 195 t.setBitrate(bitrate, packetSize) 196 } 197 198 var tos, ttl int 199 200 switch raddr.Network() { 201 case "udp", "udp4", "tcp", "tcp4": 202 var p *ipv4.Conn 203 204 tos, err = getInt(q, FieldTOS) 205 if err != nil { 206 return nil, err 207 } 208 209 if tos > 0 { 210 if p == nil { 211 p = ipv4.NewConn(conn) 212 } 213 214 if err := p.SetTOS(tos); err != nil { 215 return nil, err 216 } 217 218 tos, _ = p.TOS() 219 } 220 221 ttl, err = getInt(q, FieldTTL) 222 if err != nil { 223 return nil, err 224 } 225 226 if ttl > 0 { 227 if p == nil { 228 p = ipv4.NewConn(conn) 229 } 230 231 if err := p.SetTTL(ttl); err != nil { 232 return nil, err 233 } 234 235 ttl, _ = p.TTL() 236 } 237 } 238 239 var laddr net.Addr 240 if showLocalAddr { 241 laddr = conn.LocalAddr() 242 } 243 244 return &socket{ 245 conn: conn, 246 247 addr: raddr, 248 qaddr: laddr, 249 250 bufferSize: bufferSize, 251 packetSize: packetSize, 252 253 tos: tos, 254 ttl: ttl, 255 256 throttler: t, 257 }, nil 258 } 259 260 var scales = map[byte]int{ 261 'G': 1000000000, 262 'g': 1000000000, 263 'M': 1000000, 264 'm': 1000000, 265 'K': 1000, 266 'k': 1000, 267 } 268 269 func getSize(q url.Values, field string) (val int, err error) { 270 value := q.Get(field) 271 if value == "" { 272 return 0, nil 273 } 274 275 suffix := value[len(value)-1] 276 277 scale := 1 278 if s := scales[suffix]; s > 0 { 279 scale = s 280 value = value[:len(value)-1] 281 } 282 283 i, err := strconv.ParseInt(value, 0, strconv.IntSize) 284 if err != nil { 285 return 0, err 286 } 287 288 return int(i) * scale, nil 289 } 290 291 func getInt(q url.Values, field string) (val int, err error) { 292 value := q.Get(field) 293 if value == "" { 294 return 0, nil 295 } 296 297 i, err := strconv.ParseInt(value, 0, strconv.IntSize) 298 if err != nil { 299 return 0, err 300 } 301 302 return int(i), nil 303 } 304 305 func do(ctx context.Context, fn func() error) error { 306 done := make(chan struct{}) 307 308 var err error 309 go func() { 310 defer close(done) 311 312 err = fn() 313 }() 314 315 select { 316 case <-done: 317 case <-ctx.Done(): 318 return ctx.Err() 319 } 320 321 return err 322 }