golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/conf/parser.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package conf
     7  
     8  import (
     9  	"encoding/base64"
    10  	"net/netip"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"golang.org/x/sys/windows"
    15  	"golang.org/x/text/encoding/unicode"
    16  
    17  	"golang.zx2c4.com/wireguard/windows/driver"
    18  	"golang.zx2c4.com/wireguard/windows/l18n"
    19  )
    20  
    21  type ParseError struct {
    22  	why      string
    23  	offender string
    24  }
    25  
    26  func (e *ParseError) Error() string {
    27  	return l18n.Sprintf("%s: %q", e.why, e.offender)
    28  }
    29  
    30  func parseIPCidr(s string) (netip.Prefix, error) {
    31  	ipcidr, err := netip.ParsePrefix(s)
    32  	if err == nil {
    33  		return ipcidr, nil
    34  	}
    35  	addr, err := netip.ParseAddr(s)
    36  	if err != nil {
    37  		return netip.Prefix{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s}
    38  	}
    39  	return netip.PrefixFrom(addr, addr.BitLen()), nil
    40  }
    41  
    42  func parseEndpoint(s string) (*Endpoint, error) {
    43  	i := strings.LastIndexByte(s, ':')
    44  	if i < 0 {
    45  		return nil, &ParseError{l18n.Sprintf("Missing port from endpoint"), s}
    46  	}
    47  	host, portStr := s[:i], s[i+1:]
    48  	if len(host) < 1 {
    49  		return nil, &ParseError{l18n.Sprintf("Invalid endpoint host"), host}
    50  	}
    51  	port, err := parsePort(portStr)
    52  	if err != nil {
    53  		return nil, err
    54  	}
    55  	hostColon := strings.IndexByte(host, ':')
    56  	if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
    57  		err := &ParseError{l18n.Sprintf("Brackets must contain an IPv6 address"), host}
    58  		if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
    59  			end := len(host) - 1
    60  			if i := strings.LastIndexByte(host, '%'); i > 1 {
    61  				end = i
    62  			}
    63  			maybeV6, err2 := netip.ParseAddr(host[1:end])
    64  			if err2 != nil || !maybeV6.Is6() {
    65  				return nil, err
    66  			}
    67  		} else {
    68  			return nil, err
    69  		}
    70  		host = host[1 : len(host)-1]
    71  	}
    72  	return &Endpoint{host, port}, nil
    73  }
    74  
    75  func parseMTU(s string) (uint16, error) {
    76  	m, err := strconv.Atoi(s)
    77  	if err != nil {
    78  		return 0, err
    79  	}
    80  	if m < 576 || m > 65535 {
    81  		return 0, &ParseError{l18n.Sprintf("Invalid MTU"), s}
    82  	}
    83  	return uint16(m), nil
    84  }
    85  
    86  func parsePort(s string) (uint16, error) {
    87  	m, err := strconv.Atoi(s)
    88  	if err != nil {
    89  		return 0, err
    90  	}
    91  	if m < 0 || m > 65535 {
    92  		return 0, &ParseError{l18n.Sprintf("Invalid port"), s}
    93  	}
    94  	return uint16(m), nil
    95  }
    96  
    97  func parsePersistentKeepalive(s string) (uint16, error) {
    98  	if s == "off" {
    99  		return 0, nil
   100  	}
   101  	m, err := strconv.Atoi(s)
   102  	if err != nil {
   103  		return 0, err
   104  	}
   105  	if m < 0 || m > 65535 {
   106  		return 0, &ParseError{l18n.Sprintf("Invalid persistent keepalive"), s}
   107  	}
   108  	return uint16(m), nil
   109  }
   110  
   111  func parseTableOff(s string) (bool, error) {
   112  	if s == "off" {
   113  		return true, nil
   114  	} else if s == "auto" || s == "main" {
   115  		return false, nil
   116  	}
   117  	_, err := strconv.ParseUint(s, 10, 32)
   118  	return false, err
   119  }
   120  
   121  func parseKeyBase64(s string) (*Key, error) {
   122  	k, err := base64.StdEncoding.DecodeString(s)
   123  	if err != nil {
   124  		return nil, &ParseError{l18n.Sprintf("Invalid key: %v", err), s}
   125  	}
   126  	if len(k) != KeyLength {
   127  		return nil, &ParseError{l18n.Sprintf("Keys must decode to exactly 32 bytes"), s}
   128  	}
   129  	var key Key
   130  	copy(key[:], k)
   131  	return &key, nil
   132  }
   133  
   134  func splitList(s string) ([]string, error) {
   135  	var out []string
   136  	for _, split := range strings.Split(s, ",") {
   137  		trim := strings.TrimSpace(split)
   138  		if len(trim) == 0 {
   139  			return nil, &ParseError{l18n.Sprintf("Two commas in a row"), s}
   140  		}
   141  		out = append(out, trim)
   142  	}
   143  	return out, nil
   144  }
   145  
   146  type parserState int
   147  
   148  const (
   149  	inInterfaceSection parserState = iota
   150  	inPeerSection
   151  	notInASection
   152  )
   153  
   154  func (c *Config) maybeAddPeer(p *Peer) {
   155  	if p != nil {
   156  		c.Peers = append(c.Peers, *p)
   157  	}
   158  }
   159  
   160  func FromWgQuick(s, name string) (*Config, error) {
   161  	if !TunnelNameIsValid(name) {
   162  		return nil, &ParseError{l18n.Sprintf("Tunnel name is not valid"), name}
   163  	}
   164  	lines := strings.Split(s, "\n")
   165  	parserState := notInASection
   166  	conf := Config{Name: name}
   167  	sawPrivateKey := false
   168  	var peer *Peer
   169  	for _, line := range lines {
   170  		line, _, _ = strings.Cut(line, "#")
   171  		line = strings.TrimSpace(line)
   172  		lineLower := strings.ToLower(line)
   173  		if len(line) == 0 {
   174  			continue
   175  		}
   176  		if lineLower == "[interface]" {
   177  			conf.maybeAddPeer(peer)
   178  			parserState = inInterfaceSection
   179  			continue
   180  		}
   181  		if lineLower == "[peer]" {
   182  			conf.maybeAddPeer(peer)
   183  			peer = &Peer{}
   184  			parserState = inPeerSection
   185  			continue
   186  		}
   187  		if parserState == notInASection {
   188  			return nil, &ParseError{l18n.Sprintf("Line must occur in a section"), line}
   189  		}
   190  		equals := strings.IndexByte(line, '=')
   191  		if equals < 0 {
   192  			return nil, &ParseError{l18n.Sprintf("Config key is missing an equals separator"), line}
   193  		}
   194  		key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:])
   195  		if len(val) == 0 {
   196  			return nil, &ParseError{l18n.Sprintf("Key must have a value"), line}
   197  		}
   198  		if parserState == inInterfaceSection {
   199  			switch key {
   200  			case "privatekey":
   201  				k, err := parseKeyBase64(val)
   202  				if err != nil {
   203  					return nil, err
   204  				}
   205  				conf.Interface.PrivateKey = *k
   206  				sawPrivateKey = true
   207  			case "listenport":
   208  				p, err := parsePort(val)
   209  				if err != nil {
   210  					return nil, err
   211  				}
   212  				conf.Interface.ListenPort = p
   213  			case "mtu":
   214  				m, err := parseMTU(val)
   215  				if err != nil {
   216  					return nil, err
   217  				}
   218  				conf.Interface.MTU = m
   219  			case "address":
   220  				addresses, err := splitList(val)
   221  				if err != nil {
   222  					return nil, err
   223  				}
   224  				for _, address := range addresses {
   225  					a, err := parseIPCidr(address)
   226  					if err != nil {
   227  						return nil, err
   228  					}
   229  					conf.Interface.Addresses = append(conf.Interface.Addresses, a)
   230  				}
   231  			case "dns":
   232  				addresses, err := splitList(val)
   233  				if err != nil {
   234  					return nil, err
   235  				}
   236  				for _, address := range addresses {
   237  					a, err := netip.ParseAddr(address)
   238  					if err != nil {
   239  						conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
   240  					} else {
   241  						conf.Interface.DNS = append(conf.Interface.DNS, a)
   242  					}
   243  				}
   244  			case "preup":
   245  				conf.Interface.PreUp = val
   246  			case "postup":
   247  				conf.Interface.PostUp = val
   248  			case "predown":
   249  				conf.Interface.PreDown = val
   250  			case "postdown":
   251  				conf.Interface.PostDown = val
   252  			case "table":
   253  				tableOff, err := parseTableOff(val)
   254  				if err != nil {
   255  					return nil, err
   256  				}
   257  				conf.Interface.TableOff = tableOff
   258  			default:
   259  				return nil, &ParseError{l18n.Sprintf("Invalid key for [Interface] section"), key}
   260  			}
   261  		} else if parserState == inPeerSection {
   262  			switch key {
   263  			case "publickey":
   264  				k, err := parseKeyBase64(val)
   265  				if err != nil {
   266  					return nil, err
   267  				}
   268  				peer.PublicKey = *k
   269  			case "presharedkey":
   270  				k, err := parseKeyBase64(val)
   271  				if err != nil {
   272  					return nil, err
   273  				}
   274  				peer.PresharedKey = *k
   275  			case "allowedips":
   276  				addresses, err := splitList(val)
   277  				if err != nil {
   278  					return nil, err
   279  				}
   280  				for _, address := range addresses {
   281  					a, err := parseIPCidr(address)
   282  					if err != nil {
   283  						return nil, err
   284  					}
   285  					peer.AllowedIPs = append(peer.AllowedIPs, a)
   286  				}
   287  			case "persistentkeepalive":
   288  				p, err := parsePersistentKeepalive(val)
   289  				if err != nil {
   290  					return nil, err
   291  				}
   292  				peer.PersistentKeepalive = p
   293  			case "endpoint":
   294  				e, err := parseEndpoint(val)
   295  				if err != nil {
   296  					return nil, err
   297  				}
   298  				peer.Endpoint = *e
   299  			default:
   300  				return nil, &ParseError{l18n.Sprintf("Invalid key for [Peer] section"), key}
   301  			}
   302  		}
   303  	}
   304  	conf.maybeAddPeer(peer)
   305  
   306  	if !sawPrivateKey {
   307  		return nil, &ParseError{l18n.Sprintf("An interface must have a private key"), l18n.Sprintf("[none specified]")}
   308  	}
   309  	for _, p := range conf.Peers {
   310  		if p.PublicKey.IsZero() {
   311  			return nil, &ParseError{l18n.Sprintf("All peers must have public keys"), l18n.Sprintf("[none specified]")}
   312  		}
   313  	}
   314  
   315  	return &conf, nil
   316  }
   317  
   318  func FromWgQuickWithUnknownEncoding(s, name string) (*Config, error) {
   319  	c, firstErr := FromWgQuick(s, name)
   320  	if firstErr == nil {
   321  		return c, nil
   322  	}
   323  	for _, encoding := range unicode.All {
   324  		decoded, err := encoding.NewDecoder().String(s)
   325  		if err == nil {
   326  			c, err := FromWgQuick(decoded, name)
   327  			if err == nil {
   328  				return c, nil
   329  			}
   330  		}
   331  	}
   332  	return nil, firstErr
   333  }
   334  
   335  func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config) *Config {
   336  	conf := Config{
   337  		Name: existingConfig.Name,
   338  		Interface: Interface{
   339  			Addresses: existingConfig.Interface.Addresses,
   340  			DNS:       existingConfig.Interface.DNS,
   341  			DNSSearch: existingConfig.Interface.DNSSearch,
   342  			MTU:       existingConfig.Interface.MTU,
   343  			PreUp:     existingConfig.Interface.PreUp,
   344  			PostUp:    existingConfig.Interface.PostUp,
   345  			PreDown:   existingConfig.Interface.PreDown,
   346  			PostDown:  existingConfig.Interface.PostDown,
   347  			TableOff:  existingConfig.Interface.TableOff,
   348  		},
   349  	}
   350  	if interfaze.Flags&driver.InterfaceHasPrivateKey != 0 {
   351  		conf.Interface.PrivateKey = interfaze.PrivateKey
   352  	}
   353  	if interfaze.Flags&driver.InterfaceHasListenPort != 0 {
   354  		conf.Interface.ListenPort = interfaze.ListenPort
   355  	}
   356  	var p *driver.Peer
   357  	for i := uint32(0); i < interfaze.PeerCount; i++ {
   358  		if p == nil {
   359  			p = interfaze.FirstPeer()
   360  		} else {
   361  			p = p.NextPeer()
   362  		}
   363  		peer := Peer{}
   364  		if p.Flags&driver.PeerHasPublicKey != 0 {
   365  			peer.PublicKey = p.PublicKey
   366  		}
   367  		if p.Flags&driver.PeerHasPresharedKey != 0 {
   368  			peer.PresharedKey = p.PresharedKey
   369  		}
   370  		if p.Flags&driver.PeerHasEndpoint != 0 {
   371  			peer.Endpoint.Port = p.Endpoint.Port()
   372  			peer.Endpoint.Host = p.Endpoint.Addr().String()
   373  		}
   374  		if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
   375  			peer.PersistentKeepalive = p.PersistentKeepalive
   376  		}
   377  		peer.TxBytes = Bytes(p.TxBytes)
   378  		peer.RxBytes = Bytes(p.RxBytes)
   379  		if p.LastHandshake != 0 {
   380  			peer.LastHandshakeTime = HandshakeTime((p.LastHandshake - 116444736000000000) * 100)
   381  		}
   382  		var a *driver.AllowedIP
   383  		for j := uint32(0); j < p.AllowedIPsCount; j++ {
   384  			if a == nil {
   385  				a = p.FirstAllowedIP()
   386  			} else {
   387  				a = a.NextAllowedIP()
   388  			}
   389  			var ip netip.Addr
   390  			if a.AddressFamily == windows.AF_INET {
   391  				ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4]))
   392  			} else if a.AddressFamily == windows.AF_INET6 {
   393  				ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16]))
   394  			}
   395  			peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(ip, int(a.Cidr)))
   396  		}
   397  		conf.Peers = append(conf.Peers, peer)
   398  	}
   399  	return &conf
   400  }