github.com/bugfan/wireguard-go@v0.0.0-20230720020150-a7b2fa340c66/auth/auth.go (about)

     1  package auth
     2  
     3  import (
     4  	"crypto/tls"
     5  	"encoding/base64"
     6  	"encoding/hex"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net"
    13  	"net/http"
    14  	"os"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/bugfan/de"
    19  )
    20  
    21  var (
    22  	ServerIP string
    23  	AuthURL  string
    24  	defaults map[string]string
    25  )
    26  
    27  func init() {
    28  	defaults = map[string]string{
    29  		"auth_url":     "https://wg.lt53.cn",
    30  		"api_secret":   "jwdlhtrh",
    31  		"ifconfig_url": "http://ifconfig.co",
    32  	}
    33  	de.SetKey(Get("api_secret"))
    34  	de.SetExp(100)
    35  	AuthURL = Get("auth_url")
    36  	ServerIP = getIp()
    37  }
    38  func getIp() string {
    39  	resp, err := http.Get(Get("ifconfig_url"))
    40  	if err != nil {
    41  		fmt.Println("get server ip error:", err)
    42  		return ""
    43  	}
    44  	defer resp.Body.Close()
    45  	data, _ := ioutil.ReadAll(resp.Body)
    46  	return strings.Replace(string(data), "\n", "", 1)
    47  }
    48  func getDefault(key string) string {
    49  	return defaults[key]
    50  }
    51  
    52  func Get(key string) string {
    53  	env := strings.TrimSpace(os.Getenv(strings.ToUpper(key)))
    54  	if env != "" {
    55  		return env
    56  	}
    57  	return getDefault(key)
    58  }
    59  
    60  type Peer struct {
    61  	Address   string
    62  	PublicKey string
    63  }
    64  
    65  func Verify(clientPublicKey string) (*Peer, error) {
    66  	bearer, _ := de.EncodeWithBase64()
    67  	header := make(map[string]string)
    68  
    69  	header["Wgkey"] = clientPublicKey
    70  	header["Wgtoken"] = string(bearer)
    71  	header["Wgserverip"] = ServerIP
    72  
    73  	code, data, err := Request("GET", AuthURL+"/wireguard", header, nil)
    74  	if code >= 300 || err != nil {
    75  		return nil, errors.New(fmt.Sprintf("auth:request to auth server wireguard error:code is %v,error is %v\n", code, err))
    76  	}
    77  	peer := &Peer{}
    78  	json.Unmarshal(data, peer)
    79  
    80  	if clientPublicKey != peer.PublicKey {
    81  		fmt.Println("err equals:", clientPublicKey, peer.PublicKey)
    82  		return nil, errors.New("key not equals")
    83  	}
    84  
    85  	return peer, nil
    86  }
    87  
    88  type Config struct {
    89  	ListenPort string `json:"wg_listen_port"`
    90  	PrivateKey string `json:"wg_private_key"`
    91  }
    92  
    93  func GetWireguardConfig() (*Config, error) {
    94  	bearer, _ := de.EncodeWithBase64()
    95  	header := make(map[string]string)
    96  
    97  	header["Wgtoken"] = string(bearer)
    98  	header["Wgserverip"] = ServerIP
    99  
   100  	code, data, err := Request("GET", AuthURL+"/config", header, nil)
   101  	if code >= 300 || err != nil {
   102  		errString := fmt.Sprintf("auth:request to auth server config error:code is %v,error is %v\n", code, err)
   103  		fmt.Println(errString)
   104  		return nil, errors.New(errString)
   105  	}
   106  	conf := &Config{}
   107  	json.Unmarshal(data, conf)
   108  	conf.PrivateKey = KeyToHex(conf.PrivateKey)
   109  	return conf, nil
   110  }
   111  func atob(data []byte) []byte {
   112  	// Base64 Standard Decoding
   113  	sDec, err := base64.StdEncoding.DecodeString(string(data))
   114  	if err != nil {
   115  		fmt.Printf("Error decoding string: %s ", err.Error())
   116  		return []byte{}
   117  	}
   118  	return sDec
   119  }
   120  func KeyToHex(key string) string {
   121  	data := atob([]byte(key))
   122  	return hex.EncodeToString(data)
   123  }
   124  
   125  var (
   126  	client *http.Client
   127  )
   128  
   129  func init() {
   130  	client = new(http.Client)
   131  	netTransport := &http.Transport{
   132  		Dial: func(netw, addr string) (net.Conn, error) {
   133  			c, err := net.DialTimeout(netw, addr, time.Second*time.Duration(20))
   134  			if err != nil {
   135  				return nil, err
   136  			}
   137  			return c, nil
   138  		},
   139  		DisableKeepAlives:     true,
   140  		MaxIdleConnsPerHost:   20,                              //每个host最大空闲连接
   141  		ResponseHeaderTimeout: time.Second * time.Duration(60), //数据收发5秒超时
   142  		TLSClientConfig:       &tls.Config{InsecureSkipVerify: true},
   143  	}
   144  	client.Timeout = time.Second * 30
   145  	client.Transport = netTransport
   146  }
   147  func NewHttpClient() *http.Client {
   148  	return client
   149  }
   150  func Request(method, target string, headers map[string]string, body io.ReadCloser) (int, []byte, error) {
   151  	req, _ := http.NewRequest(method, target, body)
   152  	req.Header.Add("cache-control", "no-cache")
   153  	req.Close = true
   154  	for k, v := range headers {
   155  		req.Header.Set(k, v)
   156  	}
   157  	cli := NewHttpClient()
   158  	res, err := cli.Transport.RoundTrip(req)
   159  	if err != nil {
   160  		return -1, nil, err
   161  	}
   162  	defer res.Body.Close()
   163  	data, err := ioutil.ReadAll(res.Body)
   164  	return res.StatusCode, data, err
   165  }