github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/protocol/v2/secure_reply.go (about)

     1  // Copyright (c) 2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package v2
     6  
     7  import (
     8  	"encoding/base64"
     9  	"encoding/json"
    10  	"fmt"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/choria-io/go-choria/inter"
    15  	"github.com/choria-io/go-choria/protocol"
    16  )
    17  
    18  // SecureReply contains 1 serialized Reply hashed
    19  type SecureReply struct {
    20  	// The protocol version for this secure reply `io.choria.protocol.v2.secure_reply` / protocol.SecureReplyV2
    21  	Protocol protocol.ProtocolVersion `json:"protocol"`
    22  	// The reply held in the Secure Request
    23  	MessageBody []byte `json:"reply"`
    24  	// A sha256 of the reply
    25  	Hash string `json:"hash"`
    26  	// A signature made using the ed25519 seed of the sender
    27  	Signature []byte `json:"signature,omitempty"`
    28  	// The JWT of the sending host
    29  	SenderJWT string `json:"sender,omitempty"`
    30  
    31  	security inter.SecurityProvider
    32  	mu       sync.Mutex
    33  }
    34  
    35  // NewSecureReply creates a io.choria.protocol.v2.secure_reply
    36  func NewSecureReply(reply protocol.Reply, security inter.SecurityProvider) (protocol.SecureReply, error) {
    37  	if security.BackingTechnology() != inter.SecurityTechnologyED25519JWT {
    38  		return nil, fmt.Errorf("version 2 protocol requires a ed25519+jwt based security system")
    39  	}
    40  
    41  	secure := &SecureReply{
    42  		Protocol: protocol.SecureReplyV2,
    43  		security: security,
    44  	}
    45  
    46  	err := secure.SetMessage(reply)
    47  	if err != nil {
    48  		return nil, fmt.Errorf("could not set message on SecureReply structure: %s", err)
    49  	}
    50  
    51  	return secure, nil
    52  }
    53  
    54  // NewSecureReplyFromTransport creates a new io.choria.protocol.v2.secure_reply from the data contained in a Transport message
    55  func NewSecureReplyFromTransport(message protocol.TransportMessage, security inter.SecurityProvider, skipvalidate bool) (protocol.SecureReply, error) {
    56  	if security.BackingTechnology() != inter.SecurityTechnologyED25519JWT {
    57  		return nil, fmt.Errorf("version 2 protocol requires a ed25519+jwt based security system")
    58  	}
    59  
    60  	secure := &SecureReply{
    61  		Protocol: protocol.SecureReplyV2,
    62  		security: security,
    63  	}
    64  
    65  	data, err := message.Message()
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	err = secure.IsValidJSON(data)
    71  	if err != nil {
    72  		return nil, fmt.Errorf("the JSON body from the TransportMessage is not a valid SecureReply message: %s", err)
    73  	}
    74  
    75  	err = json.Unmarshal(data, &secure)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	if !skipvalidate {
    81  		if !secure.Valid() {
    82  			return nil, fmt.Errorf("SecureReply message created from the Transport Message is not valid")
    83  		}
    84  	}
    85  
    86  	return secure, nil
    87  }
    88  
    89  func (r *SecureReply) SetMessage(reply protocol.Reply) error {
    90  	r.mu.Lock()
    91  	defer r.mu.Unlock()
    92  
    93  	j, err := reply.JSON()
    94  	if err != nil {
    95  		protocolErrorCtr.Inc()
    96  		return fmt.Errorf("could not JSON encode reply: %v", err)
    97  	}
    98  
    99  	if r.security.ShouldSignReplies() {
   100  		jwt, err := r.security.TokenBytes()
   101  		if err != nil {
   102  			return fmt.Errorf("could not read caller token: %v", err)
   103  		}
   104  
   105  		sig, err := r.security.SignBytes(j)
   106  		if err != nil {
   107  			return err
   108  		}
   109  
   110  		r.SenderJWT = string(jwt)
   111  		r.Signature = sig
   112  	}
   113  
   114  	r.MessageBody = j
   115  	r.Hash = base64.StdEncoding.EncodeToString(r.security.ChecksumBytes(j))
   116  
   117  	return nil
   118  }
   119  
   120  func (r *SecureReply) Valid() bool {
   121  	r.mu.Lock()
   122  	defer r.mu.Unlock()
   123  
   124  	if base64.StdEncoding.EncodeToString(r.security.ChecksumBytes(r.MessageBody)) != r.Hash {
   125  		invalidCtr.Inc()
   126  		return false
   127  	}
   128  
   129  	validCtr.Inc()
   130  	return true
   131  }
   132  
   133  func (r *SecureReply) JSON() ([]byte, error) {
   134  	r.mu.Lock()
   135  	j, err := json.Marshal(r)
   136  	r.mu.Unlock()
   137  	if err != nil {
   138  		protocolErrorCtr.Inc()
   139  		return nil, fmt.Errorf("could not JSON Marshal: %s", err)
   140  	}
   141  
   142  	if err = r.IsValidJSON(j); err != nil {
   143  		return nil, fmt.Errorf("%w: %s", ErrInvalidJSON, err)
   144  	}
   145  
   146  	return j, nil
   147  }
   148  
   149  func (r *SecureReply) Message() []byte {
   150  	r.mu.Lock()
   151  	defer r.mu.Unlock()
   152  
   153  	return r.MessageBody
   154  }
   155  
   156  func (r *SecureReply) Version() protocol.ProtocolVersion {
   157  	r.mu.Lock()
   158  	defer r.mu.Unlock()
   159  
   160  	return r.Protocol
   161  }
   162  
   163  func (r *SecureReply) IsValidJSON(data []byte) error {
   164  	if !protocol.ClientStrictValidation {
   165  		return nil
   166  	}
   167  
   168  	_, errors, err := schemaValidate(protocol.SecureReplyV2, data)
   169  	if err != nil {
   170  		return fmt.Errorf("could not validate SecureReply JSON data: %s", err)
   171  	}
   172  
   173  	if len(errors) != 0 {
   174  		return fmt.Errorf("%w: %s", ErrInvalidJSON, strings.Join(errors, ", "))
   175  	}
   176  
   177  	return nil
   178  }