bosun.org@v0.0.0-20210513094433-e25bc3e69a1f/snmp/asn1/marshal.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package asn1
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/big"
    13  	"reflect"
    14  	"time"
    15  	"unicode/utf8"
    16  )
    17  
    18  // A forkableWriter is an in-memory buffer that can be
    19  // 'forked' to create new forkableWriters that bracket the
    20  // original.  After
    21  //    pre, post := w.fork()
    22  // the overall sequence of bytes represented is logically w+pre+post.
    23  type forkableWriter struct {
    24  	*bytes.Buffer
    25  	pre, post *forkableWriter
    26  }
    27  
    28  func newForkableWriter() *forkableWriter {
    29  	return &forkableWriter{new(bytes.Buffer), nil, nil}
    30  }
    31  
    32  func (f *forkableWriter) fork() (pre, post *forkableWriter) {
    33  	if f.pre != nil || f.post != nil {
    34  		panic("have already forked")
    35  	}
    36  	f.pre = newForkableWriter()
    37  	f.post = newForkableWriter()
    38  	return f.pre, f.post
    39  }
    40  
    41  func (f *forkableWriter) Len() (l int) {
    42  	l += f.Buffer.Len()
    43  	if f.pre != nil {
    44  		l += f.pre.Len()
    45  	}
    46  	if f.post != nil {
    47  		l += f.post.Len()
    48  	}
    49  	return
    50  }
    51  
    52  func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
    53  	n, err = out.Write(f.Bytes())
    54  	if err != nil {
    55  		return
    56  	}
    57  
    58  	var nn int
    59  
    60  	if f.pre != nil {
    61  		nn, err = f.pre.writeTo(out)
    62  		n += nn
    63  		if err != nil {
    64  			return
    65  		}
    66  	}
    67  
    68  	if f.post != nil {
    69  		nn, err = f.post.writeTo(out)
    70  		n += nn
    71  	}
    72  	return
    73  }
    74  
    75  func marshalBase128Int(out *forkableWriter, n int64) (err error) {
    76  	if n == 0 {
    77  		err = out.WriteByte(0)
    78  		return
    79  	}
    80  
    81  	l := 0
    82  	for i := n; i > 0; i >>= 7 {
    83  		l++
    84  	}
    85  
    86  	for i := l - 1; i >= 0; i-- {
    87  		o := byte(n >> uint(i*7))
    88  		o &= 0x7f
    89  		if i != 0 {
    90  			o |= 0x80
    91  		}
    92  		err = out.WriteByte(o)
    93  		if err != nil {
    94  			return
    95  		}
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func marshalInt64(out *forkableWriter, i int64) (err error) {
   102  	n := int64Length(i)
   103  
   104  	for ; n > 0; n-- {
   105  		err = out.WriteByte(byte(i >> uint((n-1)*8)))
   106  		if err != nil {
   107  			return
   108  		}
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  func int64Length(i int64) (numBytes int) {
   115  	numBytes = 1
   116  
   117  	for i > 127 {
   118  		numBytes++
   119  		i >>= 8
   120  	}
   121  
   122  	for i < -128 {
   123  		numBytes++
   124  		i >>= 8
   125  	}
   126  
   127  	return
   128  }
   129  
   130  func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
   131  	if n.Sign() < 0 {
   132  		// A negative number has to be converted to two's-complement
   133  		// form. So we'll subtract 1 and invert. If the
   134  		// most-significant-bit isn't set then we'll need to pad the
   135  		// beginning with 0xff in order to keep the number negative.
   136  		nMinus1 := new(big.Int).Neg(n)
   137  		nMinus1.Sub(nMinus1, bigOne)
   138  		bytes := nMinus1.Bytes()
   139  		for i := range bytes {
   140  			bytes[i] ^= 0xff
   141  		}
   142  		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
   143  			err = out.WriteByte(0xff)
   144  			if err != nil {
   145  				return
   146  			}
   147  		}
   148  		_, err = out.Write(bytes)
   149  	} else if n.Sign() == 0 {
   150  		// Zero is written as a single 0 zero rather than no bytes.
   151  		err = out.WriteByte(0x00)
   152  	} else {
   153  		bytes := n.Bytes()
   154  		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
   155  			// We'll have to pad this with 0x00 in order to stop it
   156  			// looking like a negative number.
   157  			err = out.WriteByte(0)
   158  			if err != nil {
   159  				return
   160  			}
   161  		}
   162  		_, err = out.Write(bytes)
   163  	}
   164  	return
   165  }
   166  
   167  func marshalLength(out *forkableWriter, i int) (err error) {
   168  	n := lengthLength(i)
   169  
   170  	for ; n > 0; n-- {
   171  		err = out.WriteByte(byte(i >> uint((n-1)*8)))
   172  		if err != nil {
   173  			return
   174  		}
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func lengthLength(i int) (numBytes int) {
   181  	numBytes = 1
   182  	for i > 255 {
   183  		numBytes++
   184  		i >>= 8
   185  	}
   186  	return
   187  }
   188  
   189  func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
   190  	b := uint8(t.class) << 6
   191  	if t.isCompound {
   192  		b |= 0x20
   193  	}
   194  	if t.tag >= 31 {
   195  		b |= 0x1f
   196  		err = out.WriteByte(b)
   197  		if err != nil {
   198  			return
   199  		}
   200  		err = marshalBase128Int(out, int64(t.tag))
   201  		if err != nil {
   202  			return
   203  		}
   204  	} else {
   205  		b |= uint8(t.tag)
   206  		err = out.WriteByte(b)
   207  		if err != nil {
   208  			return
   209  		}
   210  	}
   211  
   212  	if t.length >= 128 {
   213  		l := lengthLength(t.length)
   214  		err = out.WriteByte(0x80 | byte(l))
   215  		if err != nil {
   216  			return
   217  		}
   218  		err = marshalLength(out, t.length)
   219  		if err != nil {
   220  			return
   221  		}
   222  	} else {
   223  		err = out.WriteByte(byte(t.length))
   224  		if err != nil {
   225  			return
   226  		}
   227  	}
   228  
   229  	return nil
   230  }
   231  
   232  func marshalBitString(out *forkableWriter, b BitString) (err error) {
   233  	paddingBits := byte((8 - b.BitLength%8) % 8)
   234  	err = out.WriteByte(paddingBits)
   235  	if err != nil {
   236  		return
   237  	}
   238  	_, err = out.Write(b.Bytes)
   239  	return
   240  }
   241  
   242  func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
   243  	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
   244  		return StructuralError{"invalid object identifier"}
   245  	}
   246  
   247  	err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
   248  	if err != nil {
   249  		return
   250  	}
   251  	for i := 2; i < len(oid); i++ {
   252  		err = marshalBase128Int(out, int64(oid[i]))
   253  		if err != nil {
   254  			return
   255  		}
   256  	}
   257  
   258  	return
   259  }
   260  
   261  func marshalPrintableString(out *forkableWriter, s string) (err error) {
   262  	b := []byte(s)
   263  	for _, c := range b {
   264  		if !isPrintable(c) {
   265  			return StructuralError{"PrintableString contains invalid character"}
   266  		}
   267  	}
   268  
   269  	_, err = out.Write(b)
   270  	return
   271  }
   272  
   273  func marshalIA5String(out *forkableWriter, s string) (err error) {
   274  	b := []byte(s)
   275  	for _, c := range b {
   276  		if c > 127 {
   277  			return StructuralError{"IA5String contains invalid character"}
   278  		}
   279  	}
   280  
   281  	_, err = out.Write(b)
   282  	return
   283  }
   284  
   285  func marshalUTF8String(out *forkableWriter, s string) (err error) {
   286  	_, err = out.Write([]byte(s))
   287  	return
   288  }
   289  
   290  func marshalTwoDigits(out *forkableWriter, v int) (err error) {
   291  	err = out.WriteByte(byte('0' + (v/10)%10))
   292  	if err != nil {
   293  		return
   294  	}
   295  	return out.WriteByte(byte('0' + v%10))
   296  }
   297  
   298  func marshalFourDigits(out *forkableWriter, v int) (err error) {
   299  	var bytes [4]byte
   300  	for i := range bytes {
   301  		bytes[3-i] = '0' + byte(v%10)
   302  		v /= 10
   303  	}
   304  	_, err = out.Write(bytes[:])
   305  	return
   306  }
   307  
   308  func outsideUTCRange(t time.Time) bool {
   309  	year := t.Year()
   310  	return year < 1950 || year >= 2050
   311  }
   312  
   313  func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
   314  	year := t.Year()
   315  
   316  	switch {
   317  	case 1950 <= year && year < 2000:
   318  		err = marshalTwoDigits(out, int(year-1900))
   319  	case 2000 <= year && year < 2050:
   320  		err = marshalTwoDigits(out, int(year-2000))
   321  	default:
   322  		return StructuralError{"cannot represent time as UTCTime"}
   323  	}
   324  	if err != nil {
   325  		return
   326  	}
   327  
   328  	return marshalTimeCommon(out, t)
   329  }
   330  
   331  func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
   332  	year := t.Year()
   333  	if year < 0 || year > 9999 {
   334  		return StructuralError{"cannot represent time as GeneralizedTime"}
   335  	}
   336  	if err = marshalFourDigits(out, year); err != nil {
   337  		return
   338  	}
   339  
   340  	return marshalTimeCommon(out, t)
   341  }
   342  
   343  func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
   344  	_, month, day := t.Date()
   345  
   346  	err = marshalTwoDigits(out, int(month))
   347  	if err != nil {
   348  		return
   349  	}
   350  
   351  	err = marshalTwoDigits(out, day)
   352  	if err != nil {
   353  		return
   354  	}
   355  
   356  	hour, min, sec := t.Clock()
   357  
   358  	err = marshalTwoDigits(out, hour)
   359  	if err != nil {
   360  		return
   361  	}
   362  
   363  	err = marshalTwoDigits(out, min)
   364  	if err != nil {
   365  		return
   366  	}
   367  
   368  	err = marshalTwoDigits(out, sec)
   369  	if err != nil {
   370  		return
   371  	}
   372  
   373  	_, offset := t.Zone()
   374  
   375  	switch {
   376  	case offset/60 == 0:
   377  		err = out.WriteByte('Z')
   378  		return
   379  	case offset > 0:
   380  		err = out.WriteByte('+')
   381  	case offset < 0:
   382  		err = out.WriteByte('-')
   383  	}
   384  
   385  	if err != nil {
   386  		return
   387  	}
   388  
   389  	offsetMinutes := offset / 60
   390  	if offsetMinutes < 0 {
   391  		offsetMinutes = -offsetMinutes
   392  	}
   393  
   394  	err = marshalTwoDigits(out, offsetMinutes/60)
   395  	if err != nil {
   396  		return
   397  	}
   398  
   399  	err = marshalTwoDigits(out, offsetMinutes%60)
   400  	return
   401  }
   402  
   403  func stripTagAndLength(in []byte) []byte {
   404  	_, offset, err := parseTagAndLength(in, 0)
   405  	if err != nil {
   406  		return in
   407  	}
   408  	return in[offset:]
   409  }
   410  
   411  func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
   412  	switch value.Type() {
   413  	case flagType:
   414  		return nil
   415  	case timeType:
   416  		t := value.Interface().(time.Time)
   417  		if params.timeType == tagGeneralizedTime || outsideUTCRange(t) {
   418  			return marshalGeneralizedTime(out, t)
   419  		} else {
   420  			return marshalUTCTime(out, t)
   421  		}
   422  	case bitStringType:
   423  		return marshalBitString(out, value.Interface().(BitString))
   424  	case objectIdentifierType:
   425  		return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
   426  	case bigIntType:
   427  		return marshalBigInt(out, value.Interface().(*big.Int))
   428  	}
   429  
   430  	switch v := value; v.Kind() {
   431  	case reflect.Bool:
   432  		if v.Bool() {
   433  			return out.WriteByte(255)
   434  		} else {
   435  			return out.WriteByte(0)
   436  		}
   437  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   438  		return marshalInt64(out, int64(v.Int()))
   439  	case reflect.Struct:
   440  		t := v.Type()
   441  
   442  		startingField := 0
   443  
   444  		// If the first element of the structure is a non-empty
   445  		// RawContents, then we don't bother serializing the rest.
   446  		if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
   447  			s := v.Field(0)
   448  			if s.Len() > 0 {
   449  				bytes := make([]byte, s.Len())
   450  				for i := 0; i < s.Len(); i++ {
   451  					bytes[i] = uint8(s.Index(i).Uint())
   452  				}
   453  				/* The RawContents will contain the tag and
   454  				 * length fields but we'll also be writing
   455  				 * those ourselves, so we strip them out of
   456  				 * bytes */
   457  				_, err = out.Write(stripTagAndLength(bytes))
   458  				return
   459  			} else {
   460  				startingField = 1
   461  			}
   462  		}
   463  
   464  		for i := startingField; i < t.NumField(); i++ {
   465  			var pre *forkableWriter
   466  			pre, out = out.fork()
   467  			err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
   468  			if err != nil {
   469  				return
   470  			}
   471  		}
   472  		return
   473  	case reflect.Slice:
   474  		sliceType := v.Type()
   475  		if sliceType.Elem().Kind() == reflect.Uint8 {
   476  			bytes := make([]byte, v.Len())
   477  			for i := 0; i < v.Len(); i++ {
   478  				bytes[i] = uint8(v.Index(i).Uint())
   479  			}
   480  			_, err = out.Write(bytes)
   481  			return
   482  		}
   483  
   484  		var fp fieldParameters
   485  		for i := 0; i < v.Len(); i++ {
   486  			var pre *forkableWriter
   487  			pre, out = out.fork()
   488  			err = marshalField(pre, v.Index(i), fp)
   489  			if err != nil {
   490  				return
   491  			}
   492  		}
   493  		return
   494  	case reflect.String:
   495  		switch params.stringType {
   496  		case tagIA5String:
   497  			return marshalIA5String(out, v.String())
   498  		case tagPrintableString:
   499  			return marshalPrintableString(out, v.String())
   500  		default:
   501  			return marshalUTF8String(out, v.String())
   502  		}
   503  	}
   504  
   505  	return StructuralError{"unknown Go type"}
   506  }
   507  
   508  func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
   509  	// If the field is an interface{} then recurse into it.
   510  	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
   511  		return marshalField(out, v.Elem(), params)
   512  	}
   513  
   514  	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
   515  		return
   516  	}
   517  
   518  	if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
   519  		defaultValue := reflect.New(v.Type()).Elem()
   520  		defaultValue.SetInt(*params.defaultValue)
   521  
   522  		if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
   523  			return
   524  		}
   525  	}
   526  
   527  	// If no default value is given then the zero value for the type is
   528  	// assumed to be the default value. This isn't obviously the correct
   529  	// behaviour, but it's what Go has traditionally done.
   530  	if params.optional && params.defaultValue == nil {
   531  		if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
   532  			return
   533  		}
   534  	}
   535  
   536  	if v.Type() == rawValueType {
   537  		rv := v.Interface().(RawValue)
   538  		if len(rv.FullBytes) != 0 {
   539  			_, err = out.Write(rv.FullBytes)
   540  		} else {
   541  			err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
   542  			if err != nil {
   543  				return
   544  			}
   545  			_, err = out.Write(rv.Bytes)
   546  		}
   547  		return
   548  	}
   549  
   550  	tag, isCompound, ok := getUniversalType(v.Type())
   551  	if !ok {
   552  		err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
   553  		return
   554  	}
   555  	class := classUniversal
   556  
   557  	if params.timeType != 0 && tag != tagUTCTime {
   558  		return StructuralError{"explicit time type given to non-time member"}
   559  	}
   560  
   561  	if params.stringType != 0 && tag != tagPrintableString {
   562  		return StructuralError{"explicit string type given to non-string member"}
   563  	}
   564  
   565  	switch tag {
   566  	case tagPrintableString:
   567  		if params.stringType == 0 {
   568  			// This is a string without an explicit string type. We'll use
   569  			// a PrintableString if the character set in the string is
   570  			// sufficiently limited, otherwise we'll use a UTF8String.
   571  			for _, r := range v.String() {
   572  				if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
   573  					if !utf8.ValidString(v.String()) {
   574  						return errors.New("asn1: string not valid UTF-8")
   575  					}
   576  					tag = tagUTF8String
   577  					break
   578  				}
   579  			}
   580  		} else {
   581  			tag = params.stringType
   582  		}
   583  	case tagUTCTime:
   584  		if params.timeType == tagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
   585  			tag = tagGeneralizedTime
   586  		}
   587  	}
   588  
   589  	if params.set {
   590  		if tag != tagSequence {
   591  			return StructuralError{"non sequence tagged as set"}
   592  		}
   593  		tag = tagSet
   594  	}
   595  
   596  	tags, body := out.fork()
   597  
   598  	err = marshalBody(body, v, params)
   599  	if err != nil {
   600  		return
   601  	}
   602  
   603  	bodyLen := body.Len()
   604  
   605  	var explicitTag *forkableWriter
   606  	if params.explicit {
   607  		explicitTag, tags = tags.fork()
   608  	}
   609  
   610  	if !params.explicit && params.tag != nil {
   611  		// implicit tag.
   612  		tag = *params.tag
   613  		class = classContextSpecific
   614  	}
   615  
   616  	err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
   617  	if err != nil {
   618  		return
   619  	}
   620  
   621  	if params.explicit {
   622  		err = marshalTagAndLength(explicitTag, tagAndLength{
   623  			class:      classContextSpecific,
   624  			tag:        *params.tag,
   625  			length:     bodyLen + tags.Len(),
   626  			isCompound: true,
   627  		})
   628  	}
   629  
   630  	return nil
   631  }
   632  
   633  // Marshal returns the ASN.1 encoding of val.
   634  //
   635  // In addition to the struct tags recognised by Unmarshal, the following can be
   636  // used:
   637  //
   638  //	ia5:		causes strings to be marshaled as ASN.1, IA5 strings
   639  //	omitempty:	causes empty slices to be skipped
   640  //	printable:	causes strings to be marshaled as ASN.1, PrintableString strings.
   641  //	utf8:		causes strings to be marshaled as ASN.1, UTF8 strings
   642  func Marshal(val interface{}) ([]byte, error) {
   643  	var out bytes.Buffer
   644  	v := reflect.ValueOf(val)
   645  	f := newForkableWriter()
   646  	err := marshalField(f, v, fieldParameters{})
   647  	if err != nil {
   648  		return nil, err
   649  	}
   650  	_, err = f.writeTo(&out)
   651  	return out.Bytes(), nil
   652  }