github.com/keysonZZZ/kmg@v0.0.0-20151121023212-05317bfd7d39/third/kmgRadius/packet.go (about)

     1  package kmgRadius
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/hmac"
     6  	_ "crypto/md5"
     7  	"crypto/rand"
     8  	"encoding/binary"
     9  	"fmt"
    10  	"net"
    11  	"strconv"
    12  
    13  	"github.com/bronze1man/kmg/third/kmgRadius/eap"
    14  )
    15  
    16  var ErrMessageAuthenticatorCheckFail = fmt.Errorf("RADIUS Response-Authenticator verification failed")
    17  
    18  const maxPacketLength = 4096 // rfc2058 Page 9 Length
    19  
    20  type Packet struct {
    21  	Secret        []byte
    22  	Code          Code
    23  	Identifier    uint8
    24  	Authenticator [16]byte //对应的Request请求里面的Authenticator.
    25  	AVPs          []AVP
    26  }
    27  
    28  func (p *Packet) Copy() *Packet {
    29  	outP := &Packet{
    30  		Secret:        p.Secret,
    31  		Code:          p.Code,
    32  		Identifier:    p.Identifier,
    33  		Authenticator: p.Authenticator, //这个应该是拷贝
    34  	}
    35  	outP.AVPs = make([]AVP, len(p.AVPs))
    36  	for i := range p.AVPs {
    37  		outP.AVPs[i] = p.AVPs[i].Copy()
    38  	}
    39  	return outP
    40  }
    41  
    42  //此方法保证不修改包的内容
    43  func (p *Packet) Encode() (b []byte, err error) {
    44  	p = p.Copy()
    45  	p.SetAVP(&BinaryAVP{
    46  		Type:  AVPTypeMessageAuthenticator,
    47  		Value: make([]byte, 16),
    48  	})
    49  	if p.Code == CodeAccessRequest {
    50  		_, err := rand.Read(p.Authenticator[:])
    51  		if err != nil {
    52  			return nil, err
    53  		}
    54  	}
    55  	//TODO request的时候重新计算密码
    56  	b, err = p.encodeNoHash()
    57  	if err != nil {
    58  		return
    59  	}
    60  	//计算Message-Authenticator这个AVP的值,特殊化处理,Message-Authenticator这个AVP被放在最后面
    61  	hasher := hmac.New(crypto.MD5.New, []byte(p.Secret))
    62  	hasher.Write(b)
    63  	copy(b[len(b)-16:len(b)], hasher.Sum(nil))
    64  
    65  	// fix up the authenticator
    66  	// handle request and response stuff.
    67  	// here only handle response part.
    68  	switch p.Code {
    69  	case CodeAccessRequest:
    70  	case CodeAccessAccept, CodeAccessReject, CodeAccessChallenge,
    71  		CodeAccountingRequest, CodeAccountingResponse:
    72  		//rfc2865 page 15 Response Authenticator
    73  		//rfc2866 page 6 Response Authenticator
    74  		//rfc2866 page 6 Request Authenticator
    75  		hasher := crypto.Hash(crypto.MD5).New()
    76  		hasher.Write(b)
    77  		hasher.Write([]byte(p.Secret))
    78  		copy(b[4:20], hasher.Sum(nil)) //返回值再把Authenticator写回去
    79  	default:
    80  		return nil, fmt.Errorf("not handle p.Code %d", p.Code)
    81  	}
    82  
    83  	return b, err
    84  }
    85  
    86  func (p *Packet) encodeNoHash() (b []byte, err error) {
    87  	b = make([]byte, maxPacketLength)
    88  	b[0] = uint8(p.Code)
    89  	b[1] = uint8(p.Identifier)
    90  	copy(b[4:20], p.Authenticator[:])
    91  	written := 20
    92  	bb := b[20:]
    93  	for i, _ := range p.AVPs {
    94  		bb1, err := p.AVPs[i].Encode()
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  		written += len(bb1)
    99  		if written > maxPacketLength {
   100  			return nil, fmt.Errorf("[Packet.encodeNoHash] packet too large written[%d]>maxPacketLength[%d]",
   101  				written, maxPacketLength)
   102  		}
   103  		copy(bb, bb1)
   104  		bb = bb[len(bb1):]
   105  	}
   106  	binary.BigEndian.PutUint16(b[2:4], uint16(written))
   107  	return b[:written], nil
   108  }
   109  
   110  //get one avp
   111  func (p *Packet) GetAVP(attrType AVPType) AVP {
   112  	for i := range p.AVPs {
   113  		if p.AVPs[i].GetType() == attrType {
   114  			return p.AVPs[i]
   115  		}
   116  	}
   117  	return nil
   118  }
   119  
   120  //set one avp,remove all other same type
   121  func (p *Packet) SetAVP(avp AVP) {
   122  	p.DeleteOneType(avp.GetType())
   123  	p.AddAVP(avp)
   124  }
   125  
   126  func (p *Packet) AddAVP(avp AVP) {
   127  	p.AVPs = append(p.AVPs, avp)
   128  }
   129  
   130  func (p *Packet) GetVsa(typ VendorType) VSA {
   131  	for i := range p.AVPs {
   132  		if p.AVPs[i].GetType() != AVPTypeVendorSpecific {
   133  			continue
   134  		}
   135  		vsa, ok := p.AVPs[i].(*VendorSpecificAVP)
   136  		if !ok {
   137  			continue //允许使用binaryAVP代表一个AVPTypeVendorSpecific
   138  		}
   139  		if vsa.Value.GetType() != typ {
   140  			continue
   141  		}
   142  		return vsa.Value
   143  	}
   144  	return nil
   145  }
   146  
   147  //删除一个AVP
   148  /*
   149  func (p *Packet) DeleteAVP(avp AVP) {
   150  	for i := range p.AVPs {
   151  		if &(p.AVPs[i]) == avp {
   152  			for j := i; j < len(p.AVPs)-1; j++ {
   153  				p.AVPs[j] = p.AVPs[j+1]
   154  			}
   155  			p.AVPs = p.AVPs[:len(p.AVPs)-1]
   156  			break
   157  		}
   158  	}
   159  	return
   160  }
   161  */
   162  
   163  //delete all avps with this type
   164  func (p *Packet) DeleteOneType(attrType AVPType) {
   165  	for i := 0; i < len(p.AVPs); i++ {
   166  		if p.AVPs[i].GetType() == attrType {
   167  			for j := i; j < len(p.AVPs)-1; j++ {
   168  				p.AVPs[j] = p.AVPs[j+1]
   169  			}
   170  			p.AVPs = p.AVPs[:len(p.AVPs)-1]
   171  			i--
   172  			break
   173  		}
   174  	}
   175  	return
   176  }
   177  
   178  func (p *Packet) Reply() *Packet {
   179  	pac := new(Packet)
   180  	pac.Authenticator = p.Authenticator
   181  	pac.Identifier = p.Identifier
   182  	pac.Secret = p.Secret
   183  	state := p.GetState()
   184  	if len(state) > 0 {
   185  		pac.SetState(state)
   186  	}
   187  	return pac
   188  }
   189  
   190  func (p *Packet) Send(c net.PacketConn, addr net.Addr) error {
   191  	buf, err := p.Encode()
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	_, err = c.WriteTo(buf, addr)
   197  	return err
   198  }
   199  
   200  // 这个只能解密各种Request
   201  func DecodeRequestPacket(Secret []byte, buf []byte) (p *Packet, err error) {
   202  	p = &Packet{Secret: Secret}
   203  	p.Code = Code(buf[0])
   204  	p.Identifier = buf[1]
   205  	copy(p.Authenticator[:], buf[4:20])
   206  	//read attributes
   207  	b := buf[20:]
   208  	for {
   209  		if len(b) == 0 {
   210  			break
   211  		}
   212  		if len(b) < 2 {
   213  			return nil, fmt.Errorf("[radius.DecodePacket] unexcept EOF")
   214  		}
   215  		length := uint8(b[1])
   216  		if int(length) > len(b) {
   217  			return nil, fmt.Errorf("[radius.DecodePacket] invalid avp length len:%d len(b):%d", length, len(b))
   218  		}
   219  		avp, err := avpDecode(p, b[:length])
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  		p.AVPs = append(p.AVPs, avp)
   224  		b = b[length:]
   225  	}
   226  	//验证Message-Authenticator,并且通过测试验证此处算法是正确的
   227  	//此处不修改Message-Authenticator的值
   228  	err = p.checkMessageAuthenticator()
   229  	if err != nil {
   230  		return p, err
   231  	}
   232  	return p, nil
   233  }
   234  
   235  // 解密response包
   236  func DecodeResponsePacket(Secret []byte, buf []byte, RequestAuthenticator [16]byte) (p *Packet, err error) {
   237  	p = &Packet{
   238  		Secret:        Secret,
   239  		Authenticator: RequestAuthenticator,
   240  	}
   241  	p.Code = Code(buf[0])
   242  	p.Identifier = buf[1]
   243  	//read attributes
   244  	b := buf[20:]
   245  	for {
   246  		if len(b) == 0 {
   247  			break
   248  		}
   249  		if len(b) < 2 {
   250  			return nil, fmt.Errorf("[radius.DecodePacket] unexcept EOF")
   251  		}
   252  		length := uint8(b[1])
   253  		if int(length) > len(b) {
   254  			return nil, fmt.Errorf("[radius.DecodePacket] invalid avp length len:%d len(b):%d", length, len(b))
   255  		}
   256  		avp, err := avpDecode(p, b[:length])
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  		p.AVPs = append(p.AVPs, avp)
   261  		b = b[length:]
   262  	}
   263  	//验证Message-Authenticator,并且通过测试验证此处算法是正确的
   264  	//此处不修改Message-Authenticator的值
   265  	err = p.checkMessageAuthenticator()
   266  	if err != nil {
   267  		return p, err
   268  	}
   269  	return p, nil
   270  }
   271  
   272  //如果没有MessageAuthenticator也算通过
   273  func (p *Packet) checkMessageAuthenticator() (err error) {
   274  	AuthenticatorI := p.GetAVP(AVPTypeMessageAuthenticator)
   275  	if AuthenticatorI == nil {
   276  		return nil
   277  	}
   278  	Authenticator := AuthenticatorI.(*BinaryAVP)
   279  	AuthenticatorValue := Authenticator.Value
   280  	defer func() { Authenticator.Value = AuthenticatorValue }()
   281  	Authenticator.Value = make([]byte, 16)
   282  	content, err := p.encodeNoHash()
   283  	if err != nil {
   284  		return err
   285  	}
   286  	hasher := hmac.New(crypto.MD5.New, []byte(p.Secret))
   287  	hasher.Write(content)
   288  	if !hmac.Equal(hasher.Sum(nil), AuthenticatorValue) {
   289  		return ErrMessageAuthenticatorCheckFail
   290  	}
   291  	return nil
   292  }
   293  
   294  func (p *Packet) String() string {
   295  	s := "Code: " + p.Code.String() + "\n" +
   296  		"Identifier: " + strconv.Itoa(int(p.Identifier)) + "\n" +
   297  		"Authenticator: " + fmt.Sprintf("%#v", p.Authenticator) + "\n"
   298  	for _, avp := range p.AVPs {
   299  		s += avp.String() + "\n"
   300  	}
   301  	return s
   302  }
   303  
   304  //转成字符串map,便于进行log(序列化?),只有实际信息,已经把加密的东西剔除掉了
   305  func (p *Packet) ToStringMap() map[string]string {
   306  	out := make(map[string]string, len(p.AVPs))
   307  	for _, avp := range p.AVPs {
   308  		if avp.GetType() == AVPTypeMessageAuthenticator {
   309  			continue
   310  		}
   311  		out[avp.GetType().String()] = avp.ValueAsString()
   312  	}
   313  	out["Code"] = p.Code.String()
   314  	out["Identifier"] = strconv.Itoa(int(p.Identifier))
   315  	return out
   316  }
   317  
   318  func (p *Packet) GetUsername() (username string) {
   319  	avp := p.GetAVP(AVPTypeUserName)
   320  	if avp == nil {
   321  		return ""
   322  	}
   323  	return avp.(*StringAVP).Value
   324  }
   325  func (p *Packet) GetPassword() (password string) {
   326  	avp := p.GetAVP(AVPTypeUserPassword)
   327  	if avp == nil {
   328  		return ""
   329  	}
   330  	return avp.(*PasswordAVP).Value
   331  }
   332  
   333  func (p *Packet) GetNasIpAddress() (ip net.IP) {
   334  	avp := p.GetAVP(AVPTypeNASIPAddress)
   335  	if avp == nil {
   336  		return nil
   337  	}
   338  	return avp.(*IpAVP).Value
   339  }
   340  
   341  func (p *Packet) GetAcctStatusType() AcctStatusTypeEnum {
   342  	avp := p.GetAVP(AVPTypeAcctStatusType)
   343  	if avp == nil {
   344  		return AcctStatusTypeEnum(0)
   345  	}
   346  	return avp.(*Uint32EnumAVP).Value.(AcctStatusTypeEnum)
   347  }
   348  
   349  func (p *Packet) GetAcctSessionId() string {
   350  	avp := p.GetAVP(AVPTypeAcctSessionId)
   351  	if avp == nil {
   352  		return ""
   353  	}
   354  	return avp.(*StringAVP).Value
   355  }
   356  
   357  func (p *Packet) GetAcctTotalOutputOctets() uint64 {
   358  	out := uint64(0)
   359  	avp := p.GetAVP(AVPTypeAcctOutputOctets)
   360  	if avp != nil {
   361  		out += uint64(avp.(*Uint32AVP).Value)
   362  	}
   363  	avp = p.GetAVP(AVPTypeAcctOutputGigawords)
   364  	if avp != nil {
   365  		out += uint64(avp.(*Uint32AVP).Value) * (2 ^ 32)
   366  	}
   367  	return out
   368  }
   369  
   370  func (p *Packet) GetAcctTotalInputOctets() uint64 {
   371  	out := uint64(0)
   372  	avp := p.GetAVP(AVPTypeAcctInputOctets)
   373  	if avp != nil {
   374  		out += uint64(avp.(*Uint32AVP).Value)
   375  	}
   376  	avp = p.GetAVP(AVPTypeAcctInputGigawords)
   377  	if avp != nil {
   378  		out += uint64(avp.(*Uint32AVP).Value) * (2 ^ 32)
   379  	}
   380  	return out
   381  }
   382  
   383  // it is ike_id in strongswan client
   384  func (p *Packet) GetNASPort() uint32 {
   385  	avp := p.GetAVP(AVPTypeNASPort)
   386  	if avp == nil {
   387  		return 0
   388  	}
   389  	return avp.(*Uint32AVP).Value
   390  }
   391  
   392  func (p *Packet) GetNASIdentifier() string {
   393  	avp := p.GetAVP(AVPTypeNASIdentifier)
   394  	if avp == nil {
   395  		return ""
   396  	}
   397  	return avp.(*StringAVP).Value
   398  }
   399  
   400  func (p *Packet) GetEAPMessage() eap.Packet {
   401  	avp := p.GetAVP(AVPTypeEAPMessage)
   402  	if avp == nil {
   403  		return nil
   404  	}
   405  	return avp.(*EapAVP).Value
   406  }
   407  
   408  func (p *Packet) GetState() []byte {
   409  	avp := p.GetAVP(AVPTypeState)
   410  	if avp == nil {
   411  		return nil
   412  	}
   413  	return avp.GetValue().([]byte)
   414  }
   415  
   416  func (p *Packet) SetState(state []byte) {
   417  	p.SetAVP(&BinaryAVP{
   418  		Type:  AVPTypeState,
   419  		Value: state,
   420  	})
   421  }
   422  
   423  func (p *Packet) SetAcctInterimInterval(second int) {
   424  	p.SetAVP(&Uint32AVP{
   425  		Type:  AVPTypeAcctInterimInterval,
   426  		Value: uint32(second),
   427  	})
   428  }
   429  
   430  func (p *Packet) GetAcctSessionTime() uint32 {
   431  	avp := p.GetAVP(AVPTypeAcctSessionTime)
   432  	if avp == nil {
   433  		return 0
   434  	}
   435  	return avp.GetValue().(uint32)
   436  }