github.com/letsencrypt/boulder@v0.20251208.0/core/util.go (about)

     1  package core
     2  
     3  import (
     4  	"context"
     5  	"crypto"
     6  	"crypto/ecdsa"
     7  	"crypto/rand"
     8  	"crypto/rsa"
     9  	"crypto/sha256"
    10  	"crypto/x509"
    11  	"encoding/base64"
    12  	"encoding/hex"
    13  	"encoding/pem"
    14  	"errors"
    15  	"expvar"
    16  	"fmt"
    17  	"io"
    18  	"math/big"
    19  	mrand "math/rand/v2"
    20  	"os"
    21  	"path"
    22  	"reflect"
    23  	"regexp"
    24  	"sort"
    25  	"strings"
    26  	"time"
    27  	"unicode"
    28  
    29  	"github.com/go-jose/go-jose/v4"
    30  	"google.golang.org/grpc/codes"
    31  	"google.golang.org/grpc/status"
    32  	"google.golang.org/protobuf/types/known/durationpb"
    33  	"google.golang.org/protobuf/types/known/timestamppb"
    34  
    35  	"github.com/letsencrypt/boulder/identifier"
    36  )
    37  
    38  const Unspecified = "Unspecified"
    39  
    40  // Package Variables Variables
    41  
    42  // BuildID is set by the compiler (using -ldflags "-X core.BuildID $(git rev-parse --short HEAD)")
    43  // and is used by GetBuildID
    44  var BuildID string
    45  
    46  // BuildHost is set by the compiler and is used by GetBuildHost
    47  var BuildHost string
    48  
    49  // BuildTime is set by the compiler and is used by GetBuildTime
    50  var BuildTime string
    51  
    52  func init() {
    53  	expvar.NewString("BuildID").Set(BuildID)
    54  	expvar.NewString("BuildTime").Set(BuildTime)
    55  }
    56  
    57  // Random stuff
    58  
    59  type randSource interface {
    60  	Read(p []byte) (n int, err error)
    61  }
    62  
    63  // RandReader is used so that it can be replaced in tests that require
    64  // deterministic output
    65  var RandReader randSource = rand.Reader
    66  
    67  // RandomString returns a randomly generated string of the requested length.
    68  func RandomString(byteLength int) string {
    69  	b := make([]byte, byteLength)
    70  	_, err := io.ReadFull(RandReader, b)
    71  	if err != nil {
    72  		panic(fmt.Sprintf("Error reading random bytes: %s", err))
    73  	}
    74  	return base64.RawURLEncoding.EncodeToString(b)
    75  }
    76  
    77  // NewToken produces a random string for Challenges, etc.
    78  func NewToken() string {
    79  	return RandomString(32)
    80  }
    81  
    82  var tokenFormat = regexp.MustCompile(`^[\w-]{43}$`)
    83  
    84  // looksLikeAToken checks whether a string represents a 32-octet value in
    85  // the URL-safe base64 alphabet.
    86  func looksLikeAToken(token string) bool {
    87  	return tokenFormat.MatchString(token)
    88  }
    89  
    90  // Fingerprints
    91  
    92  // Fingerprint256 produces an unpadded, URL-safe Base64-encoded SHA256 digest
    93  // of the data.
    94  func Fingerprint256(data []byte) string {
    95  	d := sha256.New()
    96  	_, _ = d.Write(data) // Never returns an error
    97  	return base64.RawURLEncoding.EncodeToString(d.Sum(nil))
    98  }
    99  
   100  type Sha256Digest [sha256.Size]byte
   101  
   102  // KeyDigest produces the SHA256 digest of a provided public key.
   103  func KeyDigest(key crypto.PublicKey) (Sha256Digest, error) {
   104  	switch t := key.(type) {
   105  	case *jose.JSONWebKey:
   106  		if t == nil {
   107  			return Sha256Digest{}, errors.New("cannot compute digest of nil key")
   108  		}
   109  		return KeyDigest(t.Key)
   110  	case jose.JSONWebKey:
   111  		return KeyDigest(t.Key)
   112  	default:
   113  		keyDER, err := x509.MarshalPKIXPublicKey(key)
   114  		if err != nil {
   115  			return Sha256Digest{}, err
   116  		}
   117  		return sha256.Sum256(keyDER), nil
   118  	}
   119  }
   120  
   121  // KeyDigestB64 produces a padded, standard Base64-encoded SHA256 digest of a
   122  // provided public key.
   123  func KeyDigestB64(key crypto.PublicKey) (string, error) {
   124  	digest, err := KeyDigest(key)
   125  	if err != nil {
   126  		return "", err
   127  	}
   128  	return base64.StdEncoding.EncodeToString(digest[:]), nil
   129  }
   130  
   131  // KeyDigestEquals determines whether two public keys have the same digest.
   132  func KeyDigestEquals(j, k crypto.PublicKey) bool {
   133  	digestJ, errJ := KeyDigestB64(j)
   134  	digestK, errK := KeyDigestB64(k)
   135  	// Keys that don't have a valid digest (due to marshalling problems)
   136  	// are never equal. So, e.g. nil keys are not equal.
   137  	if errJ != nil || errK != nil {
   138  		return false
   139  	}
   140  	return digestJ == digestK
   141  }
   142  
   143  // PublicKeysEqual determines whether two public keys are identical.
   144  func PublicKeysEqual(a, b crypto.PublicKey) (bool, error) {
   145  	switch ak := a.(type) {
   146  	case *rsa.PublicKey:
   147  		return ak.Equal(b), nil
   148  	case *ecdsa.PublicKey:
   149  		return ak.Equal(b), nil
   150  	default:
   151  		return false, fmt.Errorf("unsupported public key type %T", ak)
   152  	}
   153  }
   154  
   155  // SerialToString converts a certificate serial number (big.Int) to a String
   156  // consistently.
   157  func SerialToString(serial *big.Int) string {
   158  	return fmt.Sprintf("%036x", serial)
   159  }
   160  
   161  // StringToSerial converts a string into a certificate serial number (big.Int)
   162  // consistently.
   163  func StringToSerial(serial string) (*big.Int, error) {
   164  	var serialNum big.Int
   165  	if !ValidSerial(serial) {
   166  		return &serialNum, fmt.Errorf("invalid serial number %q", serial)
   167  	}
   168  	_, err := fmt.Sscanf(serial, "%036x", &serialNum)
   169  	return &serialNum, err
   170  }
   171  
   172  // ValidSerial tests whether the input string represents a syntactically
   173  // valid serial number, i.e., that it is a valid hex string between 32
   174  // and 36 characters long.
   175  func ValidSerial(serial string) bool {
   176  	// Originally, serial numbers were 32 hex characters long. We later increased
   177  	// them to 36, but we allow the shorter ones because they exist in some
   178  	// production databases.
   179  	if len(serial) != 32 && len(serial) != 36 {
   180  		return false
   181  	}
   182  	_, err := hex.DecodeString(serial)
   183  	return err == nil
   184  }
   185  
   186  // GetBuildID identifies what build is running.
   187  func GetBuildID() (retID string) {
   188  	retID = BuildID
   189  	if retID == "" {
   190  		retID = Unspecified
   191  	}
   192  	return
   193  }
   194  
   195  // GetBuildTime identifies when this build was made
   196  func GetBuildTime() (retID string) {
   197  	retID = BuildTime
   198  	if retID == "" {
   199  		retID = Unspecified
   200  	}
   201  	return
   202  }
   203  
   204  // GetBuildHost identifies the building host
   205  func GetBuildHost() (retID string) {
   206  	retID = BuildHost
   207  	if retID == "" {
   208  		retID = Unspecified
   209  	}
   210  	return
   211  }
   212  
   213  // IsAnyNilOrZero returns whether any of the supplied values are nil, or (if not)
   214  // if any of them is its type's zero-value. This is useful for validating that
   215  // all required fields on a proto message are present.
   216  func IsAnyNilOrZero(vals ...any) bool {
   217  	for _, val := range vals {
   218  		switch v := val.(type) {
   219  		case nil:
   220  			return true
   221  		case bool:
   222  			if !v {
   223  				return true
   224  			}
   225  		case string:
   226  			if v == "" {
   227  				return true
   228  			}
   229  		case []string:
   230  			if len(v) == 0 {
   231  				return true
   232  			}
   233  		case byte:
   234  			// Byte is an alias for uint8 and will cover that case.
   235  			if v == 0 {
   236  				return true
   237  			}
   238  		case []byte:
   239  			if len(v) == 0 {
   240  				return true
   241  			}
   242  		case int:
   243  			if v == 0 {
   244  				return true
   245  			}
   246  		case int8:
   247  			if v == 0 {
   248  				return true
   249  			}
   250  		case int16:
   251  			if v == 0 {
   252  				return true
   253  			}
   254  		case int32:
   255  			if v == 0 {
   256  				return true
   257  			}
   258  		case int64:
   259  			if v == 0 {
   260  				return true
   261  			}
   262  		case uint:
   263  			if v == 0 {
   264  				return true
   265  			}
   266  		case uint16:
   267  			if v == 0 {
   268  				return true
   269  			}
   270  		case uint32:
   271  			if v == 0 {
   272  				return true
   273  			}
   274  		case uint64:
   275  			if v == 0 {
   276  				return true
   277  			}
   278  		case float32:
   279  			if v == 0 {
   280  				return true
   281  			}
   282  		case float64:
   283  			if v == 0 {
   284  				return true
   285  			}
   286  		case time.Time:
   287  			if v.IsZero() {
   288  				return true
   289  			}
   290  		case *timestamppb.Timestamp:
   291  			if v == nil || v.AsTime().IsZero() {
   292  				return true
   293  			}
   294  		case *durationpb.Duration:
   295  			if v == nil || v.AsDuration() == time.Duration(0) {
   296  				return true
   297  			}
   298  		default:
   299  			if reflect.ValueOf(v).IsZero() {
   300  				return true
   301  			}
   302  		}
   303  	}
   304  	return false
   305  }
   306  
   307  // UniqueLowerNames returns the set of all unique names in the input after all
   308  // of them are lowercased. The returned names will be in their lowercased form
   309  // and sorted alphabetically.
   310  func UniqueLowerNames(names []string) (unique []string) {
   311  	nameMap := make(map[string]int, len(names))
   312  	for _, name := range names {
   313  		nameMap[strings.ToLower(name)] = 1
   314  	}
   315  
   316  	unique = make([]string, 0, len(nameMap))
   317  	for name := range nameMap {
   318  		unique = append(unique, name)
   319  	}
   320  	sort.Strings(unique)
   321  	return
   322  }
   323  
   324  // HashIdentifiers returns a hash of the identifiers requested. This is intended
   325  // for use when interacting with the orderFqdnSets table and rate limiting.
   326  func HashIdentifiers(idents identifier.ACMEIdentifiers) []byte {
   327  	var values []string
   328  	for _, ident := range identifier.Normalize(idents) {
   329  		values = append(values, ident.Value)
   330  	}
   331  
   332  	hash := sha256.Sum256([]byte(strings.Join(values, ",")))
   333  	return hash[:]
   334  }
   335  
   336  // LoadCert loads a PEM certificate specified by filename or returns an error
   337  func LoadCert(filename string) (*x509.Certificate, error) {
   338  	certPEM, err := os.ReadFile(filename)
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  	block, _ := pem.Decode(certPEM)
   343  	if block == nil {
   344  		return nil, fmt.Errorf("no data in cert PEM file %q", filename)
   345  	}
   346  	cert, err := x509.ParseCertificate(block.Bytes)
   347  	if err != nil {
   348  		return nil, err
   349  	}
   350  	return cert, nil
   351  }
   352  
   353  // retryJitter is used to prevent bunched retried queries from falling into lockstep
   354  const retryJitter = 0.2
   355  
   356  // RetryBackoff calculates a backoff time based on number of retries, will always
   357  // add jitter so requests that start in unison won't fall into lockstep. Because of
   358  // this the returned duration can always be larger than the maximum by a factor of
   359  // retryJitter. Adapted from
   360  // https://github.com/grpc/grpc-go/blob/v1.11.3/backoff.go#L77-L96
   361  func RetryBackoff(retries int, base, max time.Duration, factor float64) time.Duration {
   362  	if retries == 0 {
   363  		return 0
   364  	}
   365  	backoff, fMax := float64(base), float64(max)
   366  	for backoff < fMax && retries > 1 {
   367  		backoff *= factor
   368  		retries--
   369  	}
   370  	if backoff > fMax {
   371  		backoff = fMax
   372  	}
   373  	// Randomize backoff delays so that if a cluster of requests start at
   374  	// the same time, they won't operate in lockstep.
   375  	backoff *= (1 - retryJitter) + 2*retryJitter*mrand.Float64()
   376  	return time.Duration(backoff)
   377  }
   378  
   379  // IsASCII determines if every character in a string is encoded in
   380  // the ASCII character set.
   381  func IsASCII(str string) bool {
   382  	for _, r := range str {
   383  		if r > unicode.MaxASCII {
   384  			return false
   385  		}
   386  	}
   387  	return true
   388  }
   389  
   390  // IsCanceled returns true if err is non-nil and is either context.Canceled, or
   391  // has a grpc code of Canceled. This is useful because cancellations propagate
   392  // through gRPC boundaries, and if we choose to treat in-process cancellations a
   393  // certain way, we usually want to treat cross-process cancellations the same way.
   394  func IsCanceled(err error) bool {
   395  	return errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled
   396  }
   397  
   398  func Command() string {
   399  	return path.Base(os.Args[0])
   400  }