github.com/icyphox/x@v0.0.355-0.20220311094250-029bd783e8b8/httpx/private_ip_validator.go (about)

     1  package httpx
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"net/http"
     7  	"net/url"
     8  
     9  	"github.com/ory/x/stringsx"
    10  
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  // DisallowPrivateIPAddressesWhenSet is a wrapper for DisallowIPPrivateAddresses which returns valid
    15  // when ipOrHostnameOrURL is empty.
    16  func DisallowPrivateIPAddressesWhenSet(ipOrHostnameOrURL string) error {
    17  	if ipOrHostnameOrURL == "" {
    18  		return nil
    19  	}
    20  	return DisallowIPPrivateAddresses(ipOrHostnameOrURL)
    21  }
    22  
    23  // DisallowIPPrivateAddresses returns nil for a domain (with NS lookup), IP, or IPv6 address if it
    24  // does not resolve to a private IP subnet. This is a first level of defense against
    25  // SSRF attacks by disallowing any domain or IP to resolve to a private network range.
    26  //
    27  // Please keep in mind that validations for domains is valid only when looking up.
    28  // A malicious actor could easily update the DSN record post validation to point
    29  // to an internal IP
    30  func DisallowIPPrivateAddresses(ipOrHostnameOrURL string) error {
    31  	lookup := func(hostname string) ([]net.IP, error) {
    32  		lookup, err := net.LookupIP(hostname)
    33  		if err != nil {
    34  			if dnsErr := new(net.DNSError); errors.As(err, &dnsErr) && (dnsErr.IsNotFound || dnsErr.IsTemporary) {
    35  				// If the hostname does not resolve, we can't validate it. So yeah,
    36  				// I guess we're allowing it.
    37  				return nil, nil
    38  			}
    39  			return nil, errors.WithStack(err)
    40  		}
    41  		return lookup, nil
    42  	}
    43  
    44  	var ips []net.IP
    45  	ip := net.ParseIP(ipOrHostnameOrURL)
    46  	if ip == nil {
    47  		if result, err := lookup(ipOrHostnameOrURL); err != nil {
    48  			return err
    49  		} else if result != nil {
    50  			ips = append(ips, result...)
    51  		}
    52  
    53  		if parsed, err := url.Parse(ipOrHostnameOrURL); err == nil {
    54  			if result, err := lookup(parsed.Hostname()); err != nil {
    55  				return err
    56  			} else if result != nil {
    57  				ips = append(ips, result...)
    58  			}
    59  		}
    60  	} else {
    61  		ips = append(ips, ip)
    62  	}
    63  
    64  	for _, disabled := range []string{
    65  		"127.0.0.0/8",
    66  		"10.0.0.0/8",
    67  		"172.16.0.0/12",
    68  		"192.168.0.0/16",
    69  		"fd47:1ed0:805d:59f0::/64",
    70  		"fc00::/7",
    71  		"::1/128",
    72  	} {
    73  		_, cidr, err := net.ParseCIDR(disabled)
    74  		if err != nil {
    75  			return err
    76  		}
    77  
    78  		for _, ip := range ips {
    79  			if cidr.Contains(ip) {
    80  				return fmt.Errorf("ip %s is in the %s range", ip, disabled)
    81  			}
    82  		}
    83  	}
    84  
    85  	return nil
    86  }
    87  
    88  var _ http.RoundTripper = (*NoInternalIPRoundTripper)(nil)
    89  
    90  // NoInternalIPRoundTripper is a RoundTripper that disallows internal IP addresses.
    91  type NoInternalIPRoundTripper struct {
    92  	http.RoundTripper
    93  }
    94  
    95  func (n NoInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
    96  	host, _, _ := net.SplitHostPort(request.Host)
    97  	if err := DisallowIPPrivateAddresses(stringsx.Coalesce(host, request.Host)); err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	if n.RoundTripper == nil {
   102  		return http.DefaultTransport.RoundTrip(request)
   103  	}
   104  
   105  	return n.RoundTripper.RoundTrip(request)
   106  }