github.com/Venafi/vcert/v5@v5.10.2/pkg/playbook/app/vcertutil/helper.go (about)

     1  /*
     2   * Copyright 2023 Venafi, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *  http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package vcertutil
    18  
    19  import (
    20  	"encoding/pem"
    21  	"fmt"
    22  	"net"
    23  	"net/url"
    24  	"os"
    25  	"strconv"
    26  	"strings"
    27  	"time"
    28  
    29  	"go.uber.org/zap"
    30  
    31  	"github.com/Venafi/vcert/v5/pkg/certificate"
    32  	"github.com/Venafi/vcert/v5/pkg/playbook/app/domain"
    33  	"github.com/Venafi/vcert/v5/pkg/util"
    34  )
    35  
    36  const (
    37  	// DefaultRSALength represents the default length of an RSA Private Key
    38  	DefaultRSALength = 2048
    39  
    40  	// DefaultTimeout represents the default time in seconds vcert will try to retrieve a certificate
    41  	DefaultTimeout = 180
    42  
    43  	// OriginName represents the Origin of the Request set in a Custom Field
    44  	OriginName = "Venafi VCert Playbook"
    45  
    46  	filePrefix = "file:"
    47  )
    48  
    49  func loadTrustBundle(path string) string {
    50  	if path != "" {
    51  		buf, err := os.ReadFile(path)
    52  		if err != nil {
    53  			zap.L().Fatal("could not read TrustBundle", zap.String("location", path), zap.Error(err))
    54  		}
    55  		return string(buf)
    56  	}
    57  	return ""
    58  }
    59  
    60  func getIPAddresses(ips []string) []net.IP {
    61  	netIps := make([]net.IP, 0)
    62  	for _, ipStr := range ips {
    63  		ip := net.ParseIP(ipStr)
    64  		if ip != nil {
    65  			netIps = append(netIps, ip)
    66  		}
    67  	}
    68  	return netIps
    69  }
    70  
    71  func getURIs(uris []string) []*url.URL {
    72  	urls := make([]*url.URL, 0)
    73  
    74  	for _, uriStr := range uris {
    75  		uri, err := url.Parse(uriStr)
    76  		if err != nil {
    77  			zap.L().Error("could not parse URI", zap.String("uri", uriStr), zap.Error(err))
    78  			continue
    79  		}
    80  		urls = append(urls, uri)
    81  	}
    82  	return urls
    83  }
    84  
    85  func setKeyType(request domain.PlaybookRequest, vcertRequest *certificate.Request) {
    86  	switch request.KeyType {
    87  	case certificate.KeyTypeRSA:
    88  		vcertRequest.KeyType = request.KeyType
    89  		if request.KeyLength <= 0 {
    90  			vcertRequest.KeyLength = DefaultRSALength
    91  		} else {
    92  			vcertRequest.KeyLength = request.KeyLength
    93  		}
    94  	case certificate.KeyTypeECDSA:
    95  		vcertRequest.KeyType = request.KeyType
    96  		vcertRequest.KeyCurve = request.KeyCurve
    97  	case certificate.KeyTypeED25519:
    98  		vcertRequest.KeyType = request.KeyType
    99  		vcertRequest.KeyCurve = certificate.EllipticCurveED25519
   100  	default:
   101  		vcertRequest.KeyType = certificate.KeyTypeRSA
   102  		vcertRequest.KeyLength = DefaultRSALength
   103  	}
   104  }
   105  
   106  func setOrigin(request domain.PlaybookRequest, vcertRequest *certificate.Request) {
   107  	origin := OriginName
   108  	if request.Origin != "" {
   109  		origin = request.Origin
   110  	}
   111  	originCustomField := certificate.CustomField{
   112  		Name:  "Origin",
   113  		Value: origin,
   114  		Type:  certificate.CustomFieldOrigin,
   115  	}
   116  	vcertRequest.CustomFields = append(vcertRequest.CustomFields, originCustomField)
   117  
   118  }
   119  
   120  func setValidity(request domain.PlaybookRequest, vcertRequest *certificate.Request) {
   121  	if request.ValidDays == "" {
   122  		return
   123  	}
   124  
   125  	data := strings.Split(request.ValidDays, "#")
   126  	days, _ := strconv.ParseInt(data[0], 10, 64)
   127  	hours := days * 24
   128  
   129  	vcertRequest.ValidityHours = int(hours) //nolint:staticcheck
   130  
   131  	var issuerHint util.IssuerHint
   132  	if len(data) > 1 { // means that issuer hint is set
   133  		option := strings.ToLower(data[1])
   134  		switch option {
   135  		case "m":
   136  			issuerHint = util.IssuerHintMicrosoft
   137  		case "d":
   138  			issuerHint = util.IssuerHintDigicert
   139  		case "e":
   140  			issuerHint = util.IssuerHintEntrust
   141  		}
   142  	}
   143  	vcertRequest.IssuerHint = issuerHint
   144  
   145  	// If IssuerHint is declared in playbook, override issuerHint from validDays string
   146  	if request.IssuerHint != util.IssuerHintGeneric {
   147  		vcertRequest.IssuerHint = request.IssuerHint
   148  	}
   149  }
   150  
   151  func setLocationWorkload(playbookRequest domain.PlaybookRequest, vcertRequest *certificate.Request) {
   152  	if playbookRequest.Location.Instance == "" {
   153  		return
   154  	}
   155  
   156  	segments := strings.Split(playbookRequest.Location.Instance, ":")
   157  	instance := segments[0]
   158  	workload := ""
   159  	// take workload from instance string
   160  	if len(segments) > 1 {
   161  		workload = segments[1]
   162  	}
   163  	// take workload from attribute.
   164  	// workload attribute has priority over workload string declared in request.Location.Instance
   165  	if playbookRequest.Location.Workload != "" {
   166  		workload = playbookRequest.Location.Workload
   167  	}
   168  
   169  	newLocation := certificate.Location{
   170  		Instance:   instance,
   171  		Workload:   workload,
   172  		TLSAddress: playbookRequest.Location.TLSAddress,
   173  		Replace:    playbookRequest.Location.Replace,
   174  		Zone:       playbookRequest.Location.Zone,
   175  	}
   176  	vcertRequest.Location = &newLocation
   177  }
   178  
   179  func setTimeout(playbookRequest domain.PlaybookRequest, vcertRequest *certificate.Request) {
   180  	timeout := DefaultTimeout
   181  	if playbookRequest.Timeout > 0 {
   182  		timeout = playbookRequest.Timeout
   183  	}
   184  	vcertRequest.Timeout = time.Duration(timeout) * time.Second
   185  }
   186  
   187  func setCSR(playbookRequest domain.PlaybookRequest, vcertRequest *certificate.Request) {
   188  	vcertRequest.CsrOrigin = certificate.LocalGeneratedCSR
   189  
   190  	//CSR is user provided. Load CSR from file
   191  	if strings.HasPrefix(playbookRequest.CsrOrigin, filePrefix) {
   192  		file := playbookRequest.CsrOrigin[len(filePrefix):]
   193  		csr, err := readCSRFromFile(file)
   194  		if err != nil {
   195  			zap.L().Warn("failed to read CSR from file", zap.String("file", file), zap.Error(err))
   196  			vcertRequest.CsrOrigin = certificate.LocalGeneratedCSR
   197  			return
   198  		}
   199  		err = vcertRequest.SetCSR(csr)
   200  		if err != nil {
   201  			zap.L().Warn("failed to set CSR", zap.Error(err))
   202  			vcertRequest.CsrOrigin = certificate.LocalGeneratedCSR
   203  			return
   204  		}
   205  
   206  		vcertRequest.CsrOrigin = certificate.UserProvidedCSR
   207  		return
   208  	}
   209  
   210  	origin := certificate.ParseCSROrigin(playbookRequest.CsrOrigin)
   211  	if origin == certificate.UnknownCSR {
   212  		vcertRequest.CsrOrigin = certificate.LocalGeneratedCSR
   213  	} else {
   214  		vcertRequest.CsrOrigin = origin
   215  	}
   216  }
   217  
   218  func readCSRFromFile(fileName string) ([]byte, error) {
   219  	bytes, err := readFile(fileName)
   220  	if err != nil {
   221  		return nil, err
   222  	}
   223  	for {
   224  		block, rest := pem.Decode(bytes)
   225  		if block != nil && strings.HasSuffix(block.Type, "CERTIFICATE REQUEST") {
   226  			return pem.EncodeToMemory(block), nil
   227  		}
   228  		if block == nil || len(rest) == 0 {
   229  			return nil, fmt.Errorf("failed to find CSR in file: %s", fileName)
   230  		}
   231  		bytes = rest
   232  	}
   233  }
   234  
   235  func readFile(fileName string) ([]byte, error) {
   236  	bytes, err := os.ReadFile(fileName)
   237  	if err != nil {
   238  		return bytes, err
   239  	}
   240  	return bytes, nil
   241  }