github.com/zmap/zcrypto@v0.0.0-20240512203510-0fef58d9a9db/ct/asn1/marshal.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package asn1
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/big"
    13  	"reflect"
    14  	"time"
    15  	"unicode/utf8"
    16  )
    17  
    18  // A forkableWriter is an in-memory buffer that can be
    19  // 'forked' to create new forkableWriters that bracket the
    20  // original.  After
    21  //
    22  //	pre, post := w.fork();
    23  //
    24  // the overall sequence of bytes represented is logically w+pre+post.
    25  type forkableWriter struct {
    26  	*bytes.Buffer
    27  	pre, post *forkableWriter
    28  }
    29  
    30  func newForkableWriter() *forkableWriter {
    31  	return &forkableWriter{new(bytes.Buffer), nil, nil}
    32  }
    33  
    34  func (f *forkableWriter) fork() (pre, post *forkableWriter) {
    35  	if f.pre != nil || f.post != nil {
    36  		panic("have already forked")
    37  	}
    38  	f.pre = newForkableWriter()
    39  	f.post = newForkableWriter()
    40  	return f.pre, f.post
    41  }
    42  
    43  func (f *forkableWriter) Len() (l int) {
    44  	l += f.Buffer.Len()
    45  	if f.pre != nil {
    46  		l += f.pre.Len()
    47  	}
    48  	if f.post != nil {
    49  		l += f.post.Len()
    50  	}
    51  	return
    52  }
    53  
    54  func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
    55  	n, err = out.Write(f.Bytes())
    56  	if err != nil {
    57  		return
    58  	}
    59  
    60  	var nn int
    61  
    62  	if f.pre != nil {
    63  		nn, err = f.pre.writeTo(out)
    64  		n += nn
    65  		if err != nil {
    66  			return
    67  		}
    68  	}
    69  
    70  	if f.post != nil {
    71  		nn, err = f.post.writeTo(out)
    72  		n += nn
    73  	}
    74  	return
    75  }
    76  
    77  func marshalBase128Int(out *forkableWriter, n int64) (err error) {
    78  	if n == 0 {
    79  		err = out.WriteByte(0)
    80  		return
    81  	}
    82  
    83  	l := 0
    84  	for i := n; i > 0; i >>= 7 {
    85  		l++
    86  	}
    87  
    88  	for i := l - 1; i >= 0; i-- {
    89  		o := byte(n >> uint(i*7))
    90  		o &= 0x7f
    91  		if i != 0 {
    92  			o |= 0x80
    93  		}
    94  		err = out.WriteByte(o)
    95  		if err != nil {
    96  			return
    97  		}
    98  	}
    99  
   100  	return nil
   101  }
   102  
   103  func marshalInt64(out *forkableWriter, i int64) (err error) {
   104  	n := int64Length(i)
   105  
   106  	for ; n > 0; n-- {
   107  		err = out.WriteByte(byte(i >> uint((n-1)*8)))
   108  		if err != nil {
   109  			return
   110  		}
   111  	}
   112  
   113  	return nil
   114  }
   115  
   116  func int64Length(i int64) (numBytes int) {
   117  	numBytes = 1
   118  
   119  	for i > 127 {
   120  		numBytes++
   121  		i >>= 8
   122  	}
   123  
   124  	for i < -128 {
   125  		numBytes++
   126  		i >>= 8
   127  	}
   128  
   129  	return
   130  }
   131  
   132  func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
   133  	if n.Sign() < 0 {
   134  		// A negative number has to be converted to two's-complement
   135  		// form. So we'll subtract 1 and invert. If the
   136  		// most-significant-bit isn't set then we'll need to pad the
   137  		// beginning with 0xff in order to keep the number negative.
   138  		nMinus1 := new(big.Int).Neg(n)
   139  		nMinus1.Sub(nMinus1, bigOne)
   140  		bytes := nMinus1.Bytes()
   141  		for i := range bytes {
   142  			bytes[i] ^= 0xff
   143  		}
   144  		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
   145  			err = out.WriteByte(0xff)
   146  			if err != nil {
   147  				return
   148  			}
   149  		}
   150  		_, err = out.Write(bytes)
   151  	} else if n.Sign() == 0 {
   152  		// Zero is written as a single 0 zero rather than no bytes.
   153  		err = out.WriteByte(0x00)
   154  	} else {
   155  		bytes := n.Bytes()
   156  		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
   157  			// We'll have to pad this with 0x00 in order to stop it
   158  			// looking like a negative number.
   159  			err = out.WriteByte(0)
   160  			if err != nil {
   161  				return
   162  			}
   163  		}
   164  		_, err = out.Write(bytes)
   165  	}
   166  	return
   167  }
   168  
   169  func marshalLength(out *forkableWriter, i int) (err error) {
   170  	n := lengthLength(i)
   171  
   172  	for ; n > 0; n-- {
   173  		err = out.WriteByte(byte(i >> uint((n-1)*8)))
   174  		if err != nil {
   175  			return
   176  		}
   177  	}
   178  
   179  	return nil
   180  }
   181  
   182  func lengthLength(i int) (numBytes int) {
   183  	numBytes = 1
   184  	for i > 255 {
   185  		numBytes++
   186  		i >>= 8
   187  	}
   188  	return
   189  }
   190  
   191  func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
   192  	b := uint8(t.class) << 6
   193  	if t.isCompound {
   194  		b |= 0x20
   195  	}
   196  	if t.tag >= 31 {
   197  		b |= 0x1f
   198  		err = out.WriteByte(b)
   199  		if err != nil {
   200  			return
   201  		}
   202  		err = marshalBase128Int(out, int64(t.tag))
   203  		if err != nil {
   204  			return
   205  		}
   206  	} else {
   207  		b |= uint8(t.tag)
   208  		err = out.WriteByte(b)
   209  		if err != nil {
   210  			return
   211  		}
   212  	}
   213  
   214  	if t.length >= 128 {
   215  		l := lengthLength(t.length)
   216  		err = out.WriteByte(0x80 | byte(l))
   217  		if err != nil {
   218  			return
   219  		}
   220  		err = marshalLength(out, t.length)
   221  		if err != nil {
   222  			return
   223  		}
   224  	} else {
   225  		err = out.WriteByte(byte(t.length))
   226  		if err != nil {
   227  			return
   228  		}
   229  	}
   230  
   231  	return nil
   232  }
   233  
   234  func marshalBitString(out *forkableWriter, b BitString) (err error) {
   235  	paddingBits := byte((8 - b.BitLength%8) % 8)
   236  	err = out.WriteByte(paddingBits)
   237  	if err != nil {
   238  		return
   239  	}
   240  	_, err = out.Write(b.Bytes)
   241  	return
   242  }
   243  
   244  func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
   245  	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
   246  		return StructuralError{"invalid object identifier"}
   247  	}
   248  
   249  	err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
   250  	if err != nil {
   251  		return
   252  	}
   253  	for i := 2; i < len(oid); i++ {
   254  		err = marshalBase128Int(out, int64(oid[i]))
   255  		if err != nil {
   256  			return
   257  		}
   258  	}
   259  
   260  	return
   261  }
   262  
   263  func marshalPrintableString(out *forkableWriter, s string) (err error) {
   264  	b := []byte(s)
   265  	for _, c := range b {
   266  		if !isPrintable(c) {
   267  			return StructuralError{"PrintableString contains invalid character"}
   268  		}
   269  	}
   270  
   271  	_, err = out.Write(b)
   272  	return
   273  }
   274  
   275  func marshalIA5String(out *forkableWriter, s string) (err error) {
   276  	b := []byte(s)
   277  	for _, c := range b {
   278  		if c > 127 {
   279  			return StructuralError{"IA5String contains invalid character"}
   280  		}
   281  	}
   282  
   283  	_, err = out.Write(b)
   284  	return
   285  }
   286  
   287  func marshalUTF8String(out *forkableWriter, s string) (err error) {
   288  	_, err = out.Write([]byte(s))
   289  	return
   290  }
   291  
   292  func marshalTwoDigits(out *forkableWriter, v int) (err error) {
   293  	err = out.WriteByte(byte('0' + (v/10)%10))
   294  	if err != nil {
   295  		return
   296  	}
   297  	return out.WriteByte(byte('0' + v%10))
   298  }
   299  
   300  func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
   301  	year, month, day := t.Date()
   302  
   303  	switch {
   304  	case 1950 <= year && year < 2000:
   305  		err = marshalTwoDigits(out, int(year-1900))
   306  	case 2000 <= year && year < 2050:
   307  		err = marshalTwoDigits(out, int(year-2000))
   308  	default:
   309  		return StructuralError{"cannot represent time as UTCTime"}
   310  	}
   311  	if err != nil {
   312  		return
   313  	}
   314  
   315  	err = marshalTwoDigits(out, int(month))
   316  	if err != nil {
   317  		return
   318  	}
   319  
   320  	err = marshalTwoDigits(out, day)
   321  	if err != nil {
   322  		return
   323  	}
   324  
   325  	hour, min, sec := t.Clock()
   326  
   327  	err = marshalTwoDigits(out, hour)
   328  	if err != nil {
   329  		return
   330  	}
   331  
   332  	err = marshalTwoDigits(out, min)
   333  	if err != nil {
   334  		return
   335  	}
   336  
   337  	err = marshalTwoDigits(out, sec)
   338  	if err != nil {
   339  		return
   340  	}
   341  
   342  	_, offset := t.Zone()
   343  
   344  	switch {
   345  	case offset/60 == 0:
   346  		err = out.WriteByte('Z')
   347  		return
   348  	case offset > 0:
   349  		err = out.WriteByte('+')
   350  	case offset < 0:
   351  		err = out.WriteByte('-')
   352  	}
   353  
   354  	if err != nil {
   355  		return
   356  	}
   357  
   358  	offsetMinutes := offset / 60
   359  	if offsetMinutes < 0 {
   360  		offsetMinutes = -offsetMinutes
   361  	}
   362  
   363  	err = marshalTwoDigits(out, offsetMinutes/60)
   364  	if err != nil {
   365  		return
   366  	}
   367  
   368  	err = marshalTwoDigits(out, offsetMinutes%60)
   369  	return
   370  }
   371  
   372  func stripTagAndLength(in []byte) []byte {
   373  	_, offset, err := parseTagAndLength(in, 0)
   374  	if err != nil {
   375  		return in
   376  	}
   377  	return in[offset:]
   378  }
   379  
   380  func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
   381  	switch value.Type() {
   382  	case timeType:
   383  		return marshalUTCTime(out, value.Interface().(time.Time))
   384  	case bitStringType:
   385  		return marshalBitString(out, value.Interface().(BitString))
   386  	case objectIdentifierType:
   387  		return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
   388  	case bigIntType:
   389  		return marshalBigInt(out, value.Interface().(*big.Int))
   390  	}
   391  
   392  	switch v := value; v.Kind() {
   393  	case reflect.Bool:
   394  		if v.Bool() {
   395  			return out.WriteByte(255)
   396  		} else {
   397  			return out.WriteByte(0)
   398  		}
   399  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   400  		return marshalInt64(out, int64(v.Int()))
   401  	case reflect.Struct:
   402  		t := v.Type()
   403  
   404  		startingField := 0
   405  
   406  		// If the first element of the structure is a non-empty
   407  		// RawContents, then we don't bother serializing the rest.
   408  		if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
   409  			s := v.Field(0)
   410  			if s.Len() > 0 {
   411  				bytes := make([]byte, s.Len())
   412  				for i := 0; i < s.Len(); i++ {
   413  					bytes[i] = uint8(s.Index(i).Uint())
   414  				}
   415  				/* The RawContents will contain the tag and
   416  				 * length fields but we'll also be writing
   417  				 * those ourselves, so we strip them out of
   418  				 * bytes */
   419  				_, err = out.Write(stripTagAndLength(bytes))
   420  				return
   421  			} else {
   422  				startingField = 1
   423  			}
   424  		}
   425  
   426  		for i := startingField; i < t.NumField(); i++ {
   427  			var pre *forkableWriter
   428  			pre, out = out.fork()
   429  			err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
   430  			if err != nil {
   431  				return
   432  			}
   433  		}
   434  		return
   435  	case reflect.Slice:
   436  		sliceType := v.Type()
   437  		if sliceType.Elem().Kind() == reflect.Uint8 {
   438  			bytes := make([]byte, v.Len())
   439  			for i := 0; i < v.Len(); i++ {
   440  				bytes[i] = uint8(v.Index(i).Uint())
   441  			}
   442  			_, err = out.Write(bytes)
   443  			return
   444  		}
   445  
   446  		var fp fieldParameters
   447  		for i := 0; i < v.Len(); i++ {
   448  			var pre *forkableWriter
   449  			pre, out = out.fork()
   450  			err = marshalField(pre, v.Index(i), fp)
   451  			if err != nil {
   452  				return
   453  			}
   454  		}
   455  		return
   456  	case reflect.String:
   457  		switch params.stringType {
   458  		case tagIA5String:
   459  			return marshalIA5String(out, v.String())
   460  		case tagPrintableString:
   461  			return marshalPrintableString(out, v.String())
   462  		default:
   463  			return marshalUTF8String(out, v.String())
   464  		}
   465  	}
   466  
   467  	return StructuralError{"unknown Go type"}
   468  }
   469  
   470  func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
   471  	// If the field is an interface{} then recurse into it.
   472  	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
   473  		return marshalField(out, v.Elem(), params)
   474  	}
   475  
   476  	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
   477  		return
   478  	}
   479  
   480  	if params.optional && reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
   481  		return
   482  	}
   483  
   484  	if v.Type() == rawValueType {
   485  		rv := v.Interface().(RawValue)
   486  		if len(rv.FullBytes) != 0 {
   487  			_, err = out.Write(rv.FullBytes)
   488  		} else {
   489  			err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
   490  			if err != nil {
   491  				return
   492  			}
   493  			_, err = out.Write(rv.Bytes)
   494  		}
   495  		return
   496  	}
   497  
   498  	tag, isCompound, ok := getUniversalType(v.Type())
   499  	if !ok {
   500  		err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
   501  		return
   502  	}
   503  	class := classUniversal
   504  
   505  	if params.stringType != 0 && tag != tagPrintableString {
   506  		return StructuralError{"explicit string type given to non-string member"}
   507  	}
   508  
   509  	if tag == tagPrintableString {
   510  		if params.stringType == 0 {
   511  			// This is a string without an explicit string type. We'll use
   512  			// a PrintableString if the character set in the string is
   513  			// sufficiently limited, otherwise we'll use a UTF8String.
   514  			for _, r := range v.String() {
   515  				if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
   516  					if !utf8.ValidString(v.String()) {
   517  						return errors.New("asn1: string not valid UTF-8")
   518  					}
   519  					tag = tagUTF8String
   520  					break
   521  				}
   522  			}
   523  		} else {
   524  			tag = params.stringType
   525  		}
   526  	}
   527  
   528  	if params.set {
   529  		if tag != tagSequence {
   530  			return StructuralError{"non sequence tagged as set"}
   531  		}
   532  		tag = tagSet
   533  	}
   534  
   535  	tags, body := out.fork()
   536  
   537  	err = marshalBody(body, v, params)
   538  	if err != nil {
   539  		return
   540  	}
   541  
   542  	bodyLen := body.Len()
   543  
   544  	var explicitTag *forkableWriter
   545  	if params.explicit {
   546  		explicitTag, tags = tags.fork()
   547  	}
   548  
   549  	if !params.explicit && params.tag != nil {
   550  		// implicit tag.
   551  		tag = *params.tag
   552  		class = classContextSpecific
   553  	}
   554  
   555  	err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
   556  	if err != nil {
   557  		return
   558  	}
   559  
   560  	if params.explicit {
   561  		err = marshalTagAndLength(explicitTag, tagAndLength{
   562  			class:      classContextSpecific,
   563  			tag:        *params.tag,
   564  			length:     bodyLen + tags.Len(),
   565  			isCompound: true,
   566  		})
   567  		if err != nil {
   568  			return
   569  		}
   570  	}
   571  
   572  	return nil
   573  }
   574  
   575  // Marshal returns the ASN.1 encoding of val.
   576  func Marshal(val interface{}) ([]byte, error) {
   577  	var out bytes.Buffer
   578  	v := reflect.ValueOf(val)
   579  	f := newForkableWriter()
   580  	err := marshalField(f, v, fieldParameters{})
   581  	if err != nil {
   582  		return nil, err
   583  	}
   584  	_, err = f.writeTo(&out)
   585  	if err != nil {
   586  		return nil, err
   587  	}
   588  	return out.Bytes(), nil
   589  }