github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/tools/lxdclient/addserver.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  // +build go1.3
     5  
     6  package lxdclient
     7  
     8  import (
     9  	"fmt"
    10  	"net"
    11  	"net/url"
    12  	"strconv"
    13  	"strings"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/lxc/lxd/shared"
    17  )
    18  
    19  // TODO(ericsnow) Rename addr -> remoteURL?
    20  
    21  func fixAddr(addr string) (string, error) {
    22  	if addr == "" {
    23  		// TODO(ericsnow) Return lxd.LocalRemote.Addr?
    24  		return addr, nil
    25  	}
    26  	if strings.HasPrefix(addr, "unix:") {
    27  		return "", errors.NewNotValid(nil, fmt.Sprintf("unix socket URLs not supported (got %q)", addr))
    28  	}
    29  
    30  	// Fix IPv6 URLs.
    31  	if strings.HasPrefix(addr, ":") {
    32  		parts := strings.SplitN(addr, "/", 2)
    33  		if net.ParseIP(parts[0]) != nil {
    34  			addr = fmt.Sprintf("[%s]", parts[0])
    35  			if len(parts) == 2 {
    36  				addr = "/" + parts[1]
    37  			}
    38  		}
    39  	}
    40  
    41  	parsedURL, err := url.Parse(addr)
    42  	if err != nil {
    43  		return "", errors.Trace(err)
    44  	}
    45  	if parsedURL.RawQuery != "" {
    46  		return "", errors.NewNotValid(nil, fmt.Sprintf("URL queries not supported (got %q)", addr))
    47  	}
    48  	if parsedURL.Fragment != "" {
    49  		return "", errors.NewNotValid(nil, fmt.Sprintf("URL fragments not supported (got %q)", addr))
    50  	}
    51  	if parsedURL.Opaque != "" {
    52  		if strings.Contains(parsedURL.Scheme, ".") {
    53  			addr, err := fixAddr("https://" + addr)
    54  			if err != nil {
    55  				return "", errors.Trace(err)
    56  			}
    57  			return addr, nil
    58  		}
    59  		return "", errors.NewNotValid(nil, fmt.Sprintf("opaque URLs not supported (got %q)", addr))
    60  	}
    61  
    62  	remoteURL := url.URL{
    63  		Scheme: parsedURL.Scheme,
    64  		Host:   parsedURL.Host,
    65  		Path:   strings.TrimRight(parsedURL.Path, "/"),
    66  	}
    67  
    68  	// Fix the scheme.
    69  	remoteURL.Scheme = fixScheme(remoteURL)
    70  	if err := validateScheme(remoteURL); err != nil {
    71  		return "", errors.Trace(err)
    72  	}
    73  
    74  	// Fix the host.
    75  	if remoteURL.Host == "" {
    76  		if strings.HasPrefix(remoteURL.Path, "/") {
    77  			return "", errors.NewNotValid(nil, fmt.Sprintf("unix socket URLs not supported (got %q)", addr))
    78  		}
    79  		addr = fmt.Sprintf("%s://%s%s", remoteURL.Scheme, remoteURL.Host, remoteURL.Path)
    80  		addr, err := fixAddr(addr)
    81  		if err != nil {
    82  			return "", errors.Trace(err)
    83  		}
    84  		return addr, nil
    85  	}
    86  	remoteURL.Host = fixHost(remoteURL.Host, shared.DefaultPort)
    87  	if err := validateHost(remoteURL); err != nil {
    88  		return "", errors.Trace(err)
    89  	}
    90  
    91  	// TODO(ericsnow) Use remoteUrl.String()
    92  	return fmt.Sprintf("%s://%s%s", remoteURL.Scheme, remoteURL.Host, remoteURL.Path), nil
    93  }
    94  
    95  func fixScheme(url url.URL) string {
    96  	switch url.Scheme {
    97  	case "https":
    98  		return url.Scheme
    99  	case "http":
   100  		return "https"
   101  	case "":
   102  		return "https"
   103  	default:
   104  		return url.Scheme
   105  	}
   106  }
   107  
   108  func validateScheme(url url.URL) error {
   109  	switch url.Scheme {
   110  	case "https":
   111  	default:
   112  		return errors.NewNotValid(nil, fmt.Sprintf("unsupported URL scheme %q", url.Scheme))
   113  	}
   114  	return nil
   115  }
   116  
   117  func fixHost(host, defaultPort string) string {
   118  	// Handle IPv6 hosts.
   119  	if strings.Count(host, ":") > 1 {
   120  		if !strings.HasPrefix(host, "[") {
   121  			return fmt.Sprintf("[%s]:%s", host, defaultPort)
   122  		} else if !strings.Contains(host, "]:") {
   123  			return host + ":" + defaultPort
   124  		}
   125  		return host
   126  	}
   127  
   128  	// Handle ports.
   129  	if !strings.Contains(host, ":") {
   130  		return host + ":" + defaultPort
   131  	}
   132  
   133  	return host
   134  }
   135  
   136  func validateHost(url url.URL) error {
   137  	if url.Host == "" {
   138  		return errors.NewNotValid(nil, "URL missing host")
   139  	}
   140  
   141  	host, port, err := net.SplitHostPort(url.Host)
   142  	if err != nil {
   143  		return errors.NewNotValid(err, "")
   144  	}
   145  
   146  	// Check the host.
   147  	if net.ParseIP(host) == nil {
   148  		if err := validateDomainName(host); err != nil {
   149  			return errors.Trace(err)
   150  		}
   151  	}
   152  
   153  	// Check the port.
   154  	if p, err := strconv.Atoi(port); err != nil {
   155  		return errors.NewNotValid(err, fmt.Sprintf("invalid port in host %q", url.Host))
   156  	} else if p <= 0 || p > 0xFFFF {
   157  		return errors.NewNotValid(err, fmt.Sprintf("invalid port in host %q", url.Host))
   158  	}
   159  
   160  	return nil
   161  }
   162  
   163  func validateDomainName(fqdn string) error {
   164  	// TODO(ericsnow) Do checks for a valid domain name.
   165  
   166  	return nil
   167  }