github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak/src/server/https/message_server.go (about)

     1  // Copyright 2017 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package https
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"net/http"
    23  	"strconv"
    24  	"time"
    25  
    26  	log "github.com/golang/glog"
    27  	"google.golang.org/protobuf/proto"
    28  
    29  	"github.com/google/fleetspeak/fleetspeak/src/common"
    30  	"github.com/google/fleetspeak/fleetspeak/src/server/comms"
    31  	"github.com/google/fleetspeak/fleetspeak/src/server/db"
    32  	"github.com/google/fleetspeak/fleetspeak/src/server/stats"
    33  
    34  	fspb "github.com/google/fleetspeak/fleetspeak/src/common/proto/fleetspeak"
    35  )
    36  
    37  // messageServer wraps a Communicator in order to handle clients polls.
    38  type messageServer struct {
    39  	*Communicator
    40  }
    41  
    42  type unknownAddr struct {
    43  	address string
    44  }
    45  
    46  func (a unknownAddr) Network() string { return "unknown" }
    47  func (a unknownAddr) String() string  { return a.address }
    48  
    49  // addrFromString takes an address in string form, e.g. from
    50  // http.Request.RemoteAddress and attempts to create an appropriate
    51  // implementation of net.Addr.
    52  //
    53  // Currently it recognizes numeric TCP Addresses (e.g. 127.0.0.1:80,
    54  // or [::1]:80) and puts them into a TCPAddr. Anything else is just
    55  // wrapped in an unknownAddr.
    56  func addrFromString(addr string) net.Addr {
    57  	host, port, err := net.SplitHostPort(addr)
    58  	if err != nil {
    59  		return unknownAddr{address: addr}
    60  	}
    61  	ip := net.ParseIP(host)
    62  	if ip == nil {
    63  		return unknownAddr{address: addr}
    64  	}
    65  	p, err := strconv.Atoi(port)
    66  	if err != nil {
    67  		return unknownAddr{address: addr}
    68  	}
    69  	return &net.TCPAddr{
    70  		IP:   ip,
    71  		Port: p,
    72  	}
    73  }
    74  
    75  // ServeHTTP implements http.Handler
    76  func (s messageServer) ServeHTTP(res http.ResponseWriter, req *http.Request) {
    77  	ctx, fin := context.WithTimeout(req.Context(), 5*time.Minute)
    78  
    79  	pi := stats.PollInfo{
    80  		CTX:    req.Context(),
    81  		Start:  db.Now(),
    82  		Status: http.StatusTeapot, // Should never actually be returned
    83  	}
    84  	defer func() {
    85  		fin()
    86  		if pi.Status == http.StatusTeapot {
    87  			log.Errorf("Forgot to set status.")
    88  		}
    89  		pi.End = db.Now()
    90  		s.fs.StatsCollector().ClientPoll(pi)
    91  	}()
    92  
    93  	if !s.startProcessing() {
    94  		log.Error("InternalServerError: server not ready.")
    95  		pi.Status = http.StatusInternalServerError
    96  		http.Error(res, "Server not ready.", pi.Status)
    97  		return
    98  	}
    99  	defer s.stopProcessing()
   100  
   101  	if req.Method != http.MethodPost {
   102  		pi.Status = http.StatusBadRequest
   103  		http.Error(res, fmt.Sprintf("%v not supported", req.Method), pi.Status)
   104  		return
   105  	}
   106  	if req.ContentLength > MaxContactSize {
   107  		pi.Status = http.StatusBadRequest
   108  		http.Error(res, fmt.Sprintf("content length too large: %v", req.ContentLength), pi.Status)
   109  		return
   110  	}
   111  
   112  	cert, err := GetClientCert(req, s.p.FrontendConfig)
   113  	if err != nil {
   114  		pi.Status = http.StatusBadRequest
   115  		http.Error(res, err.Error(), pi.Status)
   116  		return
   117  	}
   118  
   119  	if cert.PublicKey == nil {
   120  		pi.Status = http.StatusBadRequest
   121  		http.Error(res, "public key not present in client cert", pi.Status)
   122  		return
   123  	}
   124  	id, err := common.MakeClientID(cert.PublicKey)
   125  	if err != nil {
   126  		pi.Status = http.StatusBadRequest
   127  		http.Error(res, fmt.Sprintf("unable to create client id from public key: %v", err), pi.Status)
   128  		return
   129  	}
   130  	pi.ID = id
   131  
   132  	req.Body = http.MaxBytesReader(res, req.Body, MaxContactSize+1)
   133  	st := time.Now()
   134  	buf, err := io.ReadAll(req.Body)
   135  	pi.ReadTime = time.Since(st)
   136  	pi.ReadBytes = len(buf)
   137  
   138  	if len(buf) > MaxContactSize {
   139  		pi.Status = http.StatusBadRequest
   140  		http.Error(res, fmt.Sprintf("body can't be larger than %v bytes", MaxContactSize), pi.Status)
   141  		return
   142  	}
   143  	if err != nil {
   144  		pi.Status = http.StatusBadRequest
   145  		http.Error(res, fmt.Sprintf("error reading body: %v", err), pi.Status)
   146  		return
   147  	}
   148  	var wcd fspb.WrappedContactData
   149  	if err = proto.Unmarshal(buf, &wcd); err != nil {
   150  		pi.Status = http.StatusBadRequest
   151  		http.Error(res, fmt.Sprintf("error parsing body: %v", err), pi.Status)
   152  		return
   153  	}
   154  	addr := addrFromString(req.RemoteAddr)
   155  
   156  	info, toSend, _, err := s.fs.InitializeConnection(ctx, addr, cert.PublicKey, &wcd, false)
   157  	if err == comms.ErrNotAuthorized {
   158  		pi.Status = http.StatusServiceUnavailable
   159  		http.Error(res, "not authorized", pi.Status)
   160  		return
   161  	}
   162  	if err != nil {
   163  		log.Errorf("InternalServiceError: error processing contact: %v", err)
   164  		pi.Status = http.StatusInternalServerError
   165  		http.Error(res, fmt.Sprintf("error processing contact: %v", err), pi.Status)
   166  		return
   167  	}
   168  	pi.CacheHit = info.Client.Cached
   169  
   170  	bytes, err := proto.Marshal(toSend)
   171  	if err != nil {
   172  		log.Errorf("InternalServerError: proto.Marshal returned error: %v", err)
   173  		pi.Status = http.StatusInternalServerError
   174  		http.Error(res, fmt.Sprintf("error preparing messages: %v", err), pi.Status)
   175  		return
   176  	}
   177  
   178  	res.Header().Set("Content-Type", "application/octet-stream")
   179  	res.WriteHeader(http.StatusOK)
   180  	st = time.Now()
   181  	size, err := res.Write(bytes)
   182  	if err != nil {
   183  		log.Warningf("Error writing body: %v", err)
   184  		pi.Status = http.StatusBadRequest
   185  		return
   186  	}
   187  
   188  	pi.WriteTime = time.Since(st)
   189  	pi.WriteBytes = size
   190  	pi.Status = http.StatusOK
   191  }