github.com/rsc/go@v0.0.0-20150416155037-e040fd465409/src/encoding/asn1/marshal.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package asn1
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"math/big"
    13  	"reflect"
    14  	"time"
    15  	"unicode/utf8"
    16  )
    17  
    18  // A forkableWriter is an in-memory buffer that can be
    19  // 'forked' to create new forkableWriters that bracket the
    20  // original.  After
    21  //    pre, post := w.fork()
    22  // the overall sequence of bytes represented is logically w+pre+post.
    23  type forkableWriter struct {
    24  	*bytes.Buffer
    25  	pre, post *forkableWriter
    26  }
    27  
    28  func newForkableWriter() *forkableWriter {
    29  	return &forkableWriter{new(bytes.Buffer), nil, nil}
    30  }
    31  
    32  func (f *forkableWriter) fork() (pre, post *forkableWriter) {
    33  	if f.pre != nil || f.post != nil {
    34  		panic("have already forked")
    35  	}
    36  	f.pre = newForkableWriter()
    37  	f.post = newForkableWriter()
    38  	return f.pre, f.post
    39  }
    40  
    41  func (f *forkableWriter) Len() (l int) {
    42  	l += f.Buffer.Len()
    43  	if f.pre != nil {
    44  		l += f.pre.Len()
    45  	}
    46  	if f.post != nil {
    47  		l += f.post.Len()
    48  	}
    49  	return
    50  }
    51  
    52  func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
    53  	n, err = out.Write(f.Bytes())
    54  	if err != nil {
    55  		return
    56  	}
    57  
    58  	var nn int
    59  
    60  	if f.pre != nil {
    61  		nn, err = f.pre.writeTo(out)
    62  		n += nn
    63  		if err != nil {
    64  			return
    65  		}
    66  	}
    67  
    68  	if f.post != nil {
    69  		nn, err = f.post.writeTo(out)
    70  		n += nn
    71  	}
    72  	return
    73  }
    74  
    75  func marshalBase128Int(out *forkableWriter, n int64) (err error) {
    76  	if n == 0 {
    77  		err = out.WriteByte(0)
    78  		return
    79  	}
    80  
    81  	l := 0
    82  	for i := n; i > 0; i >>= 7 {
    83  		l++
    84  	}
    85  
    86  	for i := l - 1; i >= 0; i-- {
    87  		o := byte(n >> uint(i*7))
    88  		o &= 0x7f
    89  		if i != 0 {
    90  			o |= 0x80
    91  		}
    92  		err = out.WriteByte(o)
    93  		if err != nil {
    94  			return
    95  		}
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func marshalInt64(out *forkableWriter, i int64) (err error) {
   102  	n := int64Length(i)
   103  
   104  	for ; n > 0; n-- {
   105  		err = out.WriteByte(byte(i >> uint((n-1)*8)))
   106  		if err != nil {
   107  			return
   108  		}
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  func int64Length(i int64) (numBytes int) {
   115  	numBytes = 1
   116  
   117  	for i > 127 {
   118  		numBytes++
   119  		i >>= 8
   120  	}
   121  
   122  	for i < -128 {
   123  		numBytes++
   124  		i >>= 8
   125  	}
   126  
   127  	return
   128  }
   129  
   130  func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
   131  	if n.Sign() < 0 {
   132  		// A negative number has to be converted to two's-complement
   133  		// form. So we'll subtract 1 and invert. If the
   134  		// most-significant-bit isn't set then we'll need to pad the
   135  		// beginning with 0xff in order to keep the number negative.
   136  		nMinus1 := new(big.Int).Neg(n)
   137  		nMinus1.Sub(nMinus1, bigOne)
   138  		bytes := nMinus1.Bytes()
   139  		for i := range bytes {
   140  			bytes[i] ^= 0xff
   141  		}
   142  		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
   143  			err = out.WriteByte(0xff)
   144  			if err != nil {
   145  				return
   146  			}
   147  		}
   148  		_, err = out.Write(bytes)
   149  	} else if n.Sign() == 0 {
   150  		// Zero is written as a single 0 zero rather than no bytes.
   151  		err = out.WriteByte(0x00)
   152  	} else {
   153  		bytes := n.Bytes()
   154  		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
   155  			// We'll have to pad this with 0x00 in order to stop it
   156  			// looking like a negative number.
   157  			err = out.WriteByte(0)
   158  			if err != nil {
   159  				return
   160  			}
   161  		}
   162  		_, err = out.Write(bytes)
   163  	}
   164  	return
   165  }
   166  
   167  func marshalLength(out *forkableWriter, i int) (err error) {
   168  	n := lengthLength(i)
   169  
   170  	for ; n > 0; n-- {
   171  		err = out.WriteByte(byte(i >> uint((n-1)*8)))
   172  		if err != nil {
   173  			return
   174  		}
   175  	}
   176  
   177  	return nil
   178  }
   179  
   180  func lengthLength(i int) (numBytes int) {
   181  	numBytes = 1
   182  	for i > 255 {
   183  		numBytes++
   184  		i >>= 8
   185  	}
   186  	return
   187  }
   188  
   189  func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
   190  	b := uint8(t.class) << 6
   191  	if t.isCompound {
   192  		b |= 0x20
   193  	}
   194  	if t.tag >= 31 {
   195  		b |= 0x1f
   196  		err = out.WriteByte(b)
   197  		if err != nil {
   198  			return
   199  		}
   200  		err = marshalBase128Int(out, int64(t.tag))
   201  		if err != nil {
   202  			return
   203  		}
   204  	} else {
   205  		b |= uint8(t.tag)
   206  		err = out.WriteByte(b)
   207  		if err != nil {
   208  			return
   209  		}
   210  	}
   211  
   212  	if t.length >= 128 {
   213  		l := lengthLength(t.length)
   214  		err = out.WriteByte(0x80 | byte(l))
   215  		if err != nil {
   216  			return
   217  		}
   218  		err = marshalLength(out, t.length)
   219  		if err != nil {
   220  			return
   221  		}
   222  	} else {
   223  		err = out.WriteByte(byte(t.length))
   224  		if err != nil {
   225  			return
   226  		}
   227  	}
   228  
   229  	return nil
   230  }
   231  
   232  func marshalBitString(out *forkableWriter, b BitString) (err error) {
   233  	paddingBits := byte((8 - b.BitLength%8) % 8)
   234  	err = out.WriteByte(paddingBits)
   235  	if err != nil {
   236  		return
   237  	}
   238  	_, err = out.Write(b.Bytes)
   239  	return
   240  }
   241  
   242  func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
   243  	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
   244  		return StructuralError{"invalid object identifier"}
   245  	}
   246  
   247  	err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
   248  	if err != nil {
   249  		return
   250  	}
   251  	for i := 2; i < len(oid); i++ {
   252  		err = marshalBase128Int(out, int64(oid[i]))
   253  		if err != nil {
   254  			return
   255  		}
   256  	}
   257  
   258  	return
   259  }
   260  
   261  func marshalPrintableString(out *forkableWriter, s string) (err error) {
   262  	b := []byte(s)
   263  	for _, c := range b {
   264  		if !isPrintable(c) {
   265  			return StructuralError{"PrintableString contains invalid character"}
   266  		}
   267  	}
   268  
   269  	_, err = out.Write(b)
   270  	return
   271  }
   272  
   273  func marshalIA5String(out *forkableWriter, s string) (err error) {
   274  	b := []byte(s)
   275  	for _, c := range b {
   276  		if c > 127 {
   277  			return StructuralError{"IA5String contains invalid character"}
   278  		}
   279  	}
   280  
   281  	_, err = out.Write(b)
   282  	return
   283  }
   284  
   285  func marshalUTF8String(out *forkableWriter, s string) (err error) {
   286  	_, err = out.Write([]byte(s))
   287  	return
   288  }
   289  
   290  func marshalTwoDigits(out *forkableWriter, v int) (err error) {
   291  	err = out.WriteByte(byte('0' + (v/10)%10))
   292  	if err != nil {
   293  		return
   294  	}
   295  	return out.WriteByte(byte('0' + v%10))
   296  }
   297  
   298  func marshalFourDigits(out *forkableWriter, v int) (err error) {
   299  	var bytes [4]byte
   300  	for i := range bytes {
   301  		bytes[3-i] = '0' + byte(v%10)
   302  		v /= 10
   303  	}
   304  	_, err = out.Write(bytes[:])
   305  	return
   306  }
   307  
   308  func outsideUTCRange(t time.Time) bool {
   309  	year := t.Year()
   310  	return year < 1950 || year >= 2050
   311  }
   312  
   313  func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
   314  	year := t.Year()
   315  
   316  	switch {
   317  	case 1950 <= year && year < 2000:
   318  		err = marshalTwoDigits(out, int(year-1900))
   319  	case 2000 <= year && year < 2050:
   320  		err = marshalTwoDigits(out, int(year-2000))
   321  	default:
   322  		return StructuralError{"cannot represent time as UTCTime"}
   323  	}
   324  	if err != nil {
   325  		return
   326  	}
   327  
   328  	return marshalTimeCommon(out, t)
   329  }
   330  
   331  func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
   332  	year := t.Year()
   333  	if year < 0 || year > 9999 {
   334  		return StructuralError{"cannot represent time as GeneralizedTime"}
   335  	}
   336  	if err = marshalFourDigits(out, year); err != nil {
   337  		return
   338  	}
   339  
   340  	return marshalTimeCommon(out, t)
   341  }
   342  
   343  func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
   344  	_, month, day := t.Date()
   345  
   346  	err = marshalTwoDigits(out, int(month))
   347  	if err != nil {
   348  		return
   349  	}
   350  
   351  	err = marshalTwoDigits(out, day)
   352  	if err != nil {
   353  		return
   354  	}
   355  
   356  	hour, min, sec := t.Clock()
   357  
   358  	err = marshalTwoDigits(out, hour)
   359  	if err != nil {
   360  		return
   361  	}
   362  
   363  	err = marshalTwoDigits(out, min)
   364  	if err != nil {
   365  		return
   366  	}
   367  
   368  	err = marshalTwoDigits(out, sec)
   369  	if err != nil {
   370  		return
   371  	}
   372  
   373  	_, offset := t.Zone()
   374  
   375  	switch {
   376  	case offset/60 == 0:
   377  		err = out.WriteByte('Z')
   378  		return
   379  	case offset > 0:
   380  		err = out.WriteByte('+')
   381  	case offset < 0:
   382  		err = out.WriteByte('-')
   383  	}
   384  
   385  	if err != nil {
   386  		return
   387  	}
   388  
   389  	offsetMinutes := offset / 60
   390  	if offsetMinutes < 0 {
   391  		offsetMinutes = -offsetMinutes
   392  	}
   393  
   394  	err = marshalTwoDigits(out, offsetMinutes/60)
   395  	if err != nil {
   396  		return
   397  	}
   398  
   399  	err = marshalTwoDigits(out, offsetMinutes%60)
   400  	return
   401  }
   402  
   403  func stripTagAndLength(in []byte) []byte {
   404  	_, offset, err := parseTagAndLength(in, 0)
   405  	if err != nil {
   406  		return in
   407  	}
   408  	return in[offset:]
   409  }
   410  
   411  func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
   412  	switch value.Type() {
   413  	case timeType:
   414  		t := value.Interface().(time.Time)
   415  		if outsideUTCRange(t) {
   416  			return marshalGeneralizedTime(out, t)
   417  		} else {
   418  			return marshalUTCTime(out, t)
   419  		}
   420  	case bitStringType:
   421  		return marshalBitString(out, value.Interface().(BitString))
   422  	case objectIdentifierType:
   423  		return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
   424  	case bigIntType:
   425  		return marshalBigInt(out, value.Interface().(*big.Int))
   426  	}
   427  
   428  	switch v := value; v.Kind() {
   429  	case reflect.Bool:
   430  		if v.Bool() {
   431  			return out.WriteByte(255)
   432  		} else {
   433  			return out.WriteByte(0)
   434  		}
   435  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   436  		return marshalInt64(out, int64(v.Int()))
   437  	case reflect.Struct:
   438  		t := v.Type()
   439  
   440  		startingField := 0
   441  
   442  		// If the first element of the structure is a non-empty
   443  		// RawContents, then we don't bother serializing the rest.
   444  		if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
   445  			s := v.Field(0)
   446  			if s.Len() > 0 {
   447  				bytes := make([]byte, s.Len())
   448  				for i := 0; i < s.Len(); i++ {
   449  					bytes[i] = uint8(s.Index(i).Uint())
   450  				}
   451  				/* The RawContents will contain the tag and
   452  				 * length fields but we'll also be writing
   453  				 * those ourselves, so we strip them out of
   454  				 * bytes */
   455  				_, err = out.Write(stripTagAndLength(bytes))
   456  				return
   457  			} else {
   458  				startingField = 1
   459  			}
   460  		}
   461  
   462  		for i := startingField; i < t.NumField(); i++ {
   463  			var pre *forkableWriter
   464  			pre, out = out.fork()
   465  			err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
   466  			if err != nil {
   467  				return
   468  			}
   469  		}
   470  		return
   471  	case reflect.Slice:
   472  		sliceType := v.Type()
   473  		if sliceType.Elem().Kind() == reflect.Uint8 {
   474  			bytes := make([]byte, v.Len())
   475  			for i := 0; i < v.Len(); i++ {
   476  				bytes[i] = uint8(v.Index(i).Uint())
   477  			}
   478  			_, err = out.Write(bytes)
   479  			return
   480  		}
   481  
   482  		var fp fieldParameters
   483  		for i := 0; i < v.Len(); i++ {
   484  			var pre *forkableWriter
   485  			pre, out = out.fork()
   486  			err = marshalField(pre, v.Index(i), fp)
   487  			if err != nil {
   488  				return
   489  			}
   490  		}
   491  		return
   492  	case reflect.String:
   493  		switch params.stringType {
   494  		case tagIA5String:
   495  			return marshalIA5String(out, v.String())
   496  		case tagPrintableString:
   497  			return marshalPrintableString(out, v.String())
   498  		default:
   499  			return marshalUTF8String(out, v.String())
   500  		}
   501  	}
   502  
   503  	return StructuralError{"unknown Go type"}
   504  }
   505  
   506  func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
   507  	// If the field is an interface{} then recurse into it.
   508  	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
   509  		return marshalField(out, v.Elem(), params)
   510  	}
   511  
   512  	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
   513  		return
   514  	}
   515  
   516  	if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
   517  		defaultValue := reflect.New(v.Type()).Elem()
   518  		defaultValue.SetInt(*params.defaultValue)
   519  
   520  		if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
   521  			return
   522  		}
   523  	}
   524  
   525  	// If no default value is given then the zero value for the type is
   526  	// assumed to be the default value. This isn't obviously the correct
   527  	// behaviour, but it's what Go has traditionally done.
   528  	if params.optional && params.defaultValue == nil {
   529  		if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
   530  			return
   531  		}
   532  	}
   533  
   534  	if v.Type() == rawValueType {
   535  		rv := v.Interface().(RawValue)
   536  		if len(rv.FullBytes) != 0 {
   537  			_, err = out.Write(rv.FullBytes)
   538  		} else {
   539  			err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
   540  			if err != nil {
   541  				return
   542  			}
   543  			_, err = out.Write(rv.Bytes)
   544  		}
   545  		return
   546  	}
   547  
   548  	tag, isCompound, ok := getUniversalType(v.Type())
   549  	if !ok {
   550  		err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
   551  		return
   552  	}
   553  	class := classUniversal
   554  
   555  	if params.stringType != 0 && tag != tagPrintableString {
   556  		return StructuralError{"explicit string type given to non-string member"}
   557  	}
   558  
   559  	switch tag {
   560  	case tagPrintableString:
   561  		if params.stringType == 0 {
   562  			// This is a string without an explicit string type. We'll use
   563  			// a PrintableString if the character set in the string is
   564  			// sufficiently limited, otherwise we'll use a UTF8String.
   565  			for _, r := range v.String() {
   566  				if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
   567  					if !utf8.ValidString(v.String()) {
   568  						return errors.New("asn1: string not valid UTF-8")
   569  					}
   570  					tag = tagUTF8String
   571  					break
   572  				}
   573  			}
   574  		} else {
   575  			tag = params.stringType
   576  		}
   577  	case tagUTCTime:
   578  		if outsideUTCRange(v.Interface().(time.Time)) {
   579  			tag = tagGeneralizedTime
   580  		}
   581  	}
   582  
   583  	if params.set {
   584  		if tag != tagSequence {
   585  			return StructuralError{"non sequence tagged as set"}
   586  		}
   587  		tag = tagSet
   588  	}
   589  
   590  	tags, body := out.fork()
   591  
   592  	err = marshalBody(body, v, params)
   593  	if err != nil {
   594  		return
   595  	}
   596  
   597  	bodyLen := body.Len()
   598  
   599  	var explicitTag *forkableWriter
   600  	if params.explicit {
   601  		explicitTag, tags = tags.fork()
   602  	}
   603  
   604  	if !params.explicit && params.tag != nil {
   605  		// implicit tag.
   606  		tag = *params.tag
   607  		class = classContextSpecific
   608  	}
   609  
   610  	err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
   611  	if err != nil {
   612  		return
   613  	}
   614  
   615  	if params.explicit {
   616  		err = marshalTagAndLength(explicitTag, tagAndLength{
   617  			class:      classContextSpecific,
   618  			tag:        *params.tag,
   619  			length:     bodyLen + tags.Len(),
   620  			isCompound: true,
   621  		})
   622  	}
   623  
   624  	return nil
   625  }
   626  
   627  // Marshal returns the ASN.1 encoding of val.
   628  //
   629  // In addition to the struct tags recognised by Unmarshal, the following can be
   630  // used:
   631  //
   632  //	ia5:		causes strings to be marshaled as ASN.1, IA5 strings
   633  //	omitempty:	causes empty slices to be skipped
   634  //	printable:	causes strings to be marshaled as ASN.1, PrintableString strings.
   635  //	utf8:		causes strings to be marshaled as ASN.1, UTF8 strings
   636  func Marshal(val interface{}) ([]byte, error) {
   637  	var out bytes.Buffer
   638  	v := reflect.ValueOf(val)
   639  	f := newForkableWriter()
   640  	err := marshalField(f, v, fieldParameters{})
   641  	if err != nil {
   642  		return nil, err
   643  	}
   644  	_, err = f.writeTo(&out)
   645  	return out.Bytes(), nil
   646  }