github.com/webmeshproj/webmesh-cni@v0.0.27/internal/metadata/id_token_server.go (about)

     1  /*
     2  Copyright 2023 Avi Zimmerman <avi.zimmerman@gmail.com>.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package metadata
    18  
    19  import (
    20  	"crypto/ed25519"
    21  	"fmt"
    22  	"net/http"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/go-jose/go-jose/v3"
    27  	"github.com/go-jose/go-jose/v3/jwt"
    28  	"github.com/webmeshproj/webmesh/pkg/storage/types"
    29  	"sigs.k8s.io/controller-runtime/pkg/log"
    30  )
    31  
    32  // IDTokenServer is the server for ID tokens. It can create identification
    33  // tokens for clients to use to access other services in the cluster.
    34  type IDTokenServer struct{ *Server }
    35  
    36  // SignerHeader is the header specifying which node signed the token.
    37  const SignerHeader = "cni"
    38  
    39  // Now is a function that returns the current time. It is used to override
    40  // the time used for token validation.
    41  var Now = time.Now
    42  
    43  // IDClaims are the claims for an ID token.
    44  type IDClaims struct {
    45  	jwt.Claims `json:",inline"`
    46  	Groups     []string `json:"groups"`
    47  }
    48  
    49  // ServeHTTP implements http.Handler and will handle token issuance and validation.
    50  func (i *IDTokenServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    51  	rlog := log.FromContext(r.Context())
    52  	w.Header().Set("Access-Control-Allow-Origin", "*")
    53  	w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
    54  	w.Header().Set("Access-Control-Allow-Headers", "Authorization")
    55  	w.Header().Set("Access-Control-Max-Age", "86400")
    56  	if r.Method == http.MethodOptions {
    57  		// Handle CORS preflight requests.
    58  		rlog.Info("Serving CORS preflight request", "path", r.URL.Path)
    59  		w.WriteHeader(http.StatusOK)
    60  		return
    61  	}
    62  	rlog.Info("Serving metadata request", "path", r.URL.Path)
    63  	switch r.URL.Path {
    64  	case "/id-tokens/issue":
    65  		i.issueToken(w, r)
    66  	case "/id-tokens/validate":
    67  		i.validateToken(w, r)
    68  	default:
    69  		http.NotFound(w, r)
    70  	}
    71  }
    72  
    73  func (i *IDTokenServer) issueToken(w http.ResponseWriter, r *http.Request) {
    74  	rlog := log.FromContext(r.Context())
    75  	rlog.Info("Issuing ID token")
    76  	info, err := i.getPeerInfoFromRequest(r)
    77  	if err != nil {
    78  		i.returnError(w, err)
    79  		return
    80  	}
    81  	sig, err := i.newSigner()
    82  	if err != nil {
    83  		i.returnError(w, err)
    84  		return
    85  	}
    86  	peerkey, err := info.Peer.DecodePublicKey()
    87  	if err != nil {
    88  		i.returnError(w, err)
    89  		return
    90  	}
    91  	cl := IDClaims{
    92  		Claims: jwt.Claims{
    93  			Issuer:    i.Host.ID().String(),
    94  			Subject:   info.Peer.GetId(),
    95  			Audience:  i.audience(),
    96  			Expiry:    jwt.NewNumericDate(Now().UTC().Add(5 * time.Minute)),
    97  			NotBefore: jwt.NewNumericDate(Now().UTC()),
    98  			IssuedAt:  jwt.NewNumericDate(Now().UTC()),
    99  			ID: func() string {
   100  				if info.Peer.GetId() == peerkey.ID() {
   101  					// Don't include the ID if it's the same as the subject.
   102  					// Saves space and makes it easier to read.
   103  					return ":sub"
   104  				}
   105  				return peerkey.ID()
   106  			}(),
   107  		},
   108  		Groups: []string{},
   109  	}
   110  	groups, err := i.Storage.MeshDB().RBAC().ListGroups(r.Context())
   111  	if err == nil {
   112  		for _, g := range groups {
   113  			if g.ContainsNode(info.Peer.NodeID()) {
   114  				cl.Groups = append(cl.Groups, g.GetName())
   115  			}
   116  		}
   117  	} else {
   118  		rlog.Error(err, "Failed to list groups, skipping")
   119  	}
   120  	raw, err := jwt.Signed(sig).Claims(cl).CompactSerialize()
   121  	if err != nil {
   122  		i.returnError(w, err)
   123  		return
   124  	}
   125  	out := map[string]any{
   126  		"id":      peerkey.ID(),
   127  		"token":   raw,
   128  		"expires": cl.Expiry.Time().Format(time.RFC3339),
   129  	}
   130  	i.returnJSON(w, out)
   131  }
   132  
   133  func (i *IDTokenServer) validateToken(w http.ResponseWriter, r *http.Request) {
   134  	rlog := log.FromContext(r.Context())
   135  	rlog.Info("Validating ID token")
   136  	token := r.Header.Get("Authorization")
   137  	if token == "" {
   138  		i.returnError(w, fmt.Errorf("missing Authorization header"))
   139  		return
   140  	}
   141  	tok, err := jwt.ParseSigned(token)
   142  	if err != nil {
   143  		i.returnError(w, err)
   144  		return
   145  	}
   146  	issuer := i.Host.ID().String()
   147  	if len(tok.Headers) > 0 {
   148  		peer, ok := tok.Headers[0].ExtraHeaders[SignerHeader].(string)
   149  		if ok {
   150  			issuer = peer
   151  		}
   152  	}
   153  	var pubkey ed25519.PublicKey
   154  	switch issuer {
   155  	case i.Host.ID().String():
   156  		rlog.V(1).Info("Token was signed by the local host")
   157  		pubkey = i.publicKey()
   158  	default:
   159  		rlog.V(1).Info("Token was signed by a peer node", "issuer", issuer)
   160  		issuingPeer, err := i.Storage.MeshDB().Peers().Get(r.Context(), types.NodeID(issuer))
   161  		if err != nil {
   162  			i.returnError(w, err)
   163  			return
   164  		}
   165  		wmkey, err := issuingPeer.DecodePublicKey()
   166  		if err != nil {
   167  			i.returnError(w, err)
   168  			return
   169  		}
   170  		pubkey = wmkey.AsNative()
   171  	}
   172  	var cl IDClaims
   173  	if err := tok.Claims(pubkey, &cl); err != nil {
   174  		i.returnError(w, err)
   175  		return
   176  	}
   177  	expected := jwt.Expected{
   178  		// Optional fields to validate based on the query.
   179  		ID:      r.URL.Query().Get("id"),
   180  		Subject: r.URL.Query().Get("subject"),
   181  		Issuer:  r.URL.Query().Get("issuer"),
   182  		// Ensure it's the audience we expect.
   183  		Audience: i.audience(),
   184  		// Ensure the token is not expired.
   185  		Time: Now().UTC(),
   186  	}
   187  	if err := cl.Validate(expected); err != nil {
   188  		i.returnError(w, err)
   189  		return
   190  	}
   191  	i.returnJSON(w, cl)
   192  }
   193  
   194  func (i *IDTokenServer) newSigner() (jose.Signer, error) {
   195  	return jose.NewSigner(i.signingKey(), i.signingOptions())
   196  }
   197  
   198  func (i *IDTokenServer) signingKey() jose.SigningKey {
   199  	return jose.SigningKey{
   200  		Algorithm: jose.EdDSA,
   201  		Key:       i.privateKey(),
   202  	}
   203  }
   204  
   205  func (i *IDTokenServer) signingOptions() *jose.SignerOptions {
   206  	return (&jose.SignerOptions{
   207  		ExtraHeaders: map[jose.HeaderKey]any{
   208  			SignerHeader: i.Host.ID().String(),
   209  		},
   210  	}).WithType("JWT")
   211  }
   212  
   213  func (i *IDTokenServer) privateKey() ed25519.PrivateKey {
   214  	return i.Host.Node().Key().AsNative()
   215  }
   216  
   217  func (i *IDTokenServer) publicKey() ed25519.PublicKey {
   218  	return i.Host.Node().Key().PublicKey().AsNative()
   219  }
   220  
   221  func (i *IDTokenServer) audience() jwt.Audience {
   222  	return jwt.Audience{strings.TrimSuffix(i.Host.Node().Domain(), ".")}
   223  }