github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/auth/auth.go (about) 1 package auth 2 3 import ( 4 "crypto/tls" 5 "encoding/base64" 6 "encoding/hex" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net" 13 "net/http" 14 "os" 15 "strings" 16 "time" 17 18 "github.com/bugfan/de" 19 ) 20 21 var ( 22 ServerIP string 23 AuthURL string 24 defaults map[string]string 25 ) 26 27 func init() { 28 defaults = map[string]string{ 29 "auth_url": "https://wg.lt53.cn", 30 "api_secret": "jwdlhtrh", 31 "ifconfig_url": "http://ifconfig.co", 32 } 33 de.SetKey(Get("api_secret")) 34 de.SetExp(100) 35 AuthURL = Get("auth_url") 36 ServerIP = getIp() 37 } 38 func getIp() string { 39 resp, err := http.Get(Get("ifconfig_url")) 40 if err != nil { 41 fmt.Println("get server ip error:", err) 42 return "" 43 } 44 defer resp.Body.Close() 45 data, _ := ioutil.ReadAll(resp.Body) 46 return strings.Replace(string(data), "\n", "", 1) 47 } 48 func getDefault(key string) string { 49 return defaults[key] 50 } 51 52 func Get(key string) string { 53 env := strings.TrimSpace(os.Getenv(strings.ToUpper(key))) 54 if env != "" { 55 return env 56 } 57 return getDefault(key) 58 } 59 60 type Peer struct { 61 Address string 62 PublicKey string 63 } 64 65 func Verify(clientPublicKey string) (*Peer, error) { 66 bearer, _ := de.EncodeWithBase64() 67 header := make(map[string]string) 68 69 header["Wgkey"] = clientPublicKey 70 header["Wgtoken"] = string(bearer) 71 header["Wgserverip"] = ServerIP 72 73 code, data, err := Request("GET", AuthURL+"/wireguard", header, nil) 74 if code >= 300 || err != nil { 75 return nil, errors.New(fmt.Sprintf("auth:request to auth server wireguard error:code is %v,error is %v\n", code, err)) 76 } 77 peer := &Peer{} 78 json.Unmarshal(data, peer) 79 80 if clientPublicKey != peer.PublicKey { 81 fmt.Println("err equals:", clientPublicKey, peer.PublicKey) 82 return nil, errors.New("key not equals") 83 } 84 85 return peer, nil 86 } 87 88 type Config struct { 89 ListenPort string `json:"wg_listen_port"` 90 PrivateKey string `json:"wg_private_key"` 91 } 92 93 func GetWireguardConfig() (*Config, error) { 94 bearer, _ := de.EncodeWithBase64() 95 header := make(map[string]string) 96 97 header["Wgtoken"] = string(bearer) 98 header["Wgserverip"] = ServerIP 99 100 code, data, err := Request("GET", AuthURL+"/config", header, nil) 101 if code >= 300 || err != nil { 102 errString := fmt.Sprintf("auth:request to auth server config error:code is %v,error is %v\n", code, err) 103 fmt.Println(errString) 104 return nil, errors.New(errString) 105 } 106 conf := &Config{} 107 json.Unmarshal(data, conf) 108 conf.PrivateKey = KeyToHex(conf.PrivateKey) 109 return conf, nil 110 } 111 func atob(data []byte) []byte { 112 // Base64 Standard Decoding 113 sDec, err := base64.StdEncoding.DecodeString(string(data)) 114 if err != nil { 115 fmt.Printf("Error decoding string: %s ", err.Error()) 116 return []byte{} 117 } 118 return sDec 119 } 120 func KeyToHex(key string) string { 121 data := atob([]byte(key)) 122 return hex.EncodeToString(data) 123 } 124 125 var ( 126 client *http.Client 127 ) 128 129 func init() { 130 client = new(http.Client) 131 netTransport := &http.Transport{ 132 Dial: func(netw, addr string) (net.Conn, error) { 133 c, err := net.DialTimeout(netw, addr, time.Second*time.Duration(20)) 134 if err != nil { 135 return nil, err 136 } 137 return c, nil 138 }, 139 DisableKeepAlives: true, 140 MaxIdleConnsPerHost: 20, //每个host最大空闲连接 141 ResponseHeaderTimeout: time.Second * time.Duration(60), //数据收发5秒超时 142 TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, 143 } 144 client.Timeout = time.Second * 30 145 client.Transport = netTransport 146 } 147 func NewHttpClient() *http.Client { 148 return client 149 } 150 func Request(method, target string, headers map[string]string, body io.ReadCloser) (int, []byte, error) { 151 req, _ := http.NewRequest(method, target, body) 152 req.Header.Add("cache-control", "no-cache") 153 req.Close = true 154 for k, v := range headers { 155 req.Header.Set(k, v) 156 } 157 cli := NewHttpClient() 158 res, err := cli.Transport.RoundTrip(req) 159 if err != nil { 160 return -1, nil, err 161 } 162 defer res.Body.Close() 163 data, err := ioutil.ReadAll(res.Body) 164 return res.StatusCode, data, err 165 }