github.com/bepass-org/wireguard-go@v1.0.4-rc2.0.20240304192354-ebce6572bc24/warp/account.go (about)

     1  package warp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"net"
    12  	"net/http"
    13  	"os"
    14  	"path/filepath"
    15  	"time"
    16  )
    17  
    18  const (
    19  	apiVersion    = "v0a1922"
    20  	apiURL        = "https://api.cloudflareclient.com"
    21  	regURL        = apiURL + "/" + apiVersion + "/reg"
    22  	_identityFile = "wgcf-identity.json"
    23  	_profileFile  = "wgcf-profile.ini"
    24  )
    25  
    26  var (
    27  	identityFile = "wgcf-identity.json"
    28  	profileFile  = "wgcf-profile.ini"
    29  	dnsAddresses = []string{"8.8.8.8", "8.8.4.4"}
    30  	dc           = 0
    31  )
    32  
    33  var defaultHeaders = makeDefaultHeaders()
    34  var client = makeClient()
    35  
    36  type AccountData struct {
    37  	AccountID   string `json:"account_id"`
    38  	AccessToken string `json:"access_token"`
    39  	PrivateKey  string `json:"private_key"`
    40  	LicenseKey  string `json:"license_key"`
    41  }
    42  
    43  type ConfigurationData struct {
    44  	LocalAddressIPv4    string `json:"local_address_ipv4"`
    45  	LocalAddressIPv6    string `json:"local_address_ipv6"`
    46  	EndpointAddressHost string `json:"endpoint_address_host"`
    47  	EndpointAddressIPv4 string `json:"endpoint_address_ipv4"`
    48  	EndpointAddressIPv6 string `json:"endpoint_address_ipv6"`
    49  	EndpointPublicKey   string `json:"endpoint_public_key"`
    50  	WarpEnabled         bool   `json:"warp_enabled"`
    51  	AccountType         string `json:"account_type"`
    52  	WarpPlusEnabled     bool   `json:"warp_plus_enabled"`
    53  	LicenseKeyUpdated   bool   `json:"license_key_updated"`
    54  }
    55  
    56  func makeDefaultHeaders() map[string]string {
    57  	return map[string]string{
    58  		"User-Agent":        "okhttp/3.12.1",
    59  		"CF-Client-Version": "a-6.3-1922",
    60  	}
    61  }
    62  
    63  func makeClient() *http.Client {
    64  	// Create a custom dialer using the TLS config
    65  	plainDialer := &net.Dialer{
    66  		Timeout:   5 * time.Second,
    67  		KeepAlive: 5 * time.Second,
    68  	}
    69  	tlsDialer := Dialer{}
    70  	// Create a custom HTTP transport
    71  	transport := &http.Transport{
    72  		DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
    73  			return tlsDialer.TLSDial(plainDialer, network, addr)
    74  		},
    75  	}
    76  
    77  	// Create a custom HTTP client using the transport
    78  	return &http.Client{
    79  		Transport: transport,
    80  		// Other client configurations can be added here
    81  	}
    82  }
    83  
    84  func MergeMaps(maps ...map[string]string) map[string]string {
    85  	out := make(map[string]string)
    86  
    87  	for _, m := range maps {
    88  		for k, v := range m {
    89  			out[k] = v
    90  		}
    91  	}
    92  
    93  	return out
    94  }
    95  
    96  func getConfigURL(accountID string) string {
    97  	return fmt.Sprintf("%s/%s", regURL, accountID)
    98  }
    99  
   100  func getAccountURL(accountID string) string {
   101  	return fmt.Sprintf("%s/account", getConfigURL(accountID))
   102  }
   103  
   104  func getDevicesURL(accountID string) string {
   105  	return fmt.Sprintf("%s/devices", getAccountURL(accountID))
   106  }
   107  
   108  func getAccountRegURL(accountID, deviceToken string) string {
   109  	return fmt.Sprintf("%s/reg/%s", getAccountURL(accountID), deviceToken)
   110  }
   111  
   112  func getTimestamp() string {
   113  	timestamp := time.Now().Format(time.RFC3339Nano)
   114  	return timestamp
   115  }
   116  
   117  func genKeyPair() (string, string, error) {
   118  	// Generate private key
   119  	priv, err := GeneratePrivateKey()
   120  	if err != nil {
   121  		fmt.Println("Error generating private key:", err)
   122  		return "", "", err
   123  	}
   124  	privateKey := priv.String()
   125  	publicKey := priv.PublicKey().String()
   126  	return privateKey, publicKey, nil
   127  }
   128  
   129  func doRegister() (*AccountData, error) {
   130  	timestamp := getTimestamp()
   131  	privateKey, publicKey, err := genKeyPair()
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	data := map[string]interface{}{
   136  		"install_id": "",
   137  		"fcm_token":  "",
   138  		"tos":        timestamp,
   139  		"key":        publicKey,
   140  		"type":       "Android",
   141  		"model":      "PC",
   142  		"locale":     "en_US",
   143  	}
   144  
   145  	headers := map[string]string{
   146  		"Content-Type": "application/json; charset=UTF-8",
   147  	}
   148  
   149  	jsonBody, _ := json.Marshal(data)
   150  
   151  	req, err := http.NewRequest("POST", regURL, bytes.NewBuffer(jsonBody))
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	// Set headers
   157  	for k, v := range MergeMaps(defaultHeaders, headers) {
   158  		req.Header.Set(k, v)
   159  	}
   160  
   161  	// Create HTTP client and execute request
   162  	response, err := client.Do(req)
   163  	if err != nil {
   164  		fmt.Println("sending request to remote server", err)
   165  		return nil, err
   166  	}
   167  
   168  	// convert response to byte array
   169  	responseData, err := io.ReadAll(response.Body)
   170  	if err != nil {
   171  		fmt.Println("reading response body", err)
   172  		return nil, err
   173  	}
   174  
   175  	var rspData interface{}
   176  
   177  	err = json.Unmarshal(responseData, &rspData)
   178  	if err != nil {
   179  		fmt.Println("Error:", err)
   180  		return nil, err
   181  	}
   182  
   183  	m := rspData.(map[string]interface{})
   184  
   185  	return &AccountData{
   186  		AccountID:   m["id"].(string),
   187  		AccessToken: m["token"].(string),
   188  		PrivateKey:  privateKey,
   189  		LicenseKey:  m["account"].(map[string]interface{})["license"].(string),
   190  	}, nil
   191  }
   192  
   193  func saveIdentity(accountData *AccountData, identityPath string) error {
   194  	file, err := os.Create(identityPath)
   195  	if err != nil {
   196  		fmt.Println("Error:", err)
   197  		return err
   198  	}
   199  
   200  	encoder := json.NewEncoder(file)
   201  	encoder.SetIndent("", "    ")
   202  	err = encoder.Encode(accountData)
   203  	if err != nil {
   204  		fmt.Println("Error:", err)
   205  		return err
   206  	}
   207  
   208  	return file.Close()
   209  }
   210  
   211  func loadIdentity(identityPath string) (accountData *AccountData, err error) {
   212  	file, err := os.Open(identityPath)
   213  	if err != nil {
   214  		fmt.Println("Error:", err)
   215  		return nil, err
   216  	}
   217  
   218  	defer func(file *os.File) {
   219  		err = file.Close()
   220  		if err != nil {
   221  			fmt.Println("Error:", err)
   222  		}
   223  	}(file)
   224  
   225  	accountData = &AccountData{}
   226  	decoder := json.NewDecoder(file)
   227  	err = decoder.Decode(&accountData)
   228  	if err != nil {
   229  		fmt.Println("Error:", err)
   230  		return nil, err
   231  	}
   232  
   233  	return accountData, nil
   234  }
   235  
   236  func enableWarp(accountData *AccountData) error {
   237  	data := map[string]interface{}{
   238  		"warp_enabled": true,
   239  	}
   240  
   241  	jsonData, _ := json.Marshal(data)
   242  
   243  	url := getConfigURL(accountData.AccountID)
   244  
   245  	req, err := http.NewRequest("PATCH", url, bytes.NewBuffer(jsonData))
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	// Set headers
   251  	headers := map[string]string{
   252  		"Authorization": "Bearer " + accountData.AccessToken,
   253  		"Content-Type":  "application/json; charset=UTF-8",
   254  	}
   255  
   256  	for k, v := range MergeMaps(defaultHeaders, headers) {
   257  		req.Header.Set(k, v)
   258  	}
   259  
   260  	resp, err := client.Do(req)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	defer resp.Body.Close()
   265  
   266  	if resp.StatusCode != http.StatusOK {
   267  		return fmt.Errorf("error enabling WARP, status %d", resp.StatusCode)
   268  	}
   269  
   270  	var response map[string]interface{}
   271  	err = json.NewDecoder(resp.Body).Decode(&response)
   272  	if err != nil {
   273  		return err
   274  	}
   275  
   276  	if !response["warp_enabled"].(bool) {
   277  		return errors.New("warp not enabled")
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  func getServerConf(accountData *AccountData) (*ConfigurationData, error) {
   284  
   285  	req, err := http.NewRequest("GET", getConfigURL(accountData.AccountID), nil)
   286  	if err != nil {
   287  		return nil, err
   288  	}
   289  
   290  	// Set headers
   291  	headers := map[string]string{
   292  		"Authorization": "Bearer " + accountData.AccessToken,
   293  		"Content-Type":  "application/json; charset=UTF-8",
   294  	}
   295  
   296  	for k, v := range MergeMaps(defaultHeaders, headers) {
   297  		req.Header.Set(k, v)
   298  	}
   299  
   300  	resp, err := client.Do(req)
   301  	if err != nil {
   302  		return nil, err
   303  	}
   304  	defer resp.Body.Close()
   305  
   306  	if resp.StatusCode != http.StatusOK {
   307  		return nil, fmt.Errorf("error getting config, status %d", resp.StatusCode)
   308  	}
   309  
   310  	var response map[string]interface{}
   311  	err = json.NewDecoder(resp.Body).Decode(&response)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  
   316  	addresses := response["config"].(map[string]interface{})["interface"].(map[string]interface{})["addresses"]
   317  	lv4 := addresses.(map[string]interface{})["v4"].(string)
   318  	lv6 := addresses.(map[string]interface{})["v6"].(string)
   319  
   320  	peer := response["config"].(map[string]interface{})["peers"].([]interface{})[0].(map[string]interface{})
   321  	publicKey := peer["public_key"].(string)
   322  
   323  	endpoint := peer["endpoint"].(map[string]interface{})
   324  	host := endpoint["host"].(string)
   325  	v4 := endpoint["v4"].(string)
   326  	v6 := endpoint["v6"].(string)
   327  
   328  	account, ok := response["account"].(map[string]interface{})
   329  	if !ok {
   330  		account = make(map[string]interface{})
   331  	}
   332  
   333  	warpEnabled := response["warp_enabled"].(bool)
   334  
   335  	return &ConfigurationData{
   336  		LocalAddressIPv4:    lv4,
   337  		LocalAddressIPv6:    lv6,
   338  		EndpointAddressHost: host,
   339  		EndpointAddressIPv4: v4,
   340  		EndpointAddressIPv6: v6,
   341  		EndpointPublicKey:   publicKey,
   342  		WarpEnabled:         warpEnabled,
   343  		AccountType:         account["account_type"].(string),
   344  		WarpPlusEnabled:     account["warp_plus"].(bool),
   345  		LicenseKeyUpdated:   false, // omit for brevity
   346  	}, nil
   347  }
   348  
   349  func updateLicenseKey(accountData *AccountData, confData *ConfigurationData) (bool, error) {
   350  
   351  	if confData.AccountType == "free" && accountData.LicenseKey != "" {
   352  
   353  		data := map[string]interface{}{
   354  			"license": accountData.LicenseKey,
   355  		}
   356  
   357  		jsonData, _ := json.Marshal(data)
   358  
   359  		url := getAccountURL(accountData.AccountID)
   360  
   361  		req, err := http.NewRequest("PUT", url, bytes.NewBuffer(jsonData))
   362  		if err != nil {
   363  			return false, err
   364  		}
   365  
   366  		// Set headers
   367  		headers := map[string]string{
   368  			"Authorization": "Bearer " + accountData.AccessToken,
   369  			"Content-Type":  "application/json; charset=UTF-8",
   370  		}
   371  
   372  		for k, v := range MergeMaps(defaultHeaders, headers) {
   373  			req.Header.Set(k, v)
   374  		}
   375  
   376  		resp, err := client.Do(req)
   377  		if err != nil {
   378  			return false, err
   379  		}
   380  		defer resp.Body.Close()
   381  
   382  		if resp.StatusCode != http.StatusOK {
   383  			s, _ := io.ReadAll(resp.Body)
   384  			return false, fmt.Errorf("activation error, status %d %s", resp.StatusCode, string(s))
   385  		}
   386  
   387  		var activationResp map[string]interface{}
   388  		err = json.NewDecoder(resp.Body).Decode(&activationResp)
   389  		if err != nil {
   390  			return false, err
   391  		}
   392  
   393  		return activationResp["warp_plus"].(bool), nil
   394  
   395  	} else if confData.AccountType == "unlimited" {
   396  		return true, nil
   397  	}
   398  
   399  	return false, nil
   400  }
   401  
   402  func getDeviceActive(accountData *AccountData) (bool, error) {
   403  
   404  	req, err := http.NewRequest("GET", getDevicesURL(accountData.AccountID), nil)
   405  	if err != nil {
   406  		return false, err
   407  	}
   408  
   409  	// Set headers
   410  	headers := map[string]string{
   411  		"Authorization": "Bearer " + accountData.AccessToken,
   412  		"Accept":        "application/json",
   413  	}
   414  
   415  	for k, v := range MergeMaps(defaultHeaders, headers) {
   416  		req.Header.Set(k, v)
   417  	}
   418  
   419  	resp, err := client.Do(req)
   420  	if err != nil {
   421  		return false, err
   422  	}
   423  	defer resp.Body.Close()
   424  
   425  	if resp.StatusCode != http.StatusOK {
   426  		return false, fmt.Errorf("error getting devices, status %d", resp.StatusCode)
   427  	}
   428  
   429  	var devices []map[string]interface{}
   430  	json.NewDecoder(resp.Body).Decode(&devices)
   431  
   432  	for _, d := range devices {
   433  		if d["id"] == accountData.AccountID {
   434  			active := d["active"].(bool)
   435  			return active, nil
   436  		}
   437  	}
   438  
   439  	return false, nil
   440  }
   441  
   442  func setDeviceActive(accountData *AccountData, status bool) (bool, error) {
   443  
   444  	data := map[string]interface{}{
   445  		"active": status,
   446  	}
   447  
   448  	jsonData, _ := json.Marshal(data)
   449  
   450  	url := getAccountRegURL(accountData.AccountID, accountData.AccountID)
   451  
   452  	req, err := http.NewRequest("PATCH", url, bytes.NewBuffer(jsonData))
   453  	if err != nil {
   454  		return false, err
   455  	}
   456  
   457  	// Set headers
   458  	headers := map[string]string{
   459  		"Authorization": "Bearer " + accountData.AccessToken,
   460  		"Accept":        "application/json",
   461  	}
   462  
   463  	for k, v := range MergeMaps(defaultHeaders, headers) {
   464  		req.Header.Set(k, v)
   465  	}
   466  
   467  	resp, err := client.Do(req)
   468  	if err != nil {
   469  		return false, err
   470  	}
   471  	defer resp.Body.Close()
   472  
   473  	if resp.StatusCode != http.StatusOK {
   474  		return false, fmt.Errorf("error setting active status, status %d", resp.StatusCode)
   475  	}
   476  
   477  	var devices []map[string]interface{}
   478  	json.NewDecoder(resp.Body).Decode(&devices)
   479  
   480  	for _, d := range devices {
   481  		if d["id"] == accountData.AccountID {
   482  			return d["active"].(bool), nil
   483  		}
   484  	}
   485  
   486  	return false, nil
   487  }
   488  
   489  func getWireguardConfig(privateKey, address1, address2, publicKey, endpoint string) string {
   490  
   491  	var buffer bytes.Buffer
   492  
   493  	buffer.WriteString("[Interface]\n")
   494  	buffer.WriteString(fmt.Sprintf("PrivateKey = %s\n", privateKey))
   495  	buffer.WriteString(fmt.Sprintf("DNS = %s\n", dnsAddresses[dc%len(dnsAddresses)]))
   496  	dc++
   497  	buffer.WriteString(fmt.Sprintf("Address = %s\n", address1+"/24"))
   498  	buffer.WriteString(fmt.Sprintf("Address = %s\n", address2+"/128"))
   499  
   500  	buffer.WriteString("[Peer]\n")
   501  	buffer.WriteString(fmt.Sprintf("PublicKey = %s\n", publicKey))
   502  	buffer.WriteString("AllowedIPs = 0.0.0.0/0\n")
   503  	buffer.WriteString("AllowedIPs = ::/0\n")
   504  	buffer.WriteString(fmt.Sprintf("Endpoint = %s\n", endpoint))
   505  
   506  	return buffer.String()
   507  }
   508  
   509  func createConf(accountData *AccountData, confData *ConfigurationData) error {
   510  
   511  	config := getWireguardConfig(accountData.PrivateKey, confData.LocalAddressIPv4,
   512  		confData.LocalAddressIPv6, confData.EndpointPublicKey, confData.EndpointAddressHost)
   513  
   514  	return os.WriteFile(profileFile, []byte(config), 0600)
   515  }
   516  
   517  func LoadOrCreateIdentity(license string) error {
   518  	var accountData *AccountData
   519  
   520  	if _, err := os.Stat(identityFile); os.IsNotExist(err) {
   521  		fmt.Println("Creating new identity...")
   522  		accountData, err = doRegister()
   523  		if err != nil {
   524  			return err
   525  		}
   526  		accountData.LicenseKey = license
   527  		saveIdentity(accountData, identityFile)
   528  	} else {
   529  		fmt.Println("Loading existing identity...")
   530  		accountData, err = loadIdentity(identityFile)
   531  		if err != nil {
   532  			return err
   533  		}
   534  	}
   535  
   536  	fmt.Println("Getting configuration...")
   537  	confData, err := getServerConf(accountData)
   538  	if err != nil {
   539  		return err
   540  	}
   541  
   542  	// updating license key
   543  	fmt.Println("Updating account license key...")
   544  	result, err := updateLicenseKey(accountData, confData)
   545  	if err != nil {
   546  		return err
   547  	}
   548  	if result {
   549  		confData, err = getServerConf(accountData)
   550  		if err != nil {
   551  			return err
   552  		}
   553  	}
   554  
   555  	deviceStatus, err := getDeviceActive(accountData)
   556  	if err != nil {
   557  		return err
   558  	}
   559  	if !deviceStatus {
   560  		fmt.Println("This device is not registered to the account!")
   561  	}
   562  
   563  	if confData.WarpPlusEnabled && !deviceStatus {
   564  		fmt.Println("Enabling device...")
   565  		deviceStatus, _ = setDeviceActive(accountData, true)
   566  	}
   567  
   568  	if !confData.WarpEnabled {
   569  		fmt.Println("Enabling Warp...")
   570  		err := enableWarp(accountData)
   571  		if err != nil {
   572  			return err
   573  		}
   574  		confData.WarpEnabled = true
   575  	}
   576  
   577  	fmt.Printf("Warp+ enabled: %t\n", confData.WarpPlusEnabled)
   578  	fmt.Printf("Device activated: %t\n", deviceStatus)
   579  	fmt.Printf("Account type: %s\n", confData.AccountType)
   580  	fmt.Printf("Warp+ enabled: %t\n", confData.WarpPlusEnabled)
   581  
   582  	fmt.Println("Creating WireGuard configuration...")
   583  	err = createConf(accountData, confData)
   584  	if err != nil {
   585  		return fmt.Errorf("unable to enable write config file, Error: %v", err.Error())
   586  	}
   587  
   588  	fmt.Println("All done! Find your files here:")
   589  	fmt.Println(filepath.Abs(identityFile))
   590  	fmt.Println(filepath.Abs(profileFile))
   591  	return nil
   592  }
   593  
   594  func fileExist(f string) bool {
   595  	if _, err := os.Stat(f); os.IsNotExist(err) {
   596  		return false
   597  	}
   598  	return true
   599  }
   600  func removeFile(f string) {
   601  	if fileExist(f) {
   602  		e := os.Remove(f)
   603  		if e != nil {
   604  			log.Fatal(e)
   605  		}
   606  	}
   607  }
   608  
   609  func UpdatePath(path string) {
   610  	identityFile = path + "/" + _identityFile
   611  	profileFile = path + "/" + _profileFile
   612  }
   613  
   614  func CheckProfileExists(license string) bool {
   615  	isOk := true
   616  	if !fileExist(identityFile) || !fileExist(profileFile) {
   617  		isOk = false
   618  	}
   619  
   620  	ad := &AccountData{} // Read errors caught by unmarshal
   621  	if isOk {
   622  		fileBytes, _ := os.ReadFile(identityFile)
   623  		err := json.Unmarshal(fileBytes, ad)
   624  		if err != nil {
   625  			isOk = false
   626  		} else if license != "notset" && ad.LicenseKey != license {
   627  			isOk = false
   628  		}
   629  	}
   630  	if !isOk {
   631  		removeFile(profileFile)
   632  		removeFile(identityFile)
   633  	}
   634  	return isOk
   635  }
   636  
   637  func RemoveDevice(account AccountData) error {
   638  
   639  	headers := map[string]string{
   640  		"Content-Type":      "application/json",
   641  		"User-Agent":        "okhttp/3.12.1",
   642  		"CF-Client-Version": "a-6.30-3596",
   643  		"Authorization":     "Bearer " + account.AccessToken,
   644  	}
   645  
   646  	req, err := http.NewRequest("DELETE", "https://api.cloudflareclient.com/v0a3596/reg/"+account.AccountID, nil)
   647  	if err != nil {
   648  		return err
   649  	}
   650  
   651  	// Set headers
   652  	for k, v := range MergeMaps(defaultHeaders, headers) {
   653  		req.Header.Set(k, v)
   654  	}
   655  
   656  	// Create HTTP client and execute request
   657  	response, err := client.Do(req)
   658  	if err != nil {
   659  		fmt.Println("sending request to remote server", err)
   660  		return err
   661  	}
   662  
   663  	if response.StatusCode != 204 {
   664  		return fmt.Errorf("error in deleting account %d %s", response.StatusCode, response.Status)
   665  	}
   666  
   667  	return nil
   668  }