github.com/likebike/go--@v0.0.0-20190911215757-0bd925d16e96/go/src/encoding/asn1/marshal.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package asn1
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"math/big"
    11  	"reflect"
    12  	"time"
    13  	"unicode/utf8"
    14  )
    15  
    16  var (
    17  	byte00Encoder encoder = byteEncoder(0x00)
    18  	byteFFEncoder encoder = byteEncoder(0xff)
    19  )
    20  
    21  // encoder represents an ASN.1 element that is waiting to be marshaled.
    22  type encoder interface {
    23  	// Len returns the number of bytes needed to marshal this element.
    24  	Len() int
    25  	// Encode encodes this element by writing Len() bytes to dst.
    26  	Encode(dst []byte)
    27  }
    28  
    29  type byteEncoder byte
    30  
    31  func (c byteEncoder) Len() int {
    32  	return 1
    33  }
    34  
    35  func (c byteEncoder) Encode(dst []byte) {
    36  	dst[0] = byte(c)
    37  }
    38  
    39  type bytesEncoder []byte
    40  
    41  func (b bytesEncoder) Len() int {
    42  	return len(b)
    43  }
    44  
    45  func (b bytesEncoder) Encode(dst []byte) {
    46  	if copy(dst, b) != len(b) {
    47  		panic("internal error")
    48  	}
    49  }
    50  
    51  type stringEncoder string
    52  
    53  func (s stringEncoder) Len() int {
    54  	return len(s)
    55  }
    56  
    57  func (s stringEncoder) Encode(dst []byte) {
    58  	if copy(dst, s) != len(s) {
    59  		panic("internal error")
    60  	}
    61  }
    62  
    63  type multiEncoder []encoder
    64  
    65  func (m multiEncoder) Len() int {
    66  	var size int
    67  	for _, e := range m {
    68  		size += e.Len()
    69  	}
    70  	return size
    71  }
    72  
    73  func (m multiEncoder) Encode(dst []byte) {
    74  	var off int
    75  	for _, e := range m {
    76  		e.Encode(dst[off:])
    77  		off += e.Len()
    78  	}
    79  }
    80  
    81  type taggedEncoder struct {
    82  	// scratch contains temporary space for encoding the tag and length of
    83  	// an element in order to avoid extra allocations.
    84  	scratch [8]byte
    85  	tag     encoder
    86  	body    encoder
    87  }
    88  
    89  func (t *taggedEncoder) Len() int {
    90  	return t.tag.Len() + t.body.Len()
    91  }
    92  
    93  func (t *taggedEncoder) Encode(dst []byte) {
    94  	t.tag.Encode(dst)
    95  	t.body.Encode(dst[t.tag.Len():])
    96  }
    97  
    98  type int64Encoder int64
    99  
   100  func (i int64Encoder) Len() int {
   101  	n := 1
   102  
   103  	for i > 127 {
   104  		n++
   105  		i >>= 8
   106  	}
   107  
   108  	for i < -128 {
   109  		n++
   110  		i >>= 8
   111  	}
   112  
   113  	return n
   114  }
   115  
   116  func (i int64Encoder) Encode(dst []byte) {
   117  	n := i.Len()
   118  
   119  	for j := 0; j < n; j++ {
   120  		dst[j] = byte(i >> uint((n-1-j)*8))
   121  	}
   122  }
   123  
   124  func base128IntLength(n int64) int {
   125  	if n == 0 {
   126  		return 1
   127  	}
   128  
   129  	l := 0
   130  	for i := n; i > 0; i >>= 7 {
   131  		l++
   132  	}
   133  
   134  	return l
   135  }
   136  
   137  func appendBase128Int(dst []byte, n int64) []byte {
   138  	l := base128IntLength(n)
   139  
   140  	for i := l - 1; i >= 0; i-- {
   141  		o := byte(n >> uint(i*7))
   142  		o &= 0x7f
   143  		if i != 0 {
   144  			o |= 0x80
   145  		}
   146  
   147  		dst = append(dst, o)
   148  	}
   149  
   150  	return dst
   151  }
   152  
   153  func makeBigInt(n *big.Int) (encoder, error) {
   154  	if n == nil {
   155  		return nil, StructuralError{"empty integer"}
   156  	}
   157  
   158  	if n.Sign() < 0 {
   159  		// A negative number has to be converted to two's-complement
   160  		// form. So we'll invert and subtract 1. If the
   161  		// most-significant-bit isn't set then we'll need to pad the
   162  		// beginning with 0xff in order to keep the number negative.
   163  		nMinus1 := new(big.Int).Neg(n)
   164  		nMinus1.Sub(nMinus1, bigOne)
   165  		bytes := nMinus1.Bytes()
   166  		for i := range bytes {
   167  			bytes[i] ^= 0xff
   168  		}
   169  		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
   170  			return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
   171  		}
   172  		return bytesEncoder(bytes), nil
   173  	} else if n.Sign() == 0 {
   174  		// Zero is written as a single 0 zero rather than no bytes.
   175  		return byte00Encoder, nil
   176  	} else {
   177  		bytes := n.Bytes()
   178  		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
   179  			// We'll have to pad this with 0x00 in order to stop it
   180  			// looking like a negative number.
   181  			return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
   182  		}
   183  		return bytesEncoder(bytes), nil
   184  	}
   185  }
   186  
   187  func appendLength(dst []byte, i int) []byte {
   188  	n := lengthLength(i)
   189  
   190  	for ; n > 0; n-- {
   191  		dst = append(dst, byte(i>>uint((n-1)*8)))
   192  	}
   193  
   194  	return dst
   195  }
   196  
   197  func lengthLength(i int) (numBytes int) {
   198  	numBytes = 1
   199  	for i > 255 {
   200  		numBytes++
   201  		i >>= 8
   202  	}
   203  	return
   204  }
   205  
   206  func appendTagAndLength(dst []byte, t tagAndLength) []byte {
   207  	b := uint8(t.class) << 6
   208  	if t.isCompound {
   209  		b |= 0x20
   210  	}
   211  	if t.tag >= 31 {
   212  		b |= 0x1f
   213  		dst = append(dst, b)
   214  		dst = appendBase128Int(dst, int64(t.tag))
   215  	} else {
   216  		b |= uint8(t.tag)
   217  		dst = append(dst, b)
   218  	}
   219  
   220  	if t.length >= 128 {
   221  		l := lengthLength(t.length)
   222  		dst = append(dst, 0x80|byte(l))
   223  		dst = appendLength(dst, t.length)
   224  	} else {
   225  		dst = append(dst, byte(t.length))
   226  	}
   227  
   228  	return dst
   229  }
   230  
   231  type bitStringEncoder BitString
   232  
   233  func (b bitStringEncoder) Len() int {
   234  	return len(b.Bytes) + 1
   235  }
   236  
   237  func (b bitStringEncoder) Encode(dst []byte) {
   238  	dst[0] = byte((8 - b.BitLength%8) % 8)
   239  	if copy(dst[1:], b.Bytes) != len(b.Bytes) {
   240  		panic("internal error")
   241  	}
   242  }
   243  
   244  type oidEncoder []int
   245  
   246  func (oid oidEncoder) Len() int {
   247  	l := base128IntLength(int64(oid[0]*40 + oid[1]))
   248  	for i := 2; i < len(oid); i++ {
   249  		l += base128IntLength(int64(oid[i]))
   250  	}
   251  	return l
   252  }
   253  
   254  func (oid oidEncoder) Encode(dst []byte) {
   255  	dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
   256  	for i := 2; i < len(oid); i++ {
   257  		dst = appendBase128Int(dst, int64(oid[i]))
   258  	}
   259  }
   260  
   261  func makeObjectIdentifier(oid []int) (e encoder, err error) {
   262  	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
   263  		return nil, StructuralError{"invalid object identifier"}
   264  	}
   265  
   266  	return oidEncoder(oid), nil
   267  }
   268  
   269  func makePrintableString(s string) (e encoder, err error) {
   270  	for i := 0; i < len(s); i++ {
   271  		// The asterisk is often used in PrintableString, even though
   272  		// it is invalid. If a PrintableString was specifically
   273  		// requested then the asterisk is permitted by this code.
   274  		// Ampersand is allowed in parsing due a handful of CA
   275  		// certificates, however when making new certificates
   276  		// it is rejected.
   277  		if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
   278  			return nil, StructuralError{"PrintableString contains invalid character"}
   279  		}
   280  	}
   281  
   282  	return stringEncoder(s), nil
   283  }
   284  
   285  func makeIA5String(s string) (e encoder, err error) {
   286  	for i := 0; i < len(s); i++ {
   287  		if s[i] > 127 {
   288  			return nil, StructuralError{"IA5String contains invalid character"}
   289  		}
   290  	}
   291  
   292  	return stringEncoder(s), nil
   293  }
   294  
   295  func makeNumericString(s string) (e encoder, err error) {
   296  	for i := 0; i < len(s); i++ {
   297  		if !isNumeric(s[i]) {
   298  			return nil, StructuralError{"NumericString contains invalid character"}
   299  		}
   300  	}
   301  
   302  	return stringEncoder(s), nil
   303  }
   304  
   305  func makeUTF8String(s string) encoder {
   306  	return stringEncoder(s)
   307  }
   308  
   309  func appendTwoDigits(dst []byte, v int) []byte {
   310  	return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
   311  }
   312  
   313  func appendFourDigits(dst []byte, v int) []byte {
   314  	var bytes [4]byte
   315  	for i := range bytes {
   316  		bytes[3-i] = '0' + byte(v%10)
   317  		v /= 10
   318  	}
   319  	return append(dst, bytes[:]...)
   320  }
   321  
   322  func outsideUTCRange(t time.Time) bool {
   323  	year := t.Year()
   324  	return year < 1950 || year >= 2050
   325  }
   326  
   327  func makeUTCTime(t time.Time) (e encoder, err error) {
   328  	dst := make([]byte, 0, 18)
   329  
   330  	dst, err = appendUTCTime(dst, t)
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  
   335  	return bytesEncoder(dst), nil
   336  }
   337  
   338  func makeGeneralizedTime(t time.Time) (e encoder, err error) {
   339  	dst := make([]byte, 0, 20)
   340  
   341  	dst, err = appendGeneralizedTime(dst, t)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	return bytesEncoder(dst), nil
   347  }
   348  
   349  func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
   350  	year := t.Year()
   351  
   352  	switch {
   353  	case 1950 <= year && year < 2000:
   354  		dst = appendTwoDigits(dst, year-1900)
   355  	case 2000 <= year && year < 2050:
   356  		dst = appendTwoDigits(dst, year-2000)
   357  	default:
   358  		return nil, StructuralError{"cannot represent time as UTCTime"}
   359  	}
   360  
   361  	return appendTimeCommon(dst, t), nil
   362  }
   363  
   364  func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
   365  	year := t.Year()
   366  	if year < 0 || year > 9999 {
   367  		return nil, StructuralError{"cannot represent time as GeneralizedTime"}
   368  	}
   369  
   370  	dst = appendFourDigits(dst, year)
   371  
   372  	return appendTimeCommon(dst, t), nil
   373  }
   374  
   375  func appendTimeCommon(dst []byte, t time.Time) []byte {
   376  	_, month, day := t.Date()
   377  
   378  	dst = appendTwoDigits(dst, int(month))
   379  	dst = appendTwoDigits(dst, day)
   380  
   381  	hour, min, sec := t.Clock()
   382  
   383  	dst = appendTwoDigits(dst, hour)
   384  	dst = appendTwoDigits(dst, min)
   385  	dst = appendTwoDigits(dst, sec)
   386  
   387  	_, offset := t.Zone()
   388  
   389  	switch {
   390  	case offset/60 == 0:
   391  		return append(dst, 'Z')
   392  	case offset > 0:
   393  		dst = append(dst, '+')
   394  	case offset < 0:
   395  		dst = append(dst, '-')
   396  	}
   397  
   398  	offsetMinutes := offset / 60
   399  	if offsetMinutes < 0 {
   400  		offsetMinutes = -offsetMinutes
   401  	}
   402  
   403  	dst = appendTwoDigits(dst, offsetMinutes/60)
   404  	dst = appendTwoDigits(dst, offsetMinutes%60)
   405  
   406  	return dst
   407  }
   408  
   409  func stripTagAndLength(in []byte) []byte {
   410  	_, offset, err := parseTagAndLength(in, 0)
   411  	if err != nil {
   412  		return in
   413  	}
   414  	return in[offset:]
   415  }
   416  
   417  func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
   418  	switch value.Type() {
   419  	case flagType:
   420  		return bytesEncoder(nil), nil
   421  	case timeType:
   422  		t := value.Interface().(time.Time)
   423  		if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
   424  			return makeGeneralizedTime(t)
   425  		}
   426  		return makeUTCTime(t)
   427  	case bitStringType:
   428  		return bitStringEncoder(value.Interface().(BitString)), nil
   429  	case objectIdentifierType:
   430  		return makeObjectIdentifier(value.Interface().(ObjectIdentifier))
   431  	case bigIntType:
   432  		return makeBigInt(value.Interface().(*big.Int))
   433  	}
   434  
   435  	switch v := value; v.Kind() {
   436  	case reflect.Bool:
   437  		if v.Bool() {
   438  			return byteFFEncoder, nil
   439  		}
   440  		return byte00Encoder, nil
   441  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   442  		return int64Encoder(v.Int()), nil
   443  	case reflect.Struct:
   444  		t := v.Type()
   445  
   446  		for i := 0; i < t.NumField(); i++ {
   447  			if t.Field(i).PkgPath != "" {
   448  				return nil, StructuralError{"struct contains unexported fields"}
   449  			}
   450  		}
   451  
   452  		startingField := 0
   453  
   454  		n := t.NumField()
   455  		if n == 0 {
   456  			return bytesEncoder(nil), nil
   457  		}
   458  
   459  		// If the first element of the structure is a non-empty
   460  		// RawContents, then we don't bother serializing the rest.
   461  		if t.Field(0).Type == rawContentsType {
   462  			s := v.Field(0)
   463  			if s.Len() > 0 {
   464  				bytes := s.Bytes()
   465  				/* The RawContents will contain the tag and
   466  				 * length fields but we'll also be writing
   467  				 * those ourselves, so we strip them out of
   468  				 * bytes */
   469  				return bytesEncoder(stripTagAndLength(bytes)), nil
   470  			}
   471  
   472  			startingField = 1
   473  		}
   474  
   475  		switch n1 := n - startingField; n1 {
   476  		case 0:
   477  			return bytesEncoder(nil), nil
   478  		case 1:
   479  			return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
   480  		default:
   481  			m := make([]encoder, n1)
   482  			for i := 0; i < n1; i++ {
   483  				m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
   484  				if err != nil {
   485  					return nil, err
   486  				}
   487  			}
   488  
   489  			return multiEncoder(m), nil
   490  		}
   491  	case reflect.Slice:
   492  		sliceType := v.Type()
   493  		if sliceType.Elem().Kind() == reflect.Uint8 {
   494  			return bytesEncoder(v.Bytes()), nil
   495  		}
   496  
   497  		var fp fieldParameters
   498  
   499  		switch l := v.Len(); l {
   500  		case 0:
   501  			return bytesEncoder(nil), nil
   502  		case 1:
   503  			return makeField(v.Index(0), fp)
   504  		default:
   505  			m := make([]encoder, l)
   506  
   507  			for i := 0; i < l; i++ {
   508  				m[i], err = makeField(v.Index(i), fp)
   509  				if err != nil {
   510  					return nil, err
   511  				}
   512  			}
   513  
   514  			return multiEncoder(m), nil
   515  		}
   516  	case reflect.String:
   517  		switch params.stringType {
   518  		case TagIA5String:
   519  			return makeIA5String(v.String())
   520  		case TagPrintableString:
   521  			return makePrintableString(v.String())
   522  		case TagNumericString:
   523  			return makeNumericString(v.String())
   524  		default:
   525  			return makeUTF8String(v.String()), nil
   526  		}
   527  	}
   528  
   529  	return nil, StructuralError{"unknown Go type"}
   530  }
   531  
   532  func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
   533  	if !v.IsValid() {
   534  		return nil, fmt.Errorf("asn1: cannot marshal nil value")
   535  	}
   536  	// If the field is an interface{} then recurse into it.
   537  	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
   538  		return makeField(v.Elem(), params)
   539  	}
   540  
   541  	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
   542  		return bytesEncoder(nil), nil
   543  	}
   544  
   545  	if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
   546  		defaultValue := reflect.New(v.Type()).Elem()
   547  		defaultValue.SetInt(*params.defaultValue)
   548  
   549  		if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
   550  			return bytesEncoder(nil), nil
   551  		}
   552  	}
   553  
   554  	// If no default value is given then the zero value for the type is
   555  	// assumed to be the default value. This isn't obviously the correct
   556  	// behavior, but it's what Go has traditionally done.
   557  	if params.optional && params.defaultValue == nil {
   558  		if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
   559  			return bytesEncoder(nil), nil
   560  		}
   561  	}
   562  
   563  	if v.Type() == rawValueType {
   564  		rv := v.Interface().(RawValue)
   565  		if len(rv.FullBytes) != 0 {
   566  			return bytesEncoder(rv.FullBytes), nil
   567  		}
   568  
   569  		t := new(taggedEncoder)
   570  
   571  		t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
   572  		t.body = bytesEncoder(rv.Bytes)
   573  
   574  		return t, nil
   575  	}
   576  
   577  	matchAny, tag, isCompound, ok := getUniversalType(v.Type())
   578  	if !ok || matchAny {
   579  		return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
   580  	}
   581  
   582  	if params.timeType != 0 && tag != TagUTCTime {
   583  		return nil, StructuralError{"explicit time type given to non-time member"}
   584  	}
   585  
   586  	if params.stringType != 0 && tag != TagPrintableString {
   587  		return nil, StructuralError{"explicit string type given to non-string member"}
   588  	}
   589  
   590  	switch tag {
   591  	case TagPrintableString:
   592  		if params.stringType == 0 {
   593  			// This is a string without an explicit string type. We'll use
   594  			// a PrintableString if the character set in the string is
   595  			// sufficiently limited, otherwise we'll use a UTF8String.
   596  			for _, r := range v.String() {
   597  				if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
   598  					if !utf8.ValidString(v.String()) {
   599  						return nil, errors.New("asn1: string not valid UTF-8")
   600  					}
   601  					tag = TagUTF8String
   602  					break
   603  				}
   604  			}
   605  		} else {
   606  			tag = params.stringType
   607  		}
   608  	case TagUTCTime:
   609  		if params.timeType == TagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
   610  			tag = TagGeneralizedTime
   611  		}
   612  	}
   613  
   614  	if params.set {
   615  		if tag != TagSequence {
   616  			return nil, StructuralError{"non sequence tagged as set"}
   617  		}
   618  		tag = TagSet
   619  	}
   620  
   621  	t := new(taggedEncoder)
   622  
   623  	t.body, err = makeBody(v, params)
   624  	if err != nil {
   625  		return nil, err
   626  	}
   627  
   628  	bodyLen := t.body.Len()
   629  
   630  	class := ClassUniversal
   631  	if params.tag != nil {
   632  		if params.application {
   633  			class = ClassApplication
   634  		} else {
   635  			class = ClassContextSpecific
   636  		}
   637  
   638  		if params.explicit {
   639  			t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
   640  
   641  			tt := new(taggedEncoder)
   642  
   643  			tt.body = t
   644  
   645  			tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
   646  				class:      class,
   647  				tag:        *params.tag,
   648  				length:     bodyLen + t.tag.Len(),
   649  				isCompound: true,
   650  			}))
   651  
   652  			return tt, nil
   653  		}
   654  
   655  		// implicit tag.
   656  		tag = *params.tag
   657  	}
   658  
   659  	t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
   660  
   661  	return t, nil
   662  }
   663  
   664  // Marshal returns the ASN.1 encoding of val.
   665  //
   666  // In addition to the struct tags recognised by Unmarshal, the following can be
   667  // used:
   668  //
   669  //	ia5:         causes strings to be marshaled as ASN.1, IA5String values
   670  //	omitempty:   causes empty slices to be skipped
   671  //	printable:   causes strings to be marshaled as ASN.1, PrintableString values
   672  //	utf8:        causes strings to be marshaled as ASN.1, UTF8String values
   673  //	utc:         causes time.Time to be marshaled as ASN.1, UTCTime values
   674  //	generalized: causes time.Time to be marshaled as ASN.1, GeneralizedTime values
   675  func Marshal(val interface{}) ([]byte, error) {
   676  	return MarshalWithParams(val, "")
   677  }
   678  
   679  // MarshalWithParams allows field parameters to be specified for the
   680  // top-level element. The form of the params is the same as the field tags.
   681  func MarshalWithParams(val interface{}, params string) ([]byte, error) {
   682  	e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
   683  	if err != nil {
   684  		return nil, err
   685  	}
   686  	b := make([]byte, e.Len())
   687  	e.Encode(b)
   688  	return b, nil
   689  }